LeetCode-in-Java

3510. Minimum Pair Removal to Sort Array II

Hard

Given an array nums, you can perform the following operation any number of times:

Return the minimum number of operations needed to make the array non-decreasing.

An array is said to be non-decreasing if each element is greater than or equal to its previous element (if it exists).

Example 1:

Input: nums = [5,2,3,1]

Output: 2

Explanation:

The array nums became non-decreasing in two operations.

Example 2:

Input: nums = [1,2,2]

Output: 0

Explanation:

The array nums is already sorted.

Constraints:

Solution

import java.util.Arrays;

public class Solution {
    public int minimumPairRemoval(int[] nums) {
        if (nums.length == 1) {
            return 0;
        }
        int size = (int) Math.pow(2, Math.ceil(Math.log(nums.length - 1.0) / Math.log(2)));
        long[] segment = new long[size * 2 - 1];
        Arrays.fill(segment, Long.MAX_VALUE);
        int[] lefts = new int[size * 2 - 1];
        int[] rights = new int[size * 2 - 1];
        long[] sums = new long[nums.length];
        Arrays.fill(sums, Long.MAX_VALUE / 2);
        int[][] arrIdxToSegIdx = new int[nums.length][];
        sums[0] = nums[0];
        int count = 0;
        arrIdxToSegIdx[0] = new int[] {-1, size - 1};
        for (int i = 1; i < nums.length; i++) {
            if (nums[i] < nums[i - 1]) {
                count++;
            }
            lefts[size + i - 2] = i - 1;
            rights[size + i - 2] = i;
            segment[size + i - 2] = nums[i - 1] + (long) nums[i];
            arrIdxToSegIdx[i] = new int[] {size + i - 2, size + i - 1};
            sums[i] = nums[i];
        }
        arrIdxToSegIdx[nums.length - 1][1] = -1;
        for (int i = size - 2; i >= 0; i--) {
            int l = 2 * i + 1;
            int r = 2 * i + 2;
            segment[i] = Math.min(segment[l], segment[r]);
        }
        return getRes(count, segment, lefts, rights, sums, arrIdxToSegIdx);
    }

    private int getRes(
            int count,
            long[] segment,
            int[] lefts,
            int[] rights,
            long[] sums,
            int[][] arrIdxToSegIdx) {
        int res = 0;
        while (count > 0) {
            int segIdx = 0;
            while (2 * segIdx + 1 < segment.length) {
                int l = 2 * segIdx + 1;
                int r = 2 * segIdx + 2;
                if (segment[l] <= segment[r]) {
                    segIdx = l;
                } else {
                    segIdx = r;
                }
            }
            int arrIdxL = lefts[segIdx];
            int arrIdxR = rights[segIdx];
            long numL = sums[arrIdxL];
            long numR = sums[arrIdxR];
            if (numL > numR) {
                count--;
            }
            long newSum = sums[arrIdxL] = sums[arrIdxL] + sums[arrIdxR];
            int[] leftPointer = arrIdxToSegIdx[arrIdxL];
            int[] rightPointer = arrIdxToSegIdx[arrIdxR];
            int prvSegIdx = leftPointer[0];
            int nextSegIdx = rightPointer[1];
            leftPointer[1] = nextSegIdx;
            if (prvSegIdx != -1) {
                int l = lefts[prvSegIdx];
                if (sums[l] > numL && sums[l] <= newSum) {
                    count--;
                } else if (sums[l] <= numL && sums[l] > newSum) {
                    count++;
                }
                modify(segment, prvSegIdx, sums[l] + newSum);
            }
            if (nextSegIdx != -1) {
                int r = rights[nextSegIdx];
                if (numR > sums[r] && newSum <= sums[r]) {
                    count--;
                } else if (numR <= sums[r] && newSum > sums[r]) {
                    count++;
                }
                modify(segment, nextSegIdx, newSum + sums[r]);
                lefts[nextSegIdx] = arrIdxL;
            }
            modify(segment, segIdx, Long.MAX_VALUE);
            res++;
        }
        return res;
    }

    private void modify(long[] segment, int idx, long num) {
        if (segment[idx] == num) {
            return;
        }
        segment[idx] = num;
        while (idx != 0) {
            idx = (idx - 1) / 2;
            int l = 2 * idx + 1;
            int r = 2 * idx + 2;
            segment[idx] = Math.min(segment[l], segment[r]);
        }
    }
}