# Define a class named Compressed2DBIT
class Compressed2DBIT:
# Define the constructor
# with arguments n and m
def __init__(self, n, m):
self.n = n
self.m = m
# Create a 2D BIT (Binary Indexed
# Tree) with zeros, with dimensions
# n + 1 and m + 1 initialize 2D
# BIT with zeros
self.bit = [[0] * (m + 1) for _ in range(n + 1)]
self.tree = None
def update(self, x, y, val):
"""
Update the value at (x, y) in the 2D array by adding `val` to it.
"""
# Update the values in the BIT
# based on the provided x, y,
# and val
while x <= self.n:
y1 = y
while y1 <= self.m:
self.bit[x][y1] += val
# Compute the next y value
# to update based on the
# current y value
y1 += y1 & -y1
# Compute the next x value
# to update based on the
# current x value
x += x & -x
def query(self, x, y):
"""
Query the sum of the elements in
the subarray from (1, 1) to (x, y)
"""
s = 0
while x > 0:
y1 = y
while y1 > 0:
s += self.bit[x][y1]
y1 -= y1 & -y1
x -= x & -x
return s
def compress(self):
"""
Compress the 2D array using the
Fenwick tree (Binary Indexed Tree)
technique.
"""
# initialize compressed 2D array
# with zeros
self.tree = [
[0] * self.m for _ in range(self.n)]
for i in range(1, self.n + 1):
for j in range(1, self.m + 1):
# Calculate the sum of the
# elements in the subarray
# from (1, 1) to (i, j)
# using the formula:
# sum(x1, y1, x2, y2) =
# sum(x2, y2) - sum(x1-1, y2)
# - sum(x2, y1-1) +
# sum(x1-1, y1-1)
self.tree[i-1][j-1] = self.query(i, j) - self.query(
i-1, j) - self.query(i, j-1) + self.query(i-1, j-1)
# set the 2D BIT to None
# to save memory
self.bit_compressed = None
# Example usage
arr = [[1, 2, 3, 4, 5],
[6, 7, 8, 9, 10],
[11, 12, 13, 14, 15],
[16, 17, 18, 19, 20]]
bit = Compressed2DBIT(4, 5)
for i in range(4):
for j in range(5):
bit.update(i + 1, j + 1, arr[i][j])
print(bit.query(2, 3)) # expected output: 27
print(bit.query(4, 5)) # expected output: 210
bit.compress()
print(bit.query(2, 3)) # expected output: 27
print(bit.query(4, 5)) # expected output: 210
# expected output:
# [[1, 2, 3, 4, 5], [6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20]]
print(bit.tree)