LeetCode-in-Java

3486. Longest Special Path II

Hard

You are given an undirected tree rooted at node 0, with n nodes numbered from 0 to n - 1. This is represented by a 2D array edges of length n - 1, where edges[i] = [ui, vi, lengthi] indicates an edge between nodes ui and vi with length lengthi. You are also given an integer array nums, where nums[i] represents the value at node i.

A special path is defined as a downward path from an ancestor node to a descendant node in which all node values are distinct, except for at most one value that may appear twice.

Return an array result of size 2, where result[0] is the length of the longest special path, and result[1] is the minimum number of nodes in all possible longest special paths.

Example 1:

Input: edges = [[0,1,1],[1,2,3],[1,3,1],[2,4,6],[4,7,2],[3,5,2],[3,6,5],[6,8,3]], nums = [1,1,0,3,1,2,1,1,0]

Output: [9,3]

Explanation:

In the image below, nodes are colored by their corresponding values in nums.

The longest special paths are 1 -> 2 -> 4 and 1 -> 3 -> 6 -> 8, both having a length of 9. The minimum number of nodes across all longest special paths is 3.

Example 2:

Input: edges = [[1,0,3],[0,2,4],[0,3,5]], nums = [1,1,0,2]

Output: [5,2]

Explanation:

The longest path is 0 -> 3 consisting of 2 nodes with a length of 5.

Constraints:

Solution

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

@SuppressWarnings("java:S107")
public class Solution {
    public int[] longestSpecialPath(int[][] edges, int[] nums) {
        int[] ans = {0, 1};
        Map<Integer, List<int[]>> graph = new HashMap<>();
        for (int[] edge : edges) {
            int a = edge[0];
            int b = edge[1];
            int c = edge[2];
            graph.computeIfAbsent(a, k -> new ArrayList<>()).add(new int[] {b, c});
            graph.computeIfAbsent(b, k -> new ArrayList<>()).add(new int[] {a, c});
        }
        List<Integer> costs = new ArrayList<>();
        Map<Integer, Integer> last = new HashMap<>();
        dfs(0, 0, -1, new ArrayList<>(Arrays.asList(0, 0)), nums, graph, costs, last, ans);
        return ans;
    }

    private void dfs(
            int node,
            int currCost,
            int prev,
            List<Integer> left,
            int[] nums,
            Map<Integer, List<int[]>> graph,
            List<Integer> costs,
            Map<Integer, Integer> last,
            int[] ans) {
        int nodeColorIndexPrev = last.getOrDefault(nums[node], -1);
        last.put(nums[node], costs.size());
        costs.add(currCost);
        int diff = currCost - costs.get(left.get(0));
        int length = costs.size() - left.get(0);
        if (diff > ans[0] || (diff == ans[0] && length < ans[1])) {
            ans[0] = diff;
            ans[1] = length;
        }
        for (int[] next : graph.getOrDefault(node, new ArrayList<>())) {
            int nextNode = next[0];
            int nextCost = next[1];
            if (nextNode == prev) {
                continue;
            }
            List<Integer> nextLeft = new ArrayList<>(left);
            if (last.containsKey(nums[nextNode])) {
                nextLeft.add(last.get(nums[nextNode]) + 1);
            }
            nextLeft.sort(Comparator.naturalOrder());
            while (nextLeft.size() > 2) {
                nextLeft.remove(0);
            }
            dfs(nextNode, currCost + nextCost, node, nextLeft, nums, graph, costs, last, ans);
        }
        last.put(nums[node], nodeColorIndexPrev);
        costs.remove(costs.size() - 1);
    }
}