LeetCode-in-Java

1803. Count Pairs With XOR in a Range

Hard

Given a (0-indexed) integer array nums and two integers low and high, return the number of nice pairs.

A nice pair is a pair (i, j) where 0 <= i < j < nums.length and low <= (nums[i] XOR nums[j]) <= high.

Example 1:

Input: nums = [1,4,2,7], low = 2, high = 6

Output: 6

Explanation: All nice pairs (i, j) are as follows:

Example 2:

Input: nums = [9,8,4,2,1], low = 5, high = 14

Output: 8

Explanation: All nice pairs (i, j) are as follows:

Constraints:

Solution

public class Solution {
    public int countPairs(int[] nums, int low, int high) {
        Trie root = new Trie();
        int pairsCount = 0;
        for (int num : nums) {
            int pairsCountHigh = countPairsWhoseXorLessThanX(num, root, high + 1);
            int pairsCountLow = countPairsWhoseXorLessThanX(num, root, low);
            pairsCount += (pairsCountHigh - pairsCountLow);
            root.insertNumber(num);
        }
        return pairsCount;
    }

    private int countPairsWhoseXorLessThanX(int num, Trie root, int x) {
        int pairs = 0;
        Trie curr = root;
        for (int i = 14; i >= 0 && curr != null; i--) {
            int numIthBit = (num >> i) & 1;
            int xIthBit = (x >> i) & 1;
            if (xIthBit == 1) {
                if (curr.child[numIthBit] != null) {
                    pairs += curr.child[numIthBit].count;
                }
                curr = curr.child[1 - numIthBit];
            } else {
                curr = curr.child[numIthBit];
            }
        }
        return pairs;
    }

    private static class Trie {
        Trie[] child;
        int count;

        public Trie() {
            child = new Trie[2];
            count = 0;
        }

        public void insertNumber(int num) {
            Trie curr = this;
            for (int i = 14; i >= 0; i--) {
                int ithBit = (num >> i) & 1;
                if (curr.child[ithBit] == null) {
                    curr.child[ithBit] = new Trie();
                }
                curr.child[ithBit].count++;
                curr = curr.child[ithBit];
            }
        }
    }
}