LeetCode-in-Java

3139. Minimum Cost to Equalize Array

Hard

You are given an integer array nums and two integers cost1 and cost2. You are allowed to perform either of the following operations any number of times:

Return the minimum cost required to make all elements in the array equal.

Since the answer may be very large, return it modulo 109 + 7.

Example 1:

Input: nums = [4,1], cost1 = 5, cost2 = 2

Output: 15

Explanation:

The following operations can be performed to make the values equal:

The total cost is 15.

Example 2:

Input: nums = [2,3,3,3,5], cost1 = 2, cost2 = 1

Output: 6

Explanation:

The following operations can be performed to make the values equal:

The total cost is 6.

Example 3:

Input: nums = [3,5,3], cost1 = 1, cost2 = 3

Output: 4

Explanation:

The following operations can be performed to make the values equal:

The total cost is 4.

Constraints:

Solution

public class Solution {
    private static final int MOD = 1_000_000_007;
    private static final long LMOD = MOD;

    public int minCostToEqualizeArray(int[] nums, int cost1, int cost2) {
        long max = 0L;
        long min = Long.MAX_VALUE;
        long sum = 0L;
        for (long num : nums) {
            if (num > max) {
                max = num;
            }
            if (num < min) {
                min = num;
            }
            sum += num;
        }
        final int n = nums.length;
        long total = max * n - sum;
        // When operation one is always better:
        if ((cost1 << 1) <= cost2 || n <= 2) {
            return (int) (total * cost1 % LMOD);
        }
        // When operation two is moderately better:
        long op1 = Math.max(0L, ((max - min) << 1L) - total);
        long op2 = total - op1;
        long result = (op1 + (op2 & 1L)) * cost1 + (op2 >> 1L) * cost2;
        // When operation two is significantly better:
        total += op1 / (n - 2L) * n;
        op1 %= n - 2L;
        op2 = total - op1;
        result = Math.min(result, (op1 + (op2 & 1L)) * cost1 + (op2 >> 1L) * cost2);
        // When operation two is always better:
        for (int i = 0; i < 2; ++i) {
            total += n;
            result = Math.min(result, (total & 1L) * cost1 + (total >> 1L) * cost2);
        }
        return (int) (result % LMOD);
    }
}