Hard
You are given an undirected tree rooted at node 0
, with n
nodes numbered from 0
to n - 1
. This is represented by a 2D array edges
of length n - 1
, where edges[i] = [ui, vi, lengthi]
indicates an edge between nodes ui
and vi
with length lengthi
. You are also given an integer array nums
, where nums[i]
represents the value at node i
.
A special path is defined as a downward path from an ancestor node to a descendant node in which all node values are distinct, except for at most one value that may appear twice.
Return an array result
of size 2, where result[0]
is the length of the longest special path, and result[1]
is the minimum number of nodes in all possible longest special paths.
Example 1:
Input: edges = [[0,1,1],[1,2,3],[1,3,1],[2,4,6],[4,7,2],[3,5,2],[3,6,5],[6,8,3]], nums = [1,1,0,3,1,2,1,1,0]
Output: [9,3]
Explanation:
In the image below, nodes are colored by their corresponding values in nums
.
The longest special paths are 1 -> 2 -> 4
and 1 -> 3 -> 6 -> 8
, both having a length of 9. The minimum number of nodes across all longest special paths is 3.
Example 2:
Input: edges = [[1,0,3],[0,2,4],[0,3,5]], nums = [1,1,0,2]
Output: [5,2]
Explanation:
The longest path is 0 -> 3
consisting of 2 nodes with a length of 5.
Constraints:
2 <= n <= 5 * 104
edges.length == n - 1
edges[i].length == 3
0 <= ui, vi < n
1 <= lengthi <= 103
nums.length == n
0 <= nums[i] <= 5 * 104
edges
represents a valid tree.import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@SuppressWarnings("java:S107")
public class Solution {
public int[] longestSpecialPath(int[][] edges, int[] nums) {
int[] ans = {0, 1};
Map<Integer, List<int[]>> graph = new HashMap<>();
for (int[] edge : edges) {
int a = edge[0];
int b = edge[1];
int c = edge[2];
graph.computeIfAbsent(a, k -> new ArrayList<>()).add(new int[] {b, c});
graph.computeIfAbsent(b, k -> new ArrayList<>()).add(new int[] {a, c});
}
List<Integer> costs = new ArrayList<>();
Map<Integer, Integer> last = new HashMap<>();
dfs(0, 0, -1, new ArrayList<>(Arrays.asList(0, 0)), nums, graph, costs, last, ans);
return ans;
}
private void dfs(
int node,
int currCost,
int prev,
List<Integer> left,
int[] nums,
Map<Integer, List<int[]>> graph,
List<Integer> costs,
Map<Integer, Integer> last,
int[] ans) {
int nodeColorIndexPrev = last.getOrDefault(nums[node], -1);
last.put(nums[node], costs.size());
costs.add(currCost);
int diff = currCost - costs.get(left.get(0));
int length = costs.size() - left.get(0);
if (diff > ans[0] || (diff == ans[0] && length < ans[1])) {
ans[0] = diff;
ans[1] = length;
}
for (int[] next : graph.getOrDefault(node, new ArrayList<>())) {
int nextNode = next[0];
int nextCost = next[1];
if (nextNode == prev) {
continue;
}
List<Integer> nextLeft = new ArrayList<>(left);
if (last.containsKey(nums[nextNode])) {
nextLeft.add(last.get(nums[nextNode]) + 1);
}
nextLeft.sort(Comparator.naturalOrder());
while (nextLeft.size() > 2) {
nextLeft.remove(0);
}
dfs(nextNode, currCost + nextCost, node, nextLeft, nums, graph, costs, last, ans);
}
last.put(nums[node], nodeColorIndexPrev);
costs.remove(costs.size() - 1);
}
}