Skip to content

ENH: special: fix premature overflow in boxcox #20073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 15, 2024

Conversation

xuefeng-xu
Copy link
Contributor

Reference issue

#19604 (comment)
Towards #19016

What does this implement/fix?

Fix premature overflow of the following modules.
special.boxcox, special.inv_boxcox, special.boxcox1p, special.inv_boxcox1p

Additional information

@github-actions github-actions bot added scipy.special Cython Issues with the internal Cython code base enhancement A new feature or improvement labels Feb 12, 2024
@mdhaber
Copy link
Contributor

mdhaber commented Feb 12, 2024

Thanks @xuefeng-xu! Since this is modifying the formulation for all values (not just in cases of overflow), I'd like a special regular to take a look. Did you benchmark the accuracy across the full ranges of $\lambda$ and $x$ values as suggested? We want to make sure this does no harm. Looks like it does, currently - genpareto is implemented in terms of these special functions, and those test failures look real. Let's get those resolved, and then we can ask someone to take a look.

@xuefeng-xu
Copy link
Contributor Author

The modified formulation loses precision when:

  1. $\lambda\approx0$
import mpmath
import numpy as np
from scipy.special._mptestutils import (
    Arg, assert_mpmath_equal, exception_to_nan)


np.seterr(over='ignore')


def boxcox(x, lmbda):
    if abs(lmbda) < 1e-14: # change 1e-19 to 1e-14
        return np.log(x)
    else:
        # return np.expm1(lmbda * np.log(x)) / lmbda
        return np.sign(lmbda) * np.exp(lmbda * np.log(x) - np.log(abs(lmbda))) - 1 / lmbda


def test_boxcox():

    def mp_boxcox(x, lmbda):
        x = mpmath.mp.mpf(x)
        lmbda = mpmath.mp.mpf(lmbda)
        if lmbda == 0:
            return mpmath.mp.log(x)
        else:
            return mpmath.mp.powm1(x, lmbda) / lmbda

    assert_mpmath_equal(
        boxcox,
        exception_to_nan(mp_boxcox),
        [Arg(a=0, inclusive_a=False), Arg()],
        n=1000,
        dps=100,
        rtol=1e-13,
    )
==================================================== test session starts ====================================================
platform darwin -- Python 3.11.6, pytest-7.4.3, pluggy-1.3.0
rootdir: /Users/xxf/code
plugins: anyio-3.5.0, timeout-2.2.0, xdist-3.4.0, cov-4.1.0, hypothesis-6.89.0
collected 1 item                                                                                                            

mpboxcox.py F                                                                                                         [100%]

========================================================= FAILURES ==========================================================
________________________________________________________ test_boxcox ________________________________________________________

    def test_boxcox():
    
        def mp_boxcox(x, lmbda):
            x = mpmath.mp.mpf(x)
            lmbda = mpmath.mp.mpf(lmbda)
            if lmbda == 0:
                return mpmath.mp.log(x)
            else:
                return mpmath.mp.powm1(x, lmbda) / lmbda
    
>       assert_mpmath_equal(
            boxcox,
            exception_to_nan(mp_boxcox),
            [Arg(a=0, inclusive_a=False), Arg()],
            n=1000,
            dps=100,
            rtol=1e-13,
        )

mpboxcox.py:28: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
scipy/scipy/special/_mptestutils.py:295: in assert_mpmath_equal
    d.check()
scipy/scipy/special/_mptestutils.py:282: in check
    raise value
scipy/scipy/special/_mptestutils.py:263: in check
    assert_func_equal(
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

func = <function boxcox at 0x10624ca40>, results = <function MpmathData.check.<locals>.<lambda> at 0x1201414e0>
points = array([[ 1.00000000e-030, -8.98846567e+307],
       [ 1.00000000e-030, -4.32305117e+205],
       [ 1.00000000e-030, -2...567e+307,  2.07919484e+103],
       [ 8.98846567e+307,  4.32305117e+205],
       [ 8.98846567e+307,  8.98846567e+307]])
rtol = 1e-13, atol = 1e-300, param_filter = None, knownfailure = None, vectorized = False, dtype = None, nan_ok = True
ignore_inf_sign = False, distinguish_nan_and_inf = True

    def assert_func_equal(func, results, points, rtol=None, atol=None,
                          param_filter=None, knownfailure=None,
                          vectorized=True, dtype=None, nan_ok=False,
                          ignore_inf_sign=False, distinguish_nan_and_inf=True):
        if hasattr(points, 'next'):
            # it's a generator
            points = list(points)
    
        points = np.asarray(points)
        if points.ndim == 1:
            points = points[:,None]
        nparams = points.shape[1]
    
        if hasattr(results, '__name__'):
            # function
            data = points
            result_columns = None
            result_func = results
        else:
            # dataset
            data = np.c_[points, results]
            result_columns = list(range(nparams, data.shape[1]))
            result_func = None
    
        fdata = FuncData(func, data, list(range(nparams)),
                         result_columns=result_columns, result_func=result_func,
                         rtol=rtol, atol=atol, param_filter=param_filter,
                         knownfailure=knownfailure, nan_ok=nan_ok, vectorized=vectorized,
                         ignore_inf_sign=ignore_inf_sign,
                         distinguish_nan_and_inf=distinguish_nan_and_inf)
>       fdata.check()
E       AssertionError: 
E       Max |adiff|: 7.36617e-08
E       Max |rdiff|: 3.7072e-07
E       Bad results (76 out of 1024) for the following points (in output 0):
E                               1.e-30         -3.4223003202678016e-08 =>             -69.07763443514705 !=             -69.07763444097992  (rdiff          8.443939614851335e-11)
E                               1.e-30            3.73837195305305e-08 =>              -69.0774636156857 !=             -69.07746359779577  (rdiff         2.5898361610806074e-10)
E               5.2625215202708194e-27         -3.4223003202678016e-08 =>             -60.50924983993173 !=            -60.509249873916964  (rdiff          5.616535952096334e-10)
E               5.2625215202708194e-27            3.73837195305305e-08 =>            -60.509118776768446 !=             -60.50911878486791  (rdiff         1.3385523697463226e-10)
E                 2.76941327513135e-23         -3.4223003202678016e-08 =>             -51.94086775556207 !=             -51.94086781940562  (rdiff         1.2291583404492937e-09)
E                 2.76941327513135e-23            3.73837195305305e-08 =>             -51.94077119231224 !=            -51.940771227349884  (rdiff          6.745691885085158e-10)
E               1.4574096958902303e-19         -3.4223003202678016e-08 =>             -43.37248829007149 !=            -43.372488277445164  (rdiff           2.91113643685989e-10)
E               1.4574096958902303e-19            3.73837195305305e-08 =>             -43.37242095917463 !=             -43.37242092524083  (rdiff          7.823820027240829e-10)
E                7.669649888473719e-16         -3.4223003202678016e-08 =>             -34.80411123111844 !=             -34.80411124803483  (rdiff          4.860456906829046e-10)
E                7.669649888473719e-16            3.73837195305305e-08 =>             -34.80406788736582 !=             -34.80406787853985  (rdiff         2.5359021525611324e-10)
E                4.036169759103547e-12         -3.4223003202678016e-08 =>            -26.235736686736345 !=            -26.235736731173912  (rdiff         1.6937800135262329e-09)
E                4.036169759103547e-12            3.73837195305305e-08 =>            -26.235712070018053 !=            -26.235712087246075  (rdiff           6.56663010274995e-10)
E               2.1240430216748705e-08         -3.4223003202678016e-08 =>             -17.66736465319991 !=            -17.667364726861653  (rdiff          4.169367799561624e-09)
E               2.1240430216748705e-08            3.73837195305305e-08 =>            -17.667353603988886 !=            -17.667353551358623  (rdiff         2.9789556574815125e-09)
E                 0.000111778221115451         -3.4223003202678016e-08 =>             -9.098995238542557 !=             -9.098995235097325  (rdiff          3.786386921305008e-10)
E                 0.000111778221115451            3.73837195305305e-08 =>             -9.098992295563221 !=             -9.098992270876613  (rdiff         2.7131145491938895e-09)
E                   0.5882352941176471         -3.4223003202678016e-08 =>            -0.5306282304227352 !=            -0.5306282558801932  (rdiff         4.7976069401507286e-08)
E                   0.5882352941176471            3.73837195305305e-08 =>            -0.5306282453238964 !=            -0.5306282457991719  (rdiff          8.956844823031987e-10)
E                   1.1764705882352942         -3.4223003202678016e-08 =>            0.16251898929476738 !=            0.16251892904581908  (rdiff          3.707195749831828e-07)
E                   1.1764705882352942            3.73837195305305e-08 =>             0.1625189520418644 !=            0.16251892999147188  (rdiff         1.3567891760710356e-07)
E                   1.7647058823529411         -3.4223003202678016e-08 =>             0.5679840371012688 !=             0.5679840320856685  (rdiff          8.830530411261046e-09)
E                   1.7647058823529411            3.73837195305305e-08 =>             0.5679840035736561 !=             0.5679840436360429  (rdiff          7.053435268343966e-08)
E                   2.3529411764705883         -3.4223003202678016e-08 =>             0.8556660898029804 !=             0.8556660975292865  (rdiff          9.029580652676584e-09)
E                   2.3529411764705883            3.73837195305305e-08 =>             0.8556660749018192 !=             0.8556661237432364  (rdiff         5.7079993910757565e-08)
E                   2.9411764705882355         -3.4223003202678016e-08 =>             1.0788096599280834 !=             1.0788096414570465  (rdiff         1.7121683207484437e-08)
E                   2.9411764705882355            3.73837195305305e-08 =>             1.0788097083568573 !=             1.0788096831260827  (rdiff          2.338760485316419e-08)
E                   3.5294117647058822         -3.4223003202678016e-08 =>             1.2611312307417393 !=             1.2611311909508638  (rdiff          3.155173368874185e-08)
E                   3.5294117647058822            3.73837195305305e-08 =>             1.2611312307417393 !=               1.26113124789439  (rdiff         1.3601003592462826e-08)
E                     4.11764705882353         -3.4223003202678016e-08 =>             1.4152818992733955 !=             1.4152818637184148  (rdiff          2.512219060304968e-08)
E                     4.11764705882353            3.73837195305305e-08 =>             1.4152819067239761 !=              1.415281935433366  (rdiff         2.0285279586798455e-08)
E                    4.705882352941177         -3.4223003202678016e-08 =>             1.5488132759928703 !=             1.5488132495702094  (rdiff          1.705993989619612e-08)
E                    4.705882352941177            3.73837195305305e-08 =>              1.548813309520483 !=             1.5488133354561222  (rdiff          1.674549062699351e-08)
E                    5.294117647058823         -3.4223003202678016e-08 =>             1.6665962599217892 !=              1.666596278746113  (rdiff         1.1295071348422783e-08)
E                    5.294117647058823            3.73837195305305e-08 =>             1.6665963903069496 !=             1.6665963781915003  (rdiff          7.269576166728482e-09)
E                    5.882352941176471         -3.4223003202678016e-08 =>             1.7719568386673927 !=             1.7719567882046523  (rdiff         2.8478538961367388e-08)
E                    5.882352941176471            3.73837195305305e-08 =>              1.771956853568554 !=             1.7719569006211584  (rdiff         2.6554034391498504e-08)
E                    6.470588235294118         -3.4223003202678016e-08 =>             1.8672669865190983 !=              1.867266962073766  (rdiff         1.3091503625045225e-08)
E                    6.470588235294118            3.73837195305305e-08 =>              1.867267120629549 !=               1.86726708690885  (rdiff         1.8058851547428588e-08)
E                   7.0588235294117645         -3.4223003202678016e-08 =>             1.9542784057557583 !=              1.954278333373515  (rdiff         3.7037837484100626e-08)
E                   7.0588235294117645            3.73837195305305e-08 =>              1.954278476536274 !=             1.9542784701138582  (rdiff         3.2863360102343935e-09)
E                    7.647058823529412         -3.4223003202678016e-08 =>             2.0343210138380527 !=             2.0343210355840626  (rdiff         1.0689566445667878e-08)
E                    7.647058823529412            3.73837195305305e-08 =>             2.0343211963772774 !=             2.0343211837549267  (rdiff          6.204699051943513e-09)
E                     8.23529411764706         -3.4223003202678016e-08 =>             2.1084290742874146 !=              2.108429002484369  (rdiff          3.405523527080793e-08)
E                     8.23529411764706            3.73837195305305e-08 =>              2.108429156243801 !=             2.1084291616472517  (rdiff         2.5627849747855775e-09)
E                    8.823529411764707         -3.4223003202678016e-08 =>             2.1774218641221523 !=              2.177421868911566  (rdiff         2.1995800735087972e-09)
E                    8.823529411764707            3.73837195305305e-08 =>              2.177422020584345 !=             2.1774220386612586  (rdiff           8.30197978283715e-09)
E                    9.411764705882353         -3.4223003202678016e-08 =>              2.241960447281599 !=             2.2419603851685883  (rdiff         2.7704776206429715e-08)
E                    9.411764705882353            3.73837195305305e-08 =>             2.2419605627655983 !=               2.24196056513013  (rdiff         1.0546713971426796e-09)
E                                  10.         -3.4223003202678016e-08 =>              2.302585057914257 !=               2.30258500227061  (rdiff         2.4165729807369755e-08)
E                                  10.            3.73837195305305e-08 =>             2.3025851398706436 !=             2.3025851920963847  (rdiff         2.2681350181691292e-08)
E               7.0880457030678306e+44         -3.4223003202678016e-08 =>              103.2719712741673 !=             103.27197125666322  (rdiff         1.6949495563902615e-10)
E               7.0880457030678306e+44            3.73837195305305e-08 =>              103.2723530754447 !=             103.27235310448533  (rdiff          2.812043486259432e-10)
E                5.024039188877834e+88         -3.4223003202678016e-08 =>             204.24100859835744 !=             204.24100861377485  (rdiff          7.548637381191282e-11)
E                5.024039188877834e+88         -1.0540925533894595e-15 =>             204.24172241294545 !=             204.24172241292348  (rdiff         1.0756852777261997e-13)
E                5.024039188877834e+88          1.1180339887498947e-15 =>             204.24172241294545 !=             204.24172241296878  (rdiff         1.1424807412846235e-13)
E                5.024039188877834e+88            3.73837195305305e-08 =>             204.24250213429332 !=              204.2425021399006  (rdiff          2.745408312699838e-11)
E               3.561061938476992e+132         -3.4223003202678016e-08 =>              305.2096971273422 !=              305.2096970748111  (rdiff         1.7211480198446786e-10)
E               3.561061938476992e+132         -1.0540925533894595e-15 =>             305.21129107292114 !=              305.2112910728721  (rdiff         1.6072757434509447e-13)
E               3.561061938476992e+132          1.1180339887498947e-15 =>             305.21129107292114 !=             305.21129107297327  (rdiff          1.707846879193543e-13)
E               3.561061938476992e+132            3.73837195305305e-08 =>              305.2130323201418 !=              305.2130322997808  (rdiff          6.671068866885632e-11)
E              2.5240969771380244e+176         -3.4223003202678016e-08 =>              406.1780366562307 !=             406.17803664097755  (rdiff          3.755283217569957e-11)
E              2.5240969771380244e+176         -1.0540925533894595e-15 =>             406.18085973289686 !=              406.1808597328099  (rdiff         2.1411750153428288e-13)
E              2.5240969771380244e+176          1.1180339887498947e-15 =>             406.18085973289686 !=              406.1808597329891  (rdiff         2.2713248692156064e-13)
E              2.5240969771380244e+176            3.73837195305305e-08 =>              406.1839435324073 !=              406.1839435855646  (rdiff          1.308700024149117e-10)
E              1.7890914732929676e+220         -3.4223003202678016e-08 =>             507.14602729678154 !=             507.14602731347986  (rdiff           3.29260643389193e-11)
E              1.7890914732929676e+220         -1.0540925533894595e-15 =>              507.1504283928726 !=               507.150428392737  (rdiff         2.6732020006900317e-13)
E              1.7890914732929676e+220          1.1180339887498947e-15 =>              507.1504283928726 !=             507.15042839301634  (rdiff         2.8346028762018307e-13)
E              1.7890914732929676e+220            3.73837195305305e-08 =>             507.15523597225547 !=              507.1552359986905  (rdiff         5.2124142796559876e-11)
E              1.2681162129669514e+264         -3.4223003202678016e-08 =>               608.113669142127 !=              608.1136690935235  (rdiff          7.992504201070365e-11)
E              1.2681162129669514e+264         -1.0540925533894595e-15 =>              608.1199970528482 !=              608.1199970526534  (rdiff         3.2042892981527494e-13)
E              1.2681162129669514e+264          1.1180339887498947e-15 =>              608.1199970528482 !=               608.119997053055  (rdiff          3.400584733568262e-13)
E              1.2681162129669514e+264            3.73837195305305e-08 =>              608.1269095353782 !=              608.1269095405971  (rdiff          8.581938885735336e-12)
E                8.98846567431105e+307         -3.4223003202678016e-08 =>              709.0809619985521 !=              709.0809619823142  (rdiff          2.289991115596109e-11)
E                8.98846567431105e+307         -1.0540925533894595e-15 =>               709.089565712824 !=               709.089565712559  (rdiff         3.7372432418010035e-13)
E                8.98846567431105e+307          1.1180339887498947e-15 =>               709.089565712824 !=              709.0895657131051  (rdiff         3.9649088532675946e-13)
E                8.98846567431105e+307            3.73837195305305e-08 =>              709.0989642329514 !=              709.0989642127232  (rdiff          2.852660400008869e-11)

scipy/scipy/special/_testutils.py:85: AssertionError
================================================== short test summary info ==================================================
FAILED mpboxcox.py::test_boxcox - AssertionError: 
===================================================== 1 failed in 0.64s =====================================================
  1. $x=1$
print(boxcox(1, 10)) # -2.7755575615628914e-17 (not 0)

@dschmitz89
Copy link
Contributor

dschmitz89 commented Feb 15, 2024

Maybe we can use the same formula as the mpmath implementation using scipy's powm1 ?

Another possibility would be to use the old formulation for small values and the new one for large ones.

@xuefeng-xu
Copy link
Contributor Author

Thanks @mdhaber @dschmitz89 !
I use the new formulation for large values, test passed.

@mdhaber
Copy link
Contributor

mdhaber commented Apr 13, 2024

Thanks for your patience @xuefeng-xu. @steppi is a scipy.special maintainer, and he will take a look shortly.

@lucascolley lucascolley changed the title ENH: fix premature overflow in boxcox ENH: special: fix premature overflow in boxcox Apr 14, 2024
Co-authored-by: Jake Bowhay <[email protected]>
Copy link
Contributor

@steppi steppi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good to me.

@steppi steppi merged commit ebe1da8 into scipy:main Apr 15, 2024
@lucascolley lucascolley added this to the 1.14.0 milestone Apr 15, 2024
@mdhaber
Copy link
Contributor

mdhaber commented Apr 15, 2024

Thanks @steppi @xuefeng-xu!

@xuefeng-xu xuefeng-xu deleted the boxcox_premature_overflow branch April 16, 2024 02:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Cython Issues with the internal Cython code base enhancement A new feature or improvement scipy.special
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants