Hard
You are given the root
of a binary tree with n
nodes. Each node is assigned a unique value from 1
to n
. You are also given an array queries
of size m
.
You have to perform m
independent queries on the tree where in the ith
query you do the following:
queries[i]
from the tree. It is guaranteed that queries[i]
will not be equal to the value of the root.Return an array answer
of size m
where answer[i]
is the height of the tree after performing the ith
query.
Note:
Example 1:
Input: root = [1,3,4,2,null,6,5,null,null,null,null,null,7], queries = [4]
Output: [2]
Explanation: The diagram above shows the tree after removing the subtree rooted at node with value 4. The height of the tree is 2 (The path 1 -> 3 -> 2).
Example 2:
Input: root = [5,8,9,2,1,3,7,4,6], queries = [3,2,4,8]
Output: [3,2,3,2]
Explanation: We have the following queries:
Removing the subtree rooted at node with value 3. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 4).
Removing the subtree rooted at node with value 2. The height of the tree becomes 2 (The path 5 -> 8 -> 1).
Removing the subtree rooted at node with value 4. The height of the tree becomes 3 (The path 5 -> 8 -> 2 -> 6).
Removing the subtree rooted at node with value 8. The height of the tree becomes 2 (The path 5 -> 9 -> 3).
Constraints:
n
.2 <= n <= 105
1 <= Node.val <= n
m == queries.length
1 <= m <= min(n, 104)
1 <= queries[i] <= n
queries[i] != root.val
import com_github_leetcode.TreeNode;
import java.util.HashMap;
import java.util.Map;
public class Solution {
public int[] treeQueries(TreeNode root, int[] queries) {
Map<Integer, int[]> levels = new HashMap<>();
Map<Integer, int[]> map = new HashMap<>();
int max = dfs(root, 0, map, levels) - 1;
int n = queries.length;
for (int i = 0; i < n; i++) {
int q = queries[i];
int[] node = map.get(q);
int height = node[0];
int level = node[1];
int[] lev = levels.get(level);
if (lev[0] == height) {
if (lev[1] != -1) {
queries[i] = max - Math.abs(lev[0] - lev[1]);
} else {
queries[i] = max - height - 1;
}
} else {
queries[i] = max;
}
}
return queries;
}
private int dfs(TreeNode root, int level, Map<Integer, int[]> map, Map<Integer, int[]> levels) {
if (root == null) {
return 0;
}
int left = dfs(root.left, level + 1, map, levels);
int right = dfs(root.right, level + 1, map, levels);
int height = Math.max(left, right);
int[] lev = levels.getOrDefault(level, new int[] {-1, -1});
if (height >= lev[0]) {
lev[1] = lev[0];
lev[0] = height;
} else {
lev[1] = Math.max(lev[1], height);
}
levels.put(level, lev);
map.put(root.val, new int[] {height, level});
return Math.max(left, right) + 1;
}
}