LeetCode-in-Java

1577. Number of Ways Where Square of Number Is Equal to Product of Two Numbers

Medium

Given two arrays of integers nums1 and nums2, return the number of triplets formed (type 1 and type 2) under the following rules:

Example 1:

Input: nums1 = [7,4], nums2 = [5,2,8,9]

Output: 1

Explanation: Type 1: (1, 1, 2), nums1[1]2 = nums2[1] * nums2[2]. (42 = 2 * 8).

Example 2:

Input: nums1 = [1,1], nums2 = [1,1,1]

Output: 9

Explanation: All Triplets are valid, because 12 = 1 * 1.

Type 1: (0,0,1), (0,0,2), (0,1,2), (1,0,1), (1,0,2), (1,1,2). nums1[i]2 = nums2[j] * nums2[k].

Type 2: (0,0,1), (1,0,1), (2,0,1). nums2[i]2 = nums1[j] * nums1[k].

Example 3:

Input: nums1 = [7,7,8,3], nums2 = [1,2,9,7]

Output: 2

Explanation: There are 2 valid triplets.

Type 1: (3,0,2). nums1[3]2 = nums2[0] * nums2[2].

Type 2: (3,0,1). nums2[3]2 = nums1[0] * nums1[1].

Constraints:

Solution

import java.util.Arrays;

public class Solution {
    public int numTriplets(int[] nums1, int[] nums2) {
        Arrays.sort(nums1);
        Arrays.sort(nums2);
        return count(nums1, nums2) + count(nums2, nums1);
    }

    public int count(int[] a, int[] b) {
        int m = b.length;
        int count = 0;
        for (int value : a) {
            long x = (long) value * value;
            int j = 0;
            int k = m - 1;
            while (j < k) {
                long prod = (long) b[j] * b[k];
                if (prod < x) {
                    j++;
                } else if (prod > x) {
                    k--;
                } else if (b[j] != b[k]) {
                    int jNew = j;
                    int kNew = k;
                    while (b[j] == b[jNew]) {
                        jNew++;
                    }
                    while (b[k] == b[kNew]) {
                        kNew--;
                    }
                    count += (jNew - j) * (k - kNew);
                    j = jNew;
                    k = kNew;
                } else {
                    int q = k - j + 1;
                    count += (q) * (q - 1) / 2;
                    break;
                }
            }
        }
        return count;
    }
}