@@ -51,8 +51,8 @@ class Field:
51
51
dtype : bigframes .dtypes .Dtype
52
52
53
53
54
- @dataclass (frozen = True )
55
- class BigFrameNode :
54
+ @dataclass (eq = False , frozen = True )
55
+ class BigFrameNode ( abc . ABC ) :
56
56
"""
57
57
Immutable node for representing 2D typed array as a tree of operators.
58
58
@@ -95,12 +95,30 @@ def session(self):
95
95
return sessions [0 ]
96
96
return None
97
97
98
+ def _as_tuple (self ) -> Tuple :
99
+ """Get all fields as tuple."""
100
+ return tuple (getattr (self , field .name ) for field in fields (self ))
101
+
102
+ def __hash__ (self ) -> int :
103
+ # Custom hash that uses cache to avoid costly recomputation
104
+ return self ._cached_hash
105
+
106
+ def __eq__ (self , other ) -> bool :
107
+ # Custom eq that tries to short-circuit full structural comparison
108
+ if not isinstance (other , self .__class__ ):
109
+ return False
110
+ if self is other :
111
+ return True
112
+ if hash (self ) != hash (other ):
113
+ return False
114
+ return self ._as_tuple () == other ._as_tuple ()
115
+
98
116
# BigFrameNode trees can be very deep so its important avoid recalculating the hash from scratch
99
117
# Each subclass of BigFrameNode should use this property to implement __hash__
100
118
# The default dataclass-generated __hash__ method is not cached
101
119
@functools .cached_property
102
- def _node_hash (self ):
103
- return hash (tuple ( hash ( getattr ( self , field . name )) for field in fields ( self ) ))
120
+ def _cached_hash (self ):
121
+ return hash (self . _as_tuple ( ))
104
122
105
123
@property
106
124
def roots (self ) -> typing .Set [BigFrameNode ]:
@@ -226,7 +244,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
226
244
return self .transform_children (lambda x : x .prune (used_cols ))
227
245
228
246
229
- @dataclass (frozen = True )
247
+ @dataclass (frozen = True , eq = False )
230
248
class UnaryNode (BigFrameNode ):
231
249
child : BigFrameNode
232
250
@@ -252,7 +270,7 @@ def order_ambiguous(self) -> bool:
252
270
return self .child .order_ambiguous
253
271
254
272
255
- @dataclass (frozen = True )
273
+ @dataclass (frozen = True , eq = False )
256
274
class JoinNode (BigFrameNode ):
257
275
left_child : BigFrameNode
258
276
right_child : BigFrameNode
@@ -285,9 +303,6 @@ def explicitly_ordered(self) -> bool:
285
303
# Do not consider user pre-join ordering intent - they need to re-order post-join in unordered mode.
286
304
return False
287
305
288
- def __hash__ (self ):
289
- return self ._node_hash
290
-
291
306
@functools .cached_property
292
307
def fields (self ) -> Tuple [Field , ...]:
293
308
return tuple (itertools .chain (self .left_child .fields , self .right_child .fields ))
@@ -320,7 +335,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
320
335
return self .transform_children (lambda x : x .prune (new_used ))
321
336
322
337
323
- @dataclass (frozen = True )
338
+ @dataclass (frozen = True , eq = False )
324
339
class ConcatNode (BigFrameNode ):
325
340
# TODO: Explcitly map column ids from each child
326
341
children : Tuple [BigFrameNode , ...]
@@ -345,9 +360,6 @@ def explicitly_ordered(self) -> bool:
345
360
# Consider concat as an ordered operations (even though input frames may not be ordered)
346
361
return True
347
362
348
- def __hash__ (self ):
349
- return self ._node_hash
350
-
351
363
@functools .cached_property
352
364
def fields (self ) -> Tuple [Field , ...]:
353
365
# TODO: Output names should probably be aligned beforehand or be part of concat definition
@@ -371,16 +383,13 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
371
383
return self
372
384
373
385
374
- @dataclass (frozen = True )
386
+ @dataclass (frozen = True , eq = False )
375
387
class FromRangeNode (BigFrameNode ):
376
388
# TODO: Enforce single-row, single column constraint
377
389
start : BigFrameNode
378
390
end : BigFrameNode
379
391
step : int
380
392
381
- def __hash__ (self ):
382
- return self ._node_hash
383
-
384
393
@property
385
394
def roots (self ) -> typing .Set [BigFrameNode ]:
386
395
return {self }
@@ -419,7 +428,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
419
428
# Input Nodex
420
429
# TODO: Most leaf nodes produce fixed column names based on the datasource
421
430
# They should support renaming
422
- @dataclass (frozen = True )
431
+ @dataclass (frozen = True , eq = False )
423
432
class LeafNode (BigFrameNode ):
424
433
@property
425
434
def roots (self ) -> typing .Set [BigFrameNode ]:
@@ -451,7 +460,7 @@ class ScanList:
451
460
items : typing .Tuple [ScanItem , ...]
452
461
453
462
454
- @dataclass (frozen = True )
463
+ @dataclass (frozen = True , eq = False )
455
464
class ReadLocalNode (LeafNode ):
456
465
feather_bytes : bytes
457
466
data_schema : schemata .ArraySchema
@@ -460,9 +469,6 @@ class ReadLocalNode(LeafNode):
460
469
scan_list : ScanList
461
470
session : typing .Optional [bigframes .session .Session ] = None
462
471
463
- def __hash__ (self ):
464
- return self ._node_hash
465
-
466
472
@functools .cached_property
467
473
def fields (self ) -> Tuple [Field , ...]:
468
474
return tuple (Field (col_id , dtype ) for col_id , dtype , _ in self .scan_list .items )
@@ -547,7 +553,7 @@ class BigqueryDataSource:
547
553
548
554
549
555
## Put ordering in here or just add order_by node above?
550
- @dataclass (frozen = True )
556
+ @dataclass (frozen = True , eq = False )
551
557
class ReadTableNode (LeafNode ):
552
558
source : BigqueryDataSource
553
559
# Subset of physical schema column
@@ -570,9 +576,6 @@ def __post_init__(self):
570
576
def session (self ):
571
577
return self .table_session
572
578
573
- def __hash__ (self ):
574
- return self ._node_hash
575
-
576
579
@functools .cached_property
577
580
def fields (self ) -> Tuple [Field , ...]:
578
581
return tuple (Field (col_id , dtype ) for col_id , dtype , _ in self .scan_list .items )
@@ -616,15 +619,12 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
616
619
return ReadTableNode (self .source , new_scan_list , self .table_session )
617
620
618
621
619
- @dataclass (frozen = True )
622
+ @dataclass (frozen = True , eq = False )
620
623
class CachedTableNode (ReadTableNode ):
621
624
# The original BFET subtree that was cached
622
625
# note: this isn't a "child" node.
623
626
original_node : BigFrameNode = field ()
624
627
625
- def __hash__ (self ):
626
- return self ._node_hash
627
-
628
628
def prune (self , used_cols : COLUMN_SET ) -> BigFrameNode :
629
629
new_scan_list = ScanList (
630
630
tuple (item for item in self .scan_list .items if item .id in used_cols )
@@ -635,13 +635,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
635
635
636
636
637
637
# Unary nodes
638
- @dataclass (frozen = True )
638
+ @dataclass (frozen = True , eq = False )
639
639
class PromoteOffsetsNode (UnaryNode ):
640
640
col_id : bigframes .core .identifiers .ColumnId
641
641
642
- def __hash__ (self ):
643
- return self ._node_hash
644
-
645
642
@property
646
643
def non_local (self ) -> bool :
647
644
return True
@@ -666,17 +663,14 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
666
663
return self .transform_children (lambda x : x .prune (new_used ))
667
664
668
665
669
- @dataclass (frozen = True )
666
+ @dataclass (frozen = True , eq = False )
670
667
class FilterNode (UnaryNode ):
671
668
predicate : ex .Expression
672
669
673
670
@property
674
671
def row_preserving (self ) -> bool :
675
672
return False
676
673
677
- def __hash__ (self ):
678
- return self ._node_hash
679
-
680
674
@property
681
675
def variables_introduced (self ) -> int :
682
676
return 1
@@ -687,13 +681,10 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
687
681
return FilterNode (pruned_child , self .predicate )
688
682
689
683
690
- @dataclass (frozen = True )
684
+ @dataclass (frozen = True , eq = False )
691
685
class OrderByNode (UnaryNode ):
692
686
by : Tuple [OrderingExpression , ...]
693
687
694
- def __hash__ (self ):
695
- return self ._node_hash
696
-
697
688
@property
698
689
def variables_introduced (self ) -> int :
699
690
return 0
@@ -716,14 +707,11 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
716
707
return OrderByNode (pruned_child , self .by )
717
708
718
709
719
- @dataclass (frozen = True )
710
+ @dataclass (frozen = True , eq = False )
720
711
class ReversedNode (UnaryNode ):
721
712
# useless field to make sure has distinct hash
722
713
reversed : bool = True
723
714
724
- def __hash__ (self ):
725
- return self ._node_hash
726
-
727
715
@property
728
716
def variables_introduced (self ) -> int :
729
717
return 0
@@ -734,15 +722,12 @@ def relation_ops_created(self) -> int:
734
722
return 0
735
723
736
724
737
- @dataclass (frozen = True )
725
+ @dataclass (frozen = True , eq = False )
738
726
class SelectionNode (UnaryNode ):
739
727
input_output_pairs : typing .Tuple [
740
728
typing .Tuple [ex .DerefOp , bigframes .core .identifiers .ColumnId ], ...
741
729
]
742
730
743
- def __hash__ (self ):
744
- return self ._node_hash
745
-
746
731
@functools .cached_property
747
732
def fields (self ) -> Tuple [Field , ...]:
748
733
return tuple (
@@ -772,7 +757,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
772
757
return SelectionNode (pruned_child , pruned_selections )
773
758
774
759
775
- @dataclass (frozen = True )
760
+ @dataclass (frozen = True , eq = False )
776
761
class ProjectionNode (UnaryNode ):
777
762
"""Assigns new variables (without modifying existing ones)"""
778
763
@@ -788,9 +773,6 @@ def __post_init__(self):
788
773
# Cannot assign to existing variables - append only!
789
774
assert all (name not in self .child .schema .names for _ , name in self .assignments )
790
775
791
- def __hash__ (self ):
792
- return self ._node_hash
793
-
794
776
@functools .cached_property
795
777
def fields (self ) -> Tuple [Field , ...]:
796
778
input_types = self .child ._dtype_lookup
@@ -819,7 +801,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
819
801
820
802
# TODO: Merge RowCount into Aggregate Node?
821
803
# Row count can be compute from table metadata sometimes, so it is a bit special.
822
- @dataclass (frozen = True )
804
+ @dataclass (frozen = True , eq = False )
823
805
class RowCountNode (UnaryNode ):
824
806
@property
825
807
def row_preserving (self ) -> bool :
@@ -842,7 +824,7 @@ def defines_namespace(self) -> bool:
842
824
return True
843
825
844
826
845
- @dataclass (frozen = True )
827
+ @dataclass (frozen = True , eq = False )
846
828
class AggregateNode (UnaryNode ):
847
829
aggregations : typing .Tuple [
848
830
typing .Tuple [ex .Aggregation , bigframes .core .identifiers .ColumnId ], ...
@@ -854,9 +836,6 @@ class AggregateNode(UnaryNode):
854
836
def row_preserving (self ) -> bool :
855
837
return False
856
838
857
- def __hash__ (self ):
858
- return self ._node_hash
859
-
860
839
@property
861
840
def non_local (self ) -> bool :
862
841
return True
@@ -904,7 +883,7 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
904
883
return AggregateNode (pruned_child , pruned_aggs , self .by_column_ids , self .dropna )
905
884
906
885
907
- @dataclass (frozen = True )
886
+ @dataclass (frozen = True , eq = False )
908
887
class WindowOpNode (UnaryNode ):
909
888
column_name : ex .DerefOp
910
889
op : agg_ops .UnaryWindowOp
@@ -913,9 +892,6 @@ class WindowOpNode(UnaryNode):
913
892
never_skip_nulls : bool = False
914
893
skip_reproject_unsafe : bool = False
915
894
916
- def __hash__ (self ):
917
- return self ._node_hash
918
-
919
895
@property
920
896
def non_local (self ) -> bool :
921
897
return True
@@ -945,11 +921,8 @@ def prune(self, used_cols: COLUMN_SET) -> BigFrameNode:
945
921
946
922
947
923
# TODO: Remove this op
948
- @dataclass (frozen = True )
924
+ @dataclass (frozen = True , eq = False )
949
925
class ReprojectOpNode (UnaryNode ):
950
- def __hash__ (self ):
951
- return self ._node_hash
952
-
953
926
@property
954
927
def variables_introduced (self ) -> int :
955
928
return 0
@@ -960,7 +933,7 @@ def relation_ops_created(self) -> int:
960
933
return 0
961
934
962
935
963
- @dataclass (frozen = True )
936
+ @dataclass (frozen = True , eq = False )
964
937
class RandomSampleNode (UnaryNode ):
965
938
fraction : float
966
939
@@ -972,26 +945,20 @@ def deterministic(self) -> bool:
972
945
def row_preserving (self ) -> bool :
973
946
return False
974
947
975
- def __hash__ (self ):
976
- return self ._node_hash
977
-
978
948
@property
979
949
def variables_introduced (self ) -> int :
980
950
return 1
981
951
982
952
983
953
# TODO: Explode should create a new column instead of overriding the existing one
984
- @dataclass (frozen = True )
954
+ @dataclass (frozen = True , eq = False )
985
955
class ExplodeNode (UnaryNode ):
986
956
column_ids : typing .Tuple [ex .DerefOp , ...]
987
957
988
958
@property
989
959
def row_preserving (self ) -> bool :
990
960
return False
991
961
992
- def __hash__ (self ):
993
- return self ._node_hash
994
-
995
962
@functools .cached_property
996
963
def fields (self ) -> Tuple [Field , ...]:
997
964
return tuple (
0 commit comments