Hard
There is an undirected tree with n
nodes labeled from 0
to n - 1
. You are given the integer n
and a 2D integer array edges
of length n - 1
, where edges[i] = [ai, bi]
indicates that there is an edge between nodes ai
and bi
in the tree.
You are also given a 0-indexed integer array values
of length n
, where values[i]
is the value associated with the ith
node, and an integer k
.
A valid split of the tree is obtained by removing any set of edges, possibly empty, from the tree such that the resulting components all have values that are divisible by k
, where the value of a connected component is the sum of the values of its nodes.
Return the maximum number of components in any valid split.
Example 1:
Input: n = 5, edges = [[0,2],[1,2],[1,3],[2,4]], values = [1,8,1,4,4], k = 6
Output: 2
Explanation: We remove the edge connecting node 1 with 2. The resulting split is valid because:
It can be shown that no other valid split has more than 2 connected components.
Example 2:
Input: n = 7, edges = [[0,1],[0,2],[1,3],[1,4],[2,5],[2,6]], values = [3,0,6,1,5,2,1], k = 3
Output: 3
Explanation: We remove the edge connecting node 0 with 2, and the edge connecting node 0 with 1. The resulting split is valid because:
It can be shown that no other valid split has more than 3 connected components.
Constraints:
1 <= n <= 3 * 104
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
values.length == n
0 <= values[i] <= 109
1 <= k <= 109
values
is divisible by k
.edges
represents a valid tree.import java.util.ArrayList;
import java.util.List;
public class Solution {
private int ans = 0;
public int maxKDivisibleComponents(int n, int[][] edges, int[] values, int k) {
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i < n; i++) {
adj.add(new ArrayList<>());
}
for (int[] edge : edges) {
int start = edge[0];
int end = edge[1];
adj.get(start).add(end);
adj.get(end).add(start);
}
boolean[] isVis = new boolean[n];
isVis[0] = true;
get(0, -1, adj, isVis, values, k);
return ans;
}
private long get(
int curNode,
int parent,
List<List<Integer>> adj,
boolean[] isVis,
int[] values,
long k) {
long sum = values[curNode];
for (int ele : adj.get(curNode)) {
if (ele != parent && !isVis[ele]) {
isVis[ele] = true;
sum += get(ele, curNode, adj, isVis, values, k);
}
}
if (sum % k == 0) {
ans++;
return 0;
} else {
return sum;
}
}
}