Skip to content

Commit 768b82a

Browse files
authored
feat: Support callable for series where method (#2005)
1 parent 46994d7 commit 768b82a

File tree

4 files changed

+111
-0
lines changed

4 files changed

+111
-0
lines changed

bigframes/series.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,7 +1478,23 @@ def items(self):
14781478
for item in batch_df.squeeze(axis=1).items():
14791479
yield item
14801480

1481+
def _apply_callable(self, condition):
1482+
""" "Executes the possible callable condition as needed."""
1483+
if callable(condition):
1484+
# When it's a bigframes function.
1485+
if hasattr(condition, "bigframes_bigquery_function"):
1486+
return self.apply(condition)
1487+
# When it's a plain Python function.
1488+
else:
1489+
return self.apply(condition, by_row=False)
1490+
1491+
# When it's not a callable.
1492+
return condition
1493+
14811494
def where(self, cond, other=None):
1495+
cond = self._apply_callable(cond)
1496+
other = self._apply_callable(other)
1497+
14821498
value_id, cond_id, other_id, block = self._align3(cond, other)
14831499
block, result_id = block.project_expr(
14841500
ops.where_op.as_expr(value_id, cond_id, other_id)

tests/system/large/functions/test_managed_function.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,3 +1075,39 @@ def func_for_other(x):
10751075
cleanup_function_assets(
10761076
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
10771077
)
1078+
1079+
1080+
def test_managed_function_series_where(session, dataset_id, scalars_dfs):
1081+
try:
1082+
1083+
# The return type has to be bool type for callable where condition.
1084+
def _is_positive(s):
1085+
return s + 1000 > 0
1086+
1087+
is_positive_mf = session.udf(
1088+
input_types=int,
1089+
output_type=bool,
1090+
dataset=dataset_id,
1091+
name=prefixer.create_prefix(),
1092+
)(_is_positive)
1093+
1094+
scalars, scalars_pandas = scalars_dfs
1095+
1096+
bf_int64 = scalars["int64_col"]
1097+
bf_int64_filtered = bf_int64.dropna()
1098+
pd_int64 = scalars_pandas["int64_col"]
1099+
pd_int64_filtered = pd_int64.dropna()
1100+
1101+
# The cond is a callable (managed function) and the other is not a
1102+
# callable in series.where method.
1103+
bf_result = bf_int64_filtered.where(
1104+
cond=is_positive_mf, other=-bf_int64_filtered
1105+
).to_pandas()
1106+
pd_result = pd_int64_filtered.where(cond=_is_positive, other=-pd_int64_filtered)
1107+
1108+
# Ignore any dtype difference.
1109+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
1110+
1111+
finally:
1112+
# Clean up the gcp assets created for the managed function.
1113+
cleanup_function_assets(is_positive_mf, session.bqclient, ignore_failures=False)

tests/system/large/functions/test_remote_function.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2930,3 +2930,42 @@ def func_for_other(x):
29302930
cleanup_function_assets(
29312931
is_sum_positive_series_mf, session.bqclient, ignore_failures=False
29322932
)
2933+
2934+
2935+
@pytest.mark.flaky(retries=2, delay=120)
2936+
def test_remote_function_series_where(session, dataset_id, scalars_dfs):
2937+
try:
2938+
2939+
def _ten_times(x):
2940+
return x * 10
2941+
2942+
ten_times_mf = session.remote_function(
2943+
input_types=float,
2944+
output_type=float,
2945+
dataset=dataset_id,
2946+
reuse=False,
2947+
cloud_function_service_account="default",
2948+
)(_ten_times)
2949+
2950+
scalars, scalars_pandas = scalars_dfs
2951+
2952+
bf_int64 = scalars["float64_col"]
2953+
bf_int64_filtered = bf_int64.dropna()
2954+
pd_int64 = scalars_pandas["float64_col"]
2955+
pd_int64_filtered = pd_int64.dropna()
2956+
2957+
# The cond is not a callable and the other is a callable (remote
2958+
# function) in series.where method.
2959+
bf_result = bf_int64_filtered.where(
2960+
cond=bf_int64_filtered < 0, other=ten_times_mf
2961+
).to_pandas()
2962+
pd_result = pd_int64_filtered.where(
2963+
cond=pd_int64_filtered < 0, other=_ten_times
2964+
)
2965+
2966+
# Ignore any dtype difference.
2967+
pandas.testing.assert_series_equal(bf_result, pd_result, check_dtype=False)
2968+
2969+
finally:
2970+
# Clean up the gcp assets created for the remote function.
2971+
cleanup_function_assets(ten_times_mf, session.bqclient, ignore_failures=False)

tests/system/small/test_series.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3109,6 +3109,26 @@ def test_where_with_default(scalars_df_index, scalars_pandas_df_index):
31093109
)
31103110

31113111

3112+
def test_where_with_callable(scalars_df_index, scalars_pandas_df_index):
3113+
def _is_positive(x):
3114+
return x > 0
3115+
3116+
# Both cond and other are callable.
3117+
bf_result = (
3118+
scalars_df_index["int64_col"]
3119+
.where(cond=_is_positive, other=lambda x: x * 10)
3120+
.to_pandas()
3121+
)
3122+
pd_result = scalars_pandas_df_index["int64_col"].where(
3123+
cond=_is_positive, other=lambda x: x * 10
3124+
)
3125+
3126+
pd.testing.assert_series_equal(
3127+
bf_result,
3128+
pd_result,
3129+
)
3130+
3131+
31123132
@pytest.mark.parametrize(
31133133
("ordered"),
31143134
[

0 commit comments

Comments
 (0)