LeetCode-in-Java

2458. Height of Binary Tree After Subtree Removal Queries

Hard

You are given the root of a binary tree with n nodes. Each node is assigned a unique value from 1 to n. You are also given an array queries of size m.

You have to perform m independent queries on the tree where in the ith query you do the following:

Return an array answer of size m where answer[i] is the height of the tree after performing the ith query.

Note:

Example 1:

Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]

Output: [2]

Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4. The height of the tree is 2 (The path 1 -> 3 -> 2).

Example 2:

Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]

Output: [3,2,3,2]

Explanation: We have the following queries:

Constraints:

Solution

import com_github_leetcode.TreeNode;
import java.util.HashMap;
import java.util.Map;

public class Solution {
    public int[] treeQueries(TreeNode root, int[] queries) {
        Map<Integer, int[]> levels = new HashMap<>();
        Map<Integer, int[]> map = new HashMap<>();
        int max = dfs(root, 0, map, levels) - 1;
        int n = queries.length;
        for (int i = 0; i < n; i++) {
            int q = queries[i];
            int[] node = map.get(q);
            int height = node[0];
            int level = node[1];
            int[] lev = levels.get(level);
            if (lev[0] == height) {
                if (lev[1] != -1) {
                    queries[i] = max - Math.abs(lev[0] - lev[1]);
                } else {
                    queries[i] = max - height - 1;
                }
            } else {
                queries[i] = max;
            }
        }
        return queries;
    }

    private int dfs(TreeNode root, int level, Map<Integer, int[]> map, Map<Integer, int[]> levels) {
        if (root == null) {
            return 0;
        }
        int left = dfs(root.left, level + 1, map, levels);
        int right = dfs(root.right, level + 1, map, levels);
        int height = Math.max(left, right);
        int[] lev = levels.getOrDefault(level, new int[] {-1, -1});
        if (height >= lev[0]) {
            lev[1] = lev[0];
            lev[0] = height;
        } else {
            lev[1] = Math.max(lev[1], height);
        }
        levels.put(level, lev);
        map.put(root.val, new int[] {height, level});
        return Math.max(left, right) + 1;
    }
}