// Java program to transform a BST to
// sum tree
import java.util.HashMap;
class Node {
int data;
Node left, right;
Node(int value) {
data = value;
left = null;
right = null;
}
}
class GfG {
// Function to find nodes having greater value than
// current node.
static void findGreaterNodes(Node root, Node curr,
HashMap<Node, Integer> mp) {
if (root == null) return;
// if value is greater than equal to node,
// then increment it in the map
if (root.data >= curr.data)
mp.put(curr, mp.getOrDefault(curr, 0) + root.data);
findGreaterNodes(root.left, curr, mp);
findGreaterNodes(root.right, curr, mp);
}
static void transformToGreaterSumTree(Node curr, Node root,
HashMap<Node, Integer> mp) {
if (curr == null) {
return;
}
// Find all nodes greater than current node
findGreaterNodes(root, curr, mp);
// Recursively check for left and right subtree.
transformToGreaterSumTree(curr.left, root, mp);
transformToGreaterSumTree(curr.right, root, mp);
}
// Function to update value of each node.
static void preOrderTrav(Node root, HashMap<Node, Integer> mp) {
if (root == null) return;
root.data = mp.getOrDefault(root, 0);
preOrderTrav(root.left, mp);
preOrderTrav(root.right, mp);
}
static void transformTree(Node root) {
// map to store greater sum for each node.
HashMap<Node, Integer> mp = new HashMap<>();
transformToGreaterSumTree(root, root, mp);
// update the value of nodes
preOrderTrav(root, mp);
}
static void inorder(Node root) {
if (root == null) {
return;
}
inorder(root.left);
System.out.print(root.data + " ");
inorder(root.right);
}
public static void main(String[] args) {
// Representation of input binary tree:
// 50
// / \
// 30 70
// / \ / \
// 20 40 60 80
Node root = new Node(50);
root.left = new Node(30);
root.right = new Node(70);
root.left.left = new Node(20);
root.left.right = new Node(40);
root.right.left = new Node(60);
root.right.right = new Node(80);
transformTree(root);
inorder(root);
}
}