@@ -1716,10 +1716,13 @@ def kurt(self, *, numeric_only: bool = False):
17161716
17171717 kurtosis = kurt
17181718
1719- def pivot (
1719+ def _pivot (
17201720 self ,
17211721 * ,
17221722 columns : typing .Union [blocks .Label , Sequence [blocks .Label ]],
1723+ columns_unique_values : typing .Optional [
1724+ typing .Union [pandas .Index , Sequence [object ]]
1725+ ] = None ,
17231726 index : typing .Optional [
17241727 typing .Union [blocks .Label , Sequence [blocks .Label ]]
17251728 ] = None ,
@@ -1743,10 +1746,24 @@ def pivot(
17431746 pivot_block = block .pivot (
17441747 columns = column_ids ,
17451748 values = value_col_ids ,
1749+ columns_unique_values = columns_unique_values ,
17461750 values_in_index = utils .is_list_like (values ),
17471751 )
17481752 return DataFrame (pivot_block )
17491753
1754+ def pivot (
1755+ self ,
1756+ * ,
1757+ columns : typing .Union [blocks .Label , Sequence [blocks .Label ]],
1758+ index : typing .Optional [
1759+ typing .Union [blocks .Label , Sequence [blocks .Label ]]
1760+ ] = None ,
1761+ values : typing .Optional [
1762+ typing .Union [blocks .Label , Sequence [blocks .Label ]]
1763+ ] = None ,
1764+ ) -> DataFrame :
1765+ return self ._pivot (columns = columns , index = index , values = values )
1766+
17501767 def stack (self , level : LevelsType = - 1 ):
17511768 if not isinstance (self .columns , pandas .MultiIndex ):
17521769 if level not in [0 , - 1 , self .columns .name ]:
@@ -2578,3 +2595,86 @@ def _get_block(self) -> blocks.Block:
25782595
25792596 def _cached (self ) -> DataFrame :
25802597 return DataFrame (self ._block .cached ())
2598+
2599+ _DataFrameOrSeries = typing .TypeVar ("_DataFrameOrSeries" )
2600+
2601+ def dot (self , other : _DataFrameOrSeries ) -> _DataFrameOrSeries :
2602+ if not isinstance (other , (DataFrame , bf_series .Series )):
2603+ raise NotImplementedError (
2604+ f"Only DataFrame or Series operand is supported. { constants .FEEDBACK_LINK } "
2605+ )
2606+
2607+ if len (self .index .names ) > 1 or len (other .index .names ) > 1 :
2608+ raise NotImplementedError (
2609+ f"Multi-index input is not supported. { constants .FEEDBACK_LINK } "
2610+ )
2611+
2612+ if len (self .columns .names ) > 1 or (
2613+ isinstance (other , DataFrame ) and len (other .columns .names ) > 1
2614+ ):
2615+ raise NotImplementedError (
2616+ f"Multi-level column input is not supported. { constants .FEEDBACK_LINK } "
2617+ )
2618+
2619+ # Convert the dataframes into cell-value-decomposed representation, i.e.
2620+ # each cell value is present in a separate row
2621+ row_id = "row"
2622+ col_id = "col"
2623+ val_id = "val"
2624+ left_suffix = "_left"
2625+ right_suffix = "_right"
2626+ cvd_columns = [row_id , col_id , val_id ]
2627+
2628+ def get_left_id (id ):
2629+ return f"{ id } { left_suffix } "
2630+
2631+ def get_right_id (id ):
2632+ return f"{ id } { right_suffix } "
2633+
2634+ other_frame = other if isinstance (other , DataFrame ) else other .to_frame ()
2635+
2636+ left = self .stack ().reset_index ()
2637+ left .columns = cvd_columns
2638+
2639+ right = other_frame .stack ().reset_index ()
2640+ right .columns = cvd_columns
2641+
2642+ merged = left .merge (
2643+ right ,
2644+ left_on = col_id ,
2645+ right_on = row_id ,
2646+ suffixes = (left_suffix , right_suffix ),
2647+ )
2648+
2649+ left_row_id = get_left_id (row_id )
2650+ right_col_id = get_right_id (col_id )
2651+
2652+ aggregated = (
2653+ merged .assign (
2654+ val = merged [get_left_id (val_id )] * merged [get_right_id (val_id )]
2655+ )[[left_row_id , right_col_id , val_id ]]
2656+ .groupby ([left_row_id , right_col_id ])
2657+ .sum (numeric_only = True )
2658+ )
2659+ aggregated_noindex = aggregated .reset_index ()
2660+ aggregated_noindex .columns = cvd_columns
2661+ result = aggregated_noindex ._pivot (
2662+ columns = col_id , columns_unique_values = other_frame .columns , index = row_id
2663+ )
2664+
2665+ # Set the index names to match the left side matrix
2666+ result .index .names = self .index .names
2667+
2668+ # Pivot has the result columns ordered alphabetically. It should still
2669+ # match the columns in the right sided matrix. Let's reorder them as per
2670+ # the right side matrix
2671+ if not result .columns .difference (other_frame .columns ).empty :
2672+ raise RuntimeError (
2673+ f"Could not construct all columns. { constants .FEEDBACK_LINK } "
2674+ )
2675+ result = result [other_frame .columns ]
2676+
2677+ if isinstance (other , bf_series .Series ):
2678+ result = result [other .name ].rename ()
2679+
2680+ return result
0 commit comments