LeetCode-in-Java

3426. Manhattan Distances of All Arrangements of Pieces

Hard

You are given three integers m, n, and k.

There is a rectangular grid of size m × n containing k identical pieces. Return the sum of Manhattan distances between every pair of pieces over all valid arrangements of pieces.

A valid arrangement is a placement of all k pieces on the grid with at most one piece per cell.

Since the answer may be very large, return it modulo 109 + 7.

The Manhattan Distance between two cells (xi, yi) and (xj, yj) is |xi - xj| + |yi - yj|.

Example 1:

Input: m = 2, n = 2, k = 2

Output: 8

Explanation:

The valid arrangements of pieces on the board are:

Thus, the total Manhattan distance across all valid arrangements is 1 + 1 + 1 + 1 + 2 + 2 = 8.

Example 2:

Input: m = 1, n = 4, k = 3

Output: 20

Explanation:

The valid arrangements of pieces on the board are:

The total Manhattan distance between all pairs of pieces across all arrangements is 4 + 6 + 6 + 4 = 20.

Constraints:

Solution

public class Solution {
    private long comb(long a, long b, long mod) {
        if (b > a) {
            return 0;
        }
        long numer = 1;
        long denom = 1;
        for (long i = 0; i < b; ++i) {
            numer = numer * (a - i) % mod;
            denom = denom * (i + 1) % mod;
        }
        long denomInv = 1;
        long exp = mod - 2;
        while (exp > 0) {
            if (exp % 2 > 0) {
                denomInv = denomInv * denom % mod;
            }
            denom = denom * denom % mod;
            exp /= 2;
        }
        return numer * denomInv % mod;
    }

    public int distanceSum(int m, int n, int k) {
        long res = 0;
        long mod = 1000000007;
        long base = comb((long) m * n - 2, k - 2L, mod);
        for (int d = 1; d < n; ++d) {
            res = (res + (long) d * (n - d) % mod * m % mod * m % mod) % mod;
        }
        for (int d = 1; d < m; ++d) {
            res = (res + (long) d * (m - d) % mod * n % mod * n % mod) % mod;
        }
        return (int) (res * base % mod);
    }
}