LeetCode-in-Java

2791. Count Paths That Can Form a Palindrome in a Tree

Hard

You are given a tree (i.e. a connected, undirected graph that has no cycles) rooted at node 0 consisting of n nodes numbered from 0 to n - 1. The tree is represented by a 0-indexed array parent of size n, where parent[i] is the parent of node i. Since node 0 is the root, parent[0] == -1.

You are also given a string s of length n, where s[i] is the character assigned to the edge between i and parent[i]. s[0] can be ignored.

Return the number of pairs of nodes (u, v) such that u < v and the characters assigned to edges on the path from u to v can be rearranged to form a palindrome.

A string is a palindrome when it reads the same backwards as forwards.

Example 1:

Input: parent = [-1,0,0,1,1,2], s = “acaabc”

Output: 8

Explanation:

The valid pairs are:

Example 2:

Input: parent = [-1,0,0,0,0], s = “aaaaa”

Output: 10

Explanation: Any pair of nodes (u,v) where u < v is valid.

Constraints:

Solution

import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class Solution {
    private int getMap(List<Integer> parent, String s, int[] dp, int idx) {
        if (dp[idx] < 0) {
            dp[idx] = 0;
            dp[idx] = getMap(parent, s, dp, parent.get(idx)) ^ (1 << (s.charAt(idx) - 'a'));
        }
        return dp[idx];
    }

    public long countPalindromePaths(List<Integer> parent, String s) {
        int n = parent.size();
        int[] dp = new int[n];
        long ans = 0;
        Map<Integer, Integer> mapCount = new HashMap<>();
        Arrays.fill(dp, -1);
        dp[0] = 0;
        for (int i = 0; i < n; i++) {
            int currMap = getMap(parent, s, dp, i);
            int evenCount = mapCount.getOrDefault(currMap, 0);
            mapCount.put(currMap, evenCount + 1);
        }
        for (Map.Entry<Integer, Integer> entry : mapCount.entrySet()) {
            int value = entry.getValue();
            ans += (long) value * (value - 1) / 2;
            for (int i = 0; i <= 25; i++) {
                int base = 1 << i;
                if ((entry.getKey() & base) > 0 && mapCount.containsKey(entry.getKey() ^ base)) {
                    ans += (long) value * mapCount.get(entry.getKey() ^ base);
                }
            }
        }
        return ans;
    }
}