Skip to content

Commit 4379438

Browse files
perf: Speedup internal tree comparisons (#1060)
1 parent d1b87e2 commit 4379438

File tree

1 file changed

+42
-75
lines changed

1 file changed

+42
-75
lines changed

bigframes/core/nodes.py

Lines changed: 42 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ class Field:
5151
dtype: bigframes.dtypes.Dtype
5252

5353

54-
@dataclass(frozen=True)
55-
class BigFrameNode:
54+
@dataclass(eq=False, frozen=True)
55+
class BigFrameNode(abc.ABC):
5656
"""
5757
Immutable node for representing 2D typed array as a tree of operators.
5858
@@ -95,12 +95,30 @@ def session(self):
9595
return sessions[0]
9696
return None
9797

98+
def _as_tuple(self) -> Tuple:
99+
"""Get all fields as tuple."""
100+
return tuple(getattr(self, field.name) for field in fields(self))
101+
102+
def __hash__(self) -> int:
103+
# Custom hash that uses cache to avoid costly recomputation
104+
return self._cached_hash
105+
106+
def __eq__(self, other) -> bool:
107+
# Custom eq that tries to short-circuit full structural comparison
108+
if not isinstance(other, self.__class__):
109+
return False
110+
if self is other:
111+
return True
112+
if hash(self) != hash(other):
113+
return False
114+
return self._as_tuple() == other._as_tuple()
115+
98116
# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
99117
# Each subclass of BigFrameNode should use this property to implement __hash__
100118
# The default dataclass-generated __hash__ method is not cached
101119
@functools.cached_property
102-
def _node_hash(self):
103-
return hash(tuple(hash(getattr(self, field.name)) for field in fields(self)))
120+
def _cached_hash(self):
121+
return hash(self._as_tuple())
104122

105123
@property
106124
def roots(self) -> typing.Set[BigFrameNode]:
@@ -226,7 +244,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
226244
return self.transform_children(lambda x: x.prune(used_cols))
227245

228246

229-
@dataclass(frozen=True)
247+
@dataclass(frozen=True, eq=False)
230248
class UnaryNode(BigFrameNode):
231249
child: BigFrameNode
232250

@@ -252,7 +270,7 @@ def order_ambiguous(self) -> bool:
252270
return self.child.order_ambiguous
253271

254272

255-
@dataclass(frozen=True)
273+
@dataclass(frozen=True, eq=False)
256274
class JoinNode(BigFrameNode):
257275
left_child: BigFrameNode
258276
right_child: BigFrameNode
@@ -285,9 +303,6 @@ def explicitly_ordered(self) -> bool:
285303
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
286304
return False
287305

288-
def __hash__(self):
289-
return self._node_hash
290-
291306
@functools.cached_property
292307
def fields(self) -> Tuple[Field, ...]:
293308
return tuple(itertools.chain(self.left_child.fields, self.right_child.fields))
@@ -320,7 +335,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
320335
return self.transform_children(lambda x: x.prune(new_used))
321336

322337

323-
@dataclass(frozen=True)
338+
@dataclass(frozen=True, eq=False)
324339
class ConcatNode(BigFrameNode):
325340
# TODO: Explcitly map column ids from each child
326341
children: Tuple[BigFrameNode, ...]
@@ -345,9 +360,6 @@ def explicitly_ordered(self) -> bool:
345360
# Consider concat as an ordered operations (even though input frames may not be ordered)
346361
return True
347362

348-
def __hash__(self):
349-
return self._node_hash
350-
351363
@functools.cached_property
352364
def fields(self) -> Tuple[Field, ...]:
353365
# TODO: Output names should probably be aligned beforehand or be part of concat definition
@@ -371,16 +383,13 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
371383
return self
372384

373385

374-
@dataclass(frozen=True)
386+
@dataclass(frozen=True, eq=False)
375387
class FromRangeNode(BigFrameNode):
376388
# TODO: Enforce single-row, single column constraint
377389
start: BigFrameNode
378390
end: BigFrameNode
379391
step: int
380392

381-
def __hash__(self):
382-
return self._node_hash
383-
384393
@property
385394
def roots(self) -> typing.Set[BigFrameNode]:
386395
return {self}
@@ -419,7 +428,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
419428
# Input Nodex
420429
# TODO: Most leaf nodes produce fixed column names based on the datasource
421430
# They should support renaming
422-
@dataclass(frozen=True)
431+
@dataclass(frozen=True, eq=False)
423432
class LeafNode(BigFrameNode):
424433
@property
425434
def roots(self) -> typing.Set[BigFrameNode]:
@@ -451,7 +460,7 @@ class ScanList:
451460
items: typing.Tuple[ScanItem, ...]
452461

453462

454-
@dataclass(frozen=True)
463+
@dataclass(frozen=True, eq=False)
455464
class ReadLocalNode(LeafNode):
456465
feather_bytes: bytes
457466
data_schema: schemata.ArraySchema
@@ -460,9 +469,6 @@ class ReadLocalNode(LeafNode):
460469
scan_list: ScanList
461470
session: typing.Optional[bigframes.session.Session] = None
462471

463-
def __hash__(self):
464-
return self._node_hash
465-
466472
@functools.cached_property
467473
def fields(self) -> Tuple[Field, ...]:
468474
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
@@ -547,7 +553,7 @@ class BigqueryDataSource:
547553

548554

549555
## Put ordering in here or just add order_by node above?
550-
@dataclass(frozen=True)
556+
@dataclass(frozen=True, eq=False)
551557
class ReadTableNode(LeafNode):
552558
source: BigqueryDataSource
553559
# Subset of physical schema column
@@ -570,9 +576,6 @@ def __post_init__(self):
570576
def session(self):
571577
return self.table_session
572578

573-
def __hash__(self):
574-
return self._node_hash
575-
576579
@functools.cached_property
577580
def fields(self) -> Tuple[Field, ...]:
578581
return tuple(Field(col_id, dtype) for col_id, dtype, _ in self.scan_list.items)
@@ -616,15 +619,12 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
616619
return ReadTableNode(self.source, new_scan_list, self.table_session)
617620

618621

619-
@dataclass(frozen=True)
622+
@dataclass(frozen=True, eq=False)
620623
class CachedTableNode(ReadTableNode):
621624
# The original BFET subtree that was cached
622625
# note: this isn't a "child" node.
623626
original_node: BigFrameNode = field()
624627

625-
def __hash__(self):
626-
return self._node_hash
627-
628628
def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
629629
new_scan_list = ScanList(
630630
tuple(item for item in self.scan_list.items if item.id in used_cols)
@@ -635,13 +635,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
635635

636636

637637
# Unary nodes
638-
@dataclass(frozen=True)
638+
@dataclass(frozen=True, eq=False)
639639
class PromoteOffsetsNode(UnaryNode):
640640
col_id: bigframes.core.identifiers.ColumnId
641641

642-
def __hash__(self):
643-
return self._node_hash
644-
645642
@property
646643
def non_local(self) -> bool:
647644
return True
@@ -666,17 +663,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
666663
return self.transform_children(lambda x: x.prune(new_used))
667664

668665

669-
@dataclass(frozen=True)
666+
@dataclass(frozen=True, eq=False)
670667
class FilterNode(UnaryNode):
671668
predicate: ex.Expression
672669

673670
@property
674671
def row_preserving(self) -> bool:
675672
return False
676673

677-
def __hash__(self):
678-
return self._node_hash
679-
680674
@property
681675
def variables_introduced(self) -> int:
682676
return 1
@@ -687,13 +681,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
687681
return FilterNode(pruned_child, self.predicate)
688682

689683

690-
@dataclass(frozen=True)
684+
@dataclass(frozen=True, eq=False)
691685
class OrderByNode(UnaryNode):
692686
by: Tuple[OrderingExpression, ...]
693687

694-
def __hash__(self):
695-
return self._node_hash
696-
697688
@property
698689
def variables_introduced(self) -> int:
699690
return 0
@@ -716,14 +707,11 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
716707
return OrderByNode(pruned_child, self.by)
717708

718709

719-
@dataclass(frozen=True)
710+
@dataclass(frozen=True, eq=False)
720711
class ReversedNode(UnaryNode):
721712
# useless field to make sure has distinct hash
722713
reversed: bool = True
723714

724-
def __hash__(self):
725-
return self._node_hash
726-
727715
@property
728716
def variables_introduced(self) -> int:
729717
return 0
@@ -734,15 +722,12 @@ def relation_ops_created(self) -> int:
734722
return 0
735723

736724

737-
@dataclass(frozen=True)
725+
@dataclass(frozen=True, eq=False)
738726
class SelectionNode(UnaryNode):
739727
input_output_pairs: typing.Tuple[
740728
typing.Tuple[ex.DerefOp, bigframes.core.identifiers.ColumnId], ...
741729
]
742730

743-
def __hash__(self):
744-
return self._node_hash
745-
746731
@functools.cached_property
747732
def fields(self) -> Tuple[Field, ...]:
748733
return tuple(
@@ -772,7 +757,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
772757
return SelectionNode(pruned_child, pruned_selections)
773758

774759

775-
@dataclass(frozen=True)
760+
@dataclass(frozen=True, eq=False)
776761
class ProjectionNode(UnaryNode):
777762
"""Assigns new variables (without modifying existing ones)"""
778763

@@ -788,9 +773,6 @@ def __post_init__(self):
788773
# Cannot assign to existing variables - append only!
789774
assert all(name not in self.child.schema.names for _, name in self.assignments)
790775

791-
def __hash__(self):
792-
return self._node_hash
793-
794776
@functools.cached_property
795777
def fields(self) -> Tuple[Field, ...]:
796778
input_types = self.child._dtype_lookup
@@ -819,7 +801,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
819801

820802
# TODO: Merge RowCount into Aggregate Node?
821803
# Row count can be compute from table metadata sometimes, so it is a bit special.
822-
@dataclass(frozen=True)
804+
@dataclass(frozen=True, eq=False)
823805
class RowCountNode(UnaryNode):
824806
@property
825807
def row_preserving(self) -> bool:
@@ -842,7 +824,7 @@ def defines_namespace(self) -> bool:
842824
return True
843825

844826

845-
@dataclass(frozen=True)
827+
@dataclass(frozen=True, eq=False)
846828
class AggregateNode(UnaryNode):
847829
aggregations: typing.Tuple[
848830
typing.Tuple[ex.Aggregation, bigframes.core.identifiers.ColumnId], ...
@@ -854,9 +836,6 @@ class AggregateNode(UnaryNode):
854836
def row_preserving(self) -> bool:
855837
return False
856838

857-
def __hash__(self):
858-
return self._node_hash
859-
860839
@property
861840
def non_local(self) -> bool:
862841
return True
@@ -904,7 +883,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
904883
return AggregateNode(pruned_child, pruned_aggs, self.by_column_ids, self.dropna)
905884

906885

907-
@dataclass(frozen=True)
886+
@dataclass(frozen=True, eq=False)
908887
class WindowOpNode(UnaryNode):
909888
column_name: ex.DerefOp
910889
op: agg_ops.UnaryWindowOp
@@ -913,9 +892,6 @@ class WindowOpNode(UnaryNode):
913892
never_skip_nulls: bool = False
914893
skip_reproject_unsafe: bool = False
915894

916-
def __hash__(self):
917-
return self._node_hash
918-
919895
@property
920896
def non_local(self) -> bool:
921897
return True
@@ -945,11 +921,8 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
945921

946922

947923
# TODO: Remove this op
948-
@dataclass(frozen=True)
924+
@dataclass(frozen=True, eq=False)
949925
class ReprojectOpNode(UnaryNode):
950-
def __hash__(self):
951-
return self._node_hash
952-
953926
@property
954927
def variables_introduced(self) -> int:
955928
return 0
@@ -960,7 +933,7 @@ def relation_ops_created(self) -> int:
960933
return 0
961934

962935

963-
@dataclass(frozen=True)
936+
@dataclass(frozen=True, eq=False)
964937
class RandomSampleNode(UnaryNode):
965938
fraction: float
966939

@@ -972,26 +945,20 @@ def deterministic(self) -> bool:
972945
def row_preserving(self) -> bool:
973946
return False
974947

975-
def __hash__(self):
976-
return self._node_hash
977-
978948
@property
979949
def variables_introduced(self) -> int:
980950
return 1
981951

982952

983953
# TODO: Explode should create a new column instead of overriding the existing one
984-
@dataclass(frozen=True)
954+
@dataclass(frozen=True, eq=False)
985955
class ExplodeNode(UnaryNode):
986956
column_ids: typing.Tuple[ex.DerefOp, ...]
987957

988958
@property
989959
def row_preserving(self) -> bool:
990960
return False
991961

992-
def __hash__(self):
993-
return self._node_hash
994-
995962
@functools.cached_property
996963
def fields(self) -> Tuple[Field, ...]:
997964
return tuple(

0 commit comments

Comments
 (0)