LeetCode-in-Java

3600. Maximize Spanning Tree Stability with Upgrades

Hard

You are given an integer n, representing n nodes numbered from 0 to n - 1 and a list of edges, where edges[i] = [ui, vi, si, musti]:

You are also given an integer k, the maximum number of upgrades you can perform. Each upgrade doubles the strength of an edge, and each eligible edge (with musti == 0) can be upgraded at most once.

The stability of a spanning tree is defined as the minimum strength score among all edges included in it.

Return the maximum possible stability of any valid spanning tree. If it is impossible to connect all nodes, return -1.

Note: A spanning tree of a graph with n nodes is a subset of the edges that connects all nodes together (i.e. the graph is connected) without forming any cycles, and uses exactly n - 1 edges.

Example 1:

Input: n = 3, edges = [[0,1,2,1],[1,2,3,0]], k = 1

Output: 2

Explanation:

Example 2:

Input: n = 3, edges = [[0,1,4,0],[1,2,3,0],[0,2,1,0]], k = 2

Output: 6

Explanation:

Example 3:

Input: n = 3, edges = [[0,1,1,1],[1,2,1,1],[2,0,1,1]], k = 0

Output: -1

Explanation:

Constraints:

Solution

public class Solution {
    public int maxStability(int n, int[][] edges, int k) {
        int low = 0;
        int high = 0;
        for (int[] edge : edges) {
            high = Math.max(high, edge[2]);
        }
        high *= 2;
        int ans = -1;
        while (low <= high) {
            int mid = (low + high) / 2;
            if (feasible(mid, n, edges, k)) {
                ans = mid;
                low = mid + 1;
            } else {
                high = mid - 1;
            }
        }
        return ans;
    }

    private boolean feasible(int t, int n, int[][] edges, int k) {
        int[] par = new int[n];
        int[] rnk = new int[n];
        int[] comp = new int[] {n};
        for (int i = 0; i < n; i++) {
            par[i] = i;
        }
        UnionFind uf = new UnionFind(par, rnk, comp);
        int cost = 0;
        int half = (t + 1) / 2;
        for (int[] edge : edges) {
            int u = edge[0];
            int v = edge[1];
            int s = edge[2];
            int m = edge[3];
            if (m == 1 && (s < t || !uf.union(u, v))) {
                return false;
            }
        }
        for (int[] edge : edges) {
            int u = edge[0];
            int v = edge[1];
            int s = edge[2];
            int m = edge[3];
            if (m == 0 && s >= t) {
                uf.union(u, v);
            }
        }
        if (comp[0] == 1) {
            return true;
        }
        for (int[] edge : edges) {
            int u = edge[0];
            int v = edge[1];
            int s = edge[2];
            int m = edge[3];
            if (m == 0 && s >= half && s < t && uf.union(u, v)) {
                cost++;
                if (cost > k) {
                    return false;
                }
            }
        }
        return comp[0] == 1;
    }

    private static class UnionFind {
        int[] par;
        int[] rnk;
        int[] comp;

        UnionFind(int[] par, int[] rnk, int[] comp) {
            this.par = par;
            this.rnk = rnk;
            this.comp = comp;
        }

        int find(int x) {
            if (par[x] != x) {
                par[x] = find(par[x]);
            }
            return par[x];
        }

        boolean union(int a, int b) {
            int ra = find(a);
            int rb = find(b);
            if (ra == rb) {
                return false;
            }
            if (rnk[ra] < rnk[rb]) {
                int temp = ra;
                ra = rb;
                rb = temp;
            }
            par[rb] = ra;
            if (rnk[ra] == rnk[rb]) {
                rnk[ra]++;
            }
            comp[0]--;
            return true;
        }
    }
}