LeetCode-in-Java

3367. Maximize Sum of Weights after Edge Removals

Hard

There exists an undirected tree with n nodes numbered 0 to n - 1. You are given a 2D integer array edges of length n - 1, where edges[i] = [ui, vi, wi] indicates that there is an edge between nodes ui and vi with weight wi in the tree.

Your task is to remove zero or more edges such that:

Return the maximum possible sum of weights for the remaining edges after making the necessary removals.

Example 1:

Input: edges = [[0,1,4],[0,2,2],[2,3,12],[2,4,6]], k = 2

Output: 22

Explanation:

Example 2:

Input: edges = [[0,1,5],[1,2,10],[0,3,15],[3,4,20],[3,5,5],[0,6,10]], k = 3

Output: 65

Explanation:

Constraints:

Solution

import java.util.ArrayList;
import java.util.List;
import java.util.PriorityQueue;

@SuppressWarnings("unchecked")
public class Solution {
    private List<int[]>[] adj;
    private int k;

    public long maximizeSumOfWeights(int[][] edges, int k) {
        int n = edges.length + 1;
        adj = new List[n];
        this.k = k;
        for (int i = 0; i < n; i++) {
            adj[i] = new ArrayList<>();
        }
        for (int[] e : edges) {
            adj[e[0]].add(e);
            adj[e[1]].add(e);
        }
        return dfs(0, -1)[1];
    }

    private long[] dfs(int v, int parent) {
        long sum = 0;
        PriorityQueue<Long> pq = new PriorityQueue<>();
        for (int[] e : adj[v]) {
            int w = e[0] == v ? e[1] : e[0];
            if (w == parent) {
                continue;
            }
            long[] res = dfs(w, v);
            long max = Math.max(e[2] + res[0], res[1]);
            sum += max;
            pq.add(max - res[1]);
        }
        long[] res = new long[2];
        while (pq.size() > k) {
            sum -= pq.poll();
        }
        res[1] = sum;
        while (pq.size() > k - 1) {
            sum -= pq.poll();
        }
        res[0] = sum;
        return res;
    }
}