Hard
There is a tree (i.e. a connected, undirected graph with no cycles) consisting of n
nodes numbered from 0
to n - 1
and exactly n - 1
edges.
You are given a 0-indexed integer array vals
of length n
where vals[i]
denotes the value of the ith
node. You are also given a 2D integer array edges
where edges[i] = [ai, bi]
denotes that there exists an undirected edge connecting nodes ai
and bi
.
A good path is a simple path that satisfies the following conditions:
Return the number of distinct good paths.
Note that a path and its reverse are counted as the same path. For example, 0 -> 1
is considered to be the same as 1 -> 0
. A single node is also considered as a valid path.
Example 1:
Input: vals = [1,3,2,1,3], edges = [[0,1],[0,2],[2,3],[2,4]]
Output: 6
Explanation: There are 5 good paths consisting of a single node.
There is 1 additional good path: 1 -> 0 -> 2 -> 4.
(The reverse path 4 -> 2 -> 0 -> 1 is treated as the same as 1 -> 0 -> 2 -> 4.)
Note that 0 -> 2 -> 3 is not a good path because vals[2] > vals[0].
Example 2:
Input: vals = [1,1,2,2,3], edges = [[0,1],[1,2],[2,3],[2,4]]
Output: 7
Explanation: There are 5 good paths consisting of a single node.
There are 2 additional good paths: 0 -> 1 and 2 -> 3.
Example 3:
Input: vals = [1], edges = []
Output: 1
Explanation: The tree consists of only one node, so there is one good path.
Constraints:
n == vals.length
1 <= n <= 3 * 104
0 <= vals[i] <= 105
edges.length == n - 1
edges[i].length == 2
0 <= ai, bi < n
ai != bi
edges
represents a valid tree.import java.util.Arrays;
import java.util.Comparator;
public class Solution {
public int numberOfGoodPaths(int[] vals, int[][] edges) {
int n = vals.length;
int[] parent = new int[n];
int[] maxElement = new int[n];
int[] count = new int[n];
for (int i = 0; i < n; i++) {
parent[i] = i;
maxElement[i] = vals[i];
count[i] = 1;
}
Arrays.sort(edges, Comparator.comparingInt(a -> Math.max(vals[a[0]], vals[a[1]])));
int ans = n;
for (int[] it : edges) {
int a = findParent(parent, it[0]);
int b = findParent(parent, it[1]);
if (maxElement[a] != maxElement[b]) {
if (maxElement[a] > maxElement[b]) {
parent[b] = a;
} else {
parent[a] = b;
}
} else {
parent[b] = a;
ans += count[a] * count[b];
count[a] += count[b];
}
}
return ans;
}
private int findParent(int[] parent, int a) {
if (a == parent[a]) {
return a;
}
parent[a] = findParent(parent, parent[a]);
return parent[a];
}
}