Suppose we have a 2D list of values called 'tree' which represents an n-ary tree and another list of values called 'color'. The tree is represented as an adjacency list and its root is tree[0].
The characteristics of an i-th node −
tree[i] is its children and parent.
color[i] is its color.
We call a node N "special" if every node in the subtree whose root is at N has a unique color. So we have this tree, we have to find out the number of special nodes.
So, if the input is like tree = [ [1,2], [0], [0,3], [2] ]
colors = [1, 2, 1, 1], then the output will be 2.
To solve this, we will follow these steps −
result := 0
dfs(0, -1)
return result
Define a function check_intersection() . This will take colors, child_colors
if length of (colors) < length of (child_colors) , then
for each c in colors, do
if c in child_colors is non-zero, then
return True
otherwise,
for each c in child_colors, do
if c is present in child_colors, then
return True
Define a function dfs() . This will take node, prev
colors := {color[node]}
for each child in tree[node], do
if child is not same as prev, then
child_colors := dfs(child, node)
if colors and child_colors are not empty, then
if check_intersection(colors, child_colors) is non-zero, then
colors := null
otherwise,
if length of (colors) < length of (child_colors),then,
child_colors := child_colors OR colors
colors := child_colors
otherwise,
colors := colors OR child_colors
otherwise,
colors := null
if colors is not empty, then
result := result + 1
return colors
Example
Let us see the following implementation to get better understanding −
import collections class Solution: def solve(self, tree, color): self.result = 0 def dfs(node, prev): colors = {color[node]} for child in tree[node]: if child != prev: child_colors = dfs(child, node) if colors and child_colors: if self.check_intersection(colors, child_colors): colors = None else: if len(colors) < len(child_colors): child_colors |= colors colors = child_colors else: colors |= child_colors else: colors = None if colors: self.result += 1 return colors dfs(0, -1) return self.result def check_intersection(self, colors, child_colors): if len(colors) < len(child_colors): for c in colors: if c in child_colors: return True else: for c in child_colors: if c in colors: return True ob = Solution() print(ob.solve( [ [1,2], [0], [0,3], [2] ], [1, 2, 1, 1]))
Input
[ [1,2], [0], [0,3], [2] ], [1, 2, 1, 1]
Output
2