# Python program to find number of
# perfect square numbers in a
# subarray and performing updates
from math import sqrt, floor, ceil, log2
from typing import List
MAX = 1000
# Function to check if a number is
# a perfect square or not
def isPerfectSquare(x: int) -> bool:
# Find floating point value of
# square root of x.
sr = sqrt(x)
# If square root is an integer
return True if ((sr - floor(sr)) == 0) else False
# A utility function to get the middle
# index from corner indexes.
def getMid(s: int, e: int) -> int:
return s + (e - s) // 2
# Recursive function to get the number
# of perfect square numbers in a given
# range
''' where
st --> Pointer to segment tree
index --> Index of current node in the
segment tree. Initially 0 is
passed as root is always
at index 0
ss & se --> Starting and ending indexes
of the segment represented by
current node i.e. st[index]
qs & qe --> Starting and ending indexes
of query range '''
def queryUtil(st: List[int], ss: int, se: int,
qs: int, qe: int,
index: int) -> int:
# If segment of this node is a part
# of given range, then return
# the number of perfect square numbers
# in the segment
if (qs <= ss and qe >= se):
return st[index]
# If segment of this node
# is outside the given range
if (se < qs or ss > qe):
return 0
# If a part of this segment
# overlaps with the given range
mid = getMid(ss, se)
return queryUtil(st, ss, mid, qs, qe, 2 * index + 1) + queryUtil(
st, mid + 1, se, qs, qe, 2 * index + 2)
# Recursive function to update
# the nodes which have the given
# index in their range.
''' where
st, si, ss & se are same as getSumUtil()
i --> index of the element to be updated.
This index is in input array.
diff --> Value to be added to all nodes
which have i in range
'''
def updateValueUtil(st: List[int], ss: int, se: int,
i: int, diff: int,
si: int) -> None:
# Base Case:
# If the input index lies outside
# the range of this segment
if (i < ss or i > se):
return
# If the input index is in range
# of this node, then update the value
# of the node and its children
st[si] = st[si] + diff
if (se != ss):
mid = getMid(ss, se)
updateValueUtil(st, ss, mid, i,
diff, 2 * si + 1)
updateValueUtil(st, mid + 1, se, i,
diff, 2 * si + 2)
# Function to update a value in the
# input array and segment tree.
# It uses updateValueUtil() to update
# the value in segment tree
def updateValue(arr: List[int], st: List[int],
n: int, i: int,
new_val: int) -> None:
# Check for erroneous input index
if (i < 0 or i > n - 1):
print("Invalid Input")
return
diff = 0
oldValue = 0
oldValue = arr[i]
# Update the value in array
arr[i] = new_val
# Case 1: Old and new values
# both are perfect square numbers
if (isPerfectSquare(oldValue) and isPerfectSquare(new_val)):
return
# Case 2: Old and new values
# both not perfect square numbers
if (not isPerfectSquare(oldValue) and not isPerfectSquare(new_val)):
return
# Case 3: Old value was perfect square,
# new value is not a perfect square
if (isPerfectSquare(oldValue) and not isPerfectSquare(new_val)):
diff = -1
# Case 4: Old value was
# non-perfect square,
# new_val is perfect square
if (not isPerfectSquare(oldValue) and not isPerfectSquare(new_val)):
diff = 1
# Update values of nodes in segment tree
updateValueUtil(st, 0, n - 1, i, diff, 0)
# Return no. of perfect square numbers
# in range from index qs (query start)
# to qe (query end).
# It mainly uses queryUtil()
def query(st: List[int], n: int, qs: int, qe: int) -> None:
perfectSquareInRange = queryUtil(st, 0, n - 1, qs, qe, 0)
print(perfectSquareInRange)
# Recursive function that constructs
# Segment Tree for array[ss..se].
# si is index of current node
# in segment tree st
def constructSTUtil(arr: List[int], ss: int, se: int, st: List[int],
si: int) -> int:
# If there is one element in array,
# check if it is perfect square number
# then store 1 in the segment tree
# else store 0 and return
if (ss == se):
# if arr[ss] is a perfect
# square number
if (isPerfectSquare(arr[ss])):
st[si] = 1
else:
st[si] = 0
return st[si]
# If there are more than one
# elements, then recur for
# left and right subtrees
# and store the sum of the
# two values in this node
mid = getMid(ss, se)
st[si] = constructSTUtil(arr, ss, mid, st, si * 2 + 1) + constructSTUtil(
arr, mid + 1, se, st, si * 2 + 2)
return st[si]
# Function to construct a segment
# tree from given array. This
# function allocates memory for
# segment tree and calls
# constructSTUtil() to fill
# the allocated memory
def constructST(arr: List[int], n: int) -> List[int]:
# Allocate memory for segment tree
# Height of segment tree
x = (ceil(log2(n)))
# Maximum size of segment tree
max_size = 2 * pow(2, x) - 1
st = [0 for _ in range(max_size)]
# Fill the allocated memory st
constructSTUtil(arr, 0, n - 1, st, 0)
# Return the constructed segment tree
return st
# Driver Code
if __name__ == "__main__":
arr = [16, 15, 8, 9, 14, 25]
n = len(arr)
# Build segment tree from given array
st = constructST(arr, n)
# Query 1: Query(start = 0, end = 4)
start = 0
end = 4
query(st, n, start, end)
# Query 2: Update(i = 3, x = 11),
# i.e Update a[i] to x
i = 3
x = 11
updateValue(arr, st, n, i, x)
# uncomment to see array after update
# for(int i = 0; i < n; i++)
# cout << arr[i] << " ";
# Query 3: Query(start = 0, end = 4)
start = 0
end = 4
query(st, n, start, end)
# This code is contributed by sanjeev2552