LeetCode-in-Java

3757. Number of Effective Subsequences

Hard

You are given an integer array nums.

The strength of the array is defined as the bitwise OR of all its elements.

A subsequence is considered effective if removing that subsequence strictly decreases the strength of the remaining elements.

Return the number of effective subsequences in nums. Since the answer may be large, return it modulo 109 + 7.

The bitwise OR of an empty array is 0.

Example 1:

Input: nums = [1,2,3]

Output: 3

Explanation:

Example 2:

Input: nums = [7,4,6]

Output: 4

Explanation:

Example 3:

Input: nums = [8,8]

Output: 1

Explanation:

Example 4:

Input: nums = [2,2,1]

Output: 5

Explanation:

Constraints:

Solution

public class Solution {
    private static final int MOD = 1_000_000_007;

    public int countEffective(int[] nums) {
        int n = nums.length;
        int t = 0;
        for (int v : nums) {
            t |= v;
        }
        if (t == 0) {
            return 0;
        }
        int[] bits = new int[20];
        int m = 0;
        for (int b = 0; b < 20; ++b) {
            if (((t >> b) & 1) != 0) {
                bits[m++] = b;
            }
        }
        int s = 1 << m;
        int[] freq = new int[s];
        for (int v : nums) {
            int m1 = 0;
            for (int j = 0; j < m; ++j) {
                if (((v >> bits[j]) & 1) != 0) {
                    m1 |= 1 << j;
                }
            }
            freq[m1]++;
        }
        int[] f = new int[s];
        System.arraycopy(freq, 0, f, 0, s);
        for (int i = 0; i < m; ++i) {
            for (int mask = 0; mask < s; ++mask) {
                if ((mask & (1 << i)) != 0) {
                    f[mask] += f[mask ^ (1 << i)];
                }
            }
        }
        long[] p2 = new long[n + 1];
        p2[0] = 1;
        for (int i = 1; i <= n; ++i) {
            p2[i] = (p2[i - 1] << 1) % MOD;
        }
        long ans = 0;
        int all = s - 1;
        for (int bmask = 1; bmask < s; ++bmask) {
            int comp = all ^ bmask;
            int cnt = f[comp];
            long add = p2[cnt];
            if (Integer.bitCount(bmask) % 2 == 1) {
                ans = (ans + add) % MOD;
            } else {
                ans = (ans - add) % MOD;
            }
        }
        ans = (ans % MOD + MOD) % MOD;
        return (int) ans;
    }
}