Skip to content

Commit 4208044

Browse files
authored
fix: Improve Series.replace for dict input (#907)
1 parent 0a90b11 commit 4208044

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

bigframes/series.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,9 @@ def _simple_replace(self, to_replace_list: typing.Sequence, value):
581581
return Series(block.select_column(result_col))
582582

583583
def _mapping_replace(self, mapping: dict[typing.Hashable, typing.Hashable]):
584+
if not mapping:
585+
return self.copy()
586+
584587
tuples = []
585588
lcd_types: list[typing.Optional[bigframes.dtypes.Dtype]] = []
586589
for key, value in mapping.items():
@@ -597,6 +600,7 @@ def _mapping_replace(self, mapping: dict[typing.Hashable, typing.Hashable]):
597600
result_dtype = functools.reduce(
598601
lambda t1, t2: bigframes.dtypes.lcd_type(t1, t2) if (t1 and t2) else None,
599602
lcd_types,
603+
self.dtype,
600604
)
601605
if not result_dtype:
602606
raise NotImplementedError(
@@ -605,7 +609,9 @@ def _mapping_replace(self, mapping: dict[typing.Hashable, typing.Hashable]):
605609
block, result = self._block.apply_unary_op(
606610
self._value_column, ops.MapOp(tuple(tuples))
607611
)
608-
return Series(block.select_column(result))
612+
replaced = Series(block.select_column(result))
613+
replaced.name = self.name
614+
return replaced
609615

610616
@validations.requires_ordering()
611617
@validations.requires_index

tests/system/small/test_series.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,29 @@ def test_series_replace_list_scalar(scalars_dfs):
458458
)
459459

460460

461+
@pytest.mark.parametrize(
462+
("replacement_dict",),
463+
(
464+
({"Hello, World!": "Howdy, Planet!", "T": "R"},),
465+
({},),
466+
),
467+
ids=[
468+
"non-empty",
469+
"empty",
470+
],
471+
)
472+
def test_series_replace_dict(scalars_dfs, replacement_dict):
473+
scalars_df, scalars_pandas_df = scalars_dfs
474+
col_name = "string_col"
475+
bf_result = scalars_df[col_name].replace(replacement_dict).to_pandas()
476+
pd_result = scalars_pandas_df[col_name].replace(replacement_dict)
477+
478+
pd.testing.assert_series_equal(
479+
pd_result,
480+
bf_result,
481+
)
482+
483+
461484
@pytest.mark.parametrize(
462485
("method",),
463486
(

0 commit comments

Comments
 (0)