LeetCode-in-Java

2179. Count Good Triplets in an Array

Hard

You are given two 0-indexed arrays nums1 and nums2 of length n, both of which are permutations of [0, 1, ..., n - 1].

A good triplet is a set of 3 distinct values which are present in increasing order by position both in nums1 and nums2. In other words, if we consider pos1v as the index of the value v in nums1 and pos2v as the index of the value v in nums2, then a good triplet will be a set (x, y, z) where 0 <= x, y, z <= n - 1, such that pos1x < pos1y < pos1z and pos2x < pos2y < pos2z.

Return the total number of good triplets.

Example 1:

Input: nums1 = [2,0,1,3], nums2 = [0,1,2,3]

Output: 1

Explanation:

There are 4 triplets (x,y,z) such that pos1x < pos1y < pos1z. They are (2,0,1), (2,0,3), (2,1,3), and (0,1,3).

Out of those triplets, only the triplet (0,1,3) satisfies pos2x < pos2y < pos2z.

Hence, there is only 1 good triplet.

Example 2:

Input: nums1 = [4,0,1,3,2], nums2 = [4,1,0,2,3]

Output: 4

Explanation: The 4 good triplets are (4,0,3), (4,0,2), (4,1,3), and (4,1,2).

Constraints:

Solution

public class Solution {
    public long goodTriplets(int[] nums1, int[] nums2) {
        int n = nums1.length;
        int[] idx = new int[n];
        int[] arr = new int[n];
        for (int i = 0; i < n; i++) {
            idx[nums2[i]] = i;
        }
        for (int i = 0; i < n; i++) {
            arr[i] = idx[nums1[i]];
        }
        Tree tree = new Tree(n);
        long res = 0L;
        for (int i = 0; i < n; i++) {
            int smaller = tree.query(arr[i]);
            int bigger = n - (arr[i] + 1) - (i - smaller);
            res += (long) smaller * bigger;
            tree.update(arr[i] + 1, 1);
        }
        return res;
    }

    private static class Tree {
        int[] array;
        int n;

        public Tree(int n) {
            this.n = n;
            array = new int[n + 1];
        }

        int lowbit(int x) {
            return x & (-x);
        }

        void update(int i, int delta) {
            while (i <= n) {
                array[i] += delta;
                i += lowbit(i);
            }
        }

        int query(int k) {
            int ans = 0;
            while (k > 0) {
                ans += array[k];
                k -= lowbit(k);
            }
            return ans;
        }
    }
}