LeetCode-in-Java

1373. Maximum Sum BST in Binary Tree

Hard

Given a binary tree root, return the maximum sum of all keys of any sub-tree which is also a Binary Search Tree (BST).

Assume a BST is defined as follows:

Example 1:

Input: root = [1,4,3,2,4,2,5,null,null,null,null,null,null,4,6]

Output: 20

Explanation: Maximum sum in a valid Binary search tree is obtained in root node with key equal to 3.

Example 2:

Input: root = [4,3,null,1,2]

Output: 2

Explanation: Maximum sum in a valid Binary search tree is obtained in a single root node with key equal to 2.

Example 3:

Input: root = [-4,-2,-5]

Output: 0

Explanation: All values are negatives. Return an empty BST.

Constraints:

Solution

import com_github_leetcode.TreeNode;

/*
 * Definition for a binary tree node.
 * public class TreeNode {
 *     int val;
 *     TreeNode left;
 *     TreeNode right;
 *     TreeNode() {}
 *     TreeNode(int val) { this.val = val; }
 *     TreeNode(int val, TreeNode left, TreeNode right) {
 *         this.val = val;
 *         this.left = left;
 *         this.right = right;
 *     }
 * }
 */
@SuppressWarnings("java:S1700")
public class Solution {
    public int maxSumBST(TreeNode root) {
        IsBST temp = checkBST(root);
        return Math.max(temp.maxSum, 0);
    }

    private static class IsBST {
        int max = Integer.MIN_VALUE;
        int min = Integer.MAX_VALUE;
        boolean isBst = true;
        int sum = 0;
        int maxSum = Integer.MIN_VALUE;
    }

    private IsBST checkBST(TreeNode root) {
        if (root == null) {
            return new IsBST();
        }
        IsBST lp = checkBST(root.left);
        IsBST rp = checkBST(root.right);
        IsBST mp = new IsBST();
        mp.max = Math.max(root.val, Math.max(lp.max, rp.max));
        mp.min = Math.min(root.val, Math.min(lp.min, rp.min));
        mp.sum = lp.sum + rp.sum + root.val;
        boolean check = root.val > lp.max && root.val < rp.min;
        if (lp.isBst && rp.isBst && check) {
            mp.isBst = true;
            int tempMax = Math.max(mp.sum, Math.max(lp.sum, rp.sum));
            mp.maxSum = Math.max(tempMax, Math.max(lp.maxSum, rp.maxSum));
        } else {
            mp.isBst = false;
            mp.maxSum = Math.max(lp.maxSum, rp.maxSum);
        }
        return mp;
    }
}