LeetCode-in-Java

2867. Count Valid Paths in a Tree

Hard

There is an undirected tree with n nodes labeled from 1 to n. You are given the integer n and a 2D integer array edges of length n - 1, where edges[i] = [ui, vi] indicates that there is an edge between nodes ui and vi in the tree.

Return the number of valid paths in the tree.

A path (a, b) is valid if there exists exactly one prime number among the node labels in the path from a to b.

Note that:

Example 1:

Input: n = 5, edges = [[1,2],[1,3],[2,4],[2,5]]

Output: 4

Explanation: The pairs with exactly one prime number on the path between them are:

It can be shown that there are only 4 valid paths.

Example 2:

Input: n = 6, edges = [[1,2],[1,3],[2,4],[3,5],[3,6]]

Output: 6

Explanation: The pairs with exactly one prime number on the path between them are:

It can be shown that there are only 6 valid paths.

Constraints:

Solution

import java.util.ArrayList;
import java.util.List;

@SuppressWarnings("unchecked")
public class Solution {
    private boolean[] isPrime;
    private List<Integer>[] treeEdges;
    private long r;

    private boolean[] preparePrime(int n) {
        // Sieve of Eratosthenes < 3
        boolean[] isPrimeLocal = new boolean[n + 1];
        for (int i = 2; i < n + 1; i++) {
            isPrimeLocal[i] = true;
        }
        for (int i = 2; i <= n / 2; i++) {
            for (int j = 2 * i; j < n + 1; j += i) {
                isPrimeLocal[j] = false;
            }
        }
        return isPrimeLocal;
    }

    private List<Integer>[] prepareTree(int n, int[][] edges) {
        List<Integer>[] treeEdgesLocal = new List[n + 1];
        for (int[] edge : edges) {
            if (treeEdgesLocal[edge[0]] == null) {
                treeEdgesLocal[edge[0]] = new ArrayList<>();
            }
            treeEdgesLocal[edge[0]].add(edge[1]);
            if (treeEdgesLocal[edge[1]] == null) {
                treeEdgesLocal[edge[1]] = new ArrayList<>();
            }
            treeEdgesLocal[edge[1]].add(edge[0]);
        }
        return treeEdgesLocal;
    }

    private long[] countPathDfs(int node, int parent) {
        long[] v = new long[] {isPrime[node] ? 0 : 1, isPrime[node] ? 1 : 0};
        List<Integer> edges = treeEdges[node];
        if (edges == null) {
            return v;
        }
        for (Integer neigh : edges) {
            if (neigh == parent) {
                continue;
            }
            long[] ce = countPathDfs(neigh, node);
            r += v[0] * ce[1] + v[1] * ce[0];
            if (isPrime[node]) {
                v[1] += ce[0];
            } else {
                v[0] += ce[0];
                v[1] += ce[1];
            }
        }
        return v;
    }

    public long countPaths(int n, int[][] edges) {
        isPrime = preparePrime(n);
        treeEdges = prepareTree(n, edges);
        r = 0;
        countPathDfs(1, 0);
        return r;
    }
}