LeetCode-in-Java

1584. Min Cost to Connect All Points

Medium

You are given an array points representing integer coordinates of some points on a 2D-plane, where points[i] = [xi, yi].

The cost of connecting two points [xi, yi] and [xj, yj] is the manhattan distance between them: |xi - xj| + |yi - yj|, where |val| denotes the absolute value of val.

Return the minimum cost to make all points connected. All points are connected if there is exactly one simple path between any two points.

Example 1:

Input: points = [[0,0],[2,2],[3,10],[5,2],[7,0]]

Output: 20

Explanation:

We can connect the points as shown above to get the minimum cost of 20.

Notice that there is a unique path between every pair of points.

Example 2:

Input: points = [[3,12],[-2,5],[-4,1]]

Output: 18

Constraints:

Solution

import java.util.Arrays;
import java.util.Comparator;
import java.util.PriorityQueue;

public class Solution {
    public int minCostConnectPoints(int[][] points) {
        int v = points.length;
        if (v == 2) {
            return getDistance(points[0], points[1]);
        }
        PriorityQueue<Pair> pq = new PriorityQueue<>(v, new Pair());
        boolean[] mst = new boolean[v];
        int[] dist = new int[v];
        int[] parent = new int[v];
        Arrays.fill(dist, 1000000);
        Arrays.fill(parent, -1);
        dist[0] = 0;
        parent[0] = 0;
        for (int i = 0; i < v; i++) {
            pq.add(new Pair(dist[i], i));
        }
        constructMST(parent, points, mst, pq, dist);
        int cost = 0;
        for (int i = 1; i < parent.length; i++) {
            cost += getDistance(points[parent[i]], points[i]);
        }
        return cost;
    }

    private void constructMST(
            int[] parent, int[][] points, boolean[] mst, PriorityQueue<Pair> pq, int[] dist) {
        if (!containsFalse(mst)) {
            return;
        }
        Pair newPair = pq.poll();
        assert newPair != null;
        int pointIndex = newPair.getV();
        mst[pointIndex] = true;
        for (int i = 0; i < parent.length; i++) {
            int d = getDistance(points[pointIndex], points[i]);
            if (!mst[i] && d < dist[i]) {
                dist[i] = d;
                pq.add(new Pair(dist[i], i));
                parent[i] = pointIndex;
            }
        }
        constructMST(parent, points, mst, pq, dist);
    }

    public boolean containsFalse(boolean[] mst) {
        for (boolean b : mst) {
            if (!b) {
                return true;
            }
        }
        return false;
    }

    public int getDistance(int[] p1, int[] p2) {
        return Math.abs(p1[0] - p2[0]) + Math.abs(p1[1] - p2[1]);
    }

    private static class Pair implements Comparator<Pair> {
        int dis;
        int v;

        public Pair() {}

        public Pair(int dis, int v) {
            this.dis = dis;
            this.v = v;
        }

        public int getDis() {
            return dis;
        }

        public int getV() {
            return v;
        }

        @Override
        public int compare(Pair p1, Pair p2) {
            return p1.dis - p2.dis;
        }
    }
}