// Java program to find the maximum element in the path
// between two Nodes of Binary Search Tree.
import java.util.*;
class Node {
int data;
Node left, right;
Node(int x) {
data = x;
left = right = null;
}
}
class GfG {
// Insert a new Node in Binary Search Tree
static void insertNode(Node root, int x) {
Node current = root, parent = null;
// Traverse to the correct position
// for insertion
while (current != null) {
parent = current;
if (x < current.data)
current = current.left;
else
current = current.right;
}
// Insert new Node at the correct position
if (parent == null)
root = new Node(x);
else if (x < parent.data)
parent.left = new Node(x);
else
parent.right = new Node(x);
}
// DFS to populate parent map for each node
static void dfs(Node root, Map<Node, Node> parentMap,
Node parent) {
if (root == null)
return;
// Store the parent of the current node
if (parent != null) {
parentMap.put(root, parent);
}
// Recur for left and right children
dfs(root.left, parentMap, root);
dfs(root.right, parentMap, root);
}
// Function to find the node with the given
// value in the BST
static Node findNode(Node root, int val) {
if (root == null)
return null;
if (root.data == val)
return root;
Node leftResult = findNode(root.left, val);
if (leftResult != null)
return leftResult;
return findNode(root.right, val);
}
// Find maximum element in the path between
// two nodes in BST
static int findMaxElement(Node root, int x, int y) {
Map<Node, Node> parentMap = new HashMap<>();
// Populate parent map with DFS
dfs(root, parentMap, null);
// Find the nodes corresponding to
// the values x and y
Node p1 = findNode(root, x);
Node p2 = findNode(root, y);
// If nodes not found
if (p1 == null || p2 == null)
return -1;
// Sets to store nodes encountered
// while traversing up the tree
Set<Node> s1 = new HashSet<>();
Set<Node> s2 = new HashSet<>();
// Variable to store the maximum element
// in the path
int maxElement = Integer.MIN_VALUE;
// Traverse up the tree from p1 and p2
// and add nodes to sets s1 and s2
while (p1 != p2) {
if (p1 != null) {
s1.add(p1);
maxElement = Math.max(maxElement, p1.data);
// Move to parent node
p1 = parentMap.get(p1);
}
if (p2 != null) {
s2.add(p2);
maxElement = Math.max(maxElement, p2.data);
p2 = parentMap.get(p2);
}
// Check if there's a common node in both sets
if (s1.contains(p2))
break;
if (s2.contains(p1))
break;
}
// Now both p1 and p2 point to their
// Lowest Common Ancestor (LCA)
maxElement = Math.max(maxElement, p1.data);
return maxElement;
}
public static void main(String[] args) {
int[] arr = {18, 36, 9, 6, 12, 10, 1, 8};
int a = 1, b = 10;
int n = arr.length;
Node root = new Node(arr[0]);
for (int i = 1; i < n; i++)
insertNode(root, arr[i]);
System.out.println(findMaxElement(root, a, b));
}
}