Skip to content

MAINT: stats.dirichlet: fix interface inconsistency #16042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 25, 2022

Conversation

mdhaber
Copy link
Contributor

@mdhaber mdhaber commented Apr 24, 2022

Reference issue

Closes gh-6006
gh-4984

What does this implement/fix?

The dirichlet distribution interface is inconsistent with other multivariate distributions and even itself: the pdf method expects the transpose of what the rvs method produces. This PR introduces multivariate_beta, which is the same distribution without this inconsistency.

The plan is to deprecate dirichlet in favor of multivariate_beta. (@h-vetinari can you help with this, either by making a PR against my branch or a follow-up to this one? I can review it.) After the deprecation cycle, we can make dirichlet an alias for multivariate_beta if desired.

Additional information

There are many other things that could be improved about the distribution and its documentation (e.g. pdf input x can only be 2D). This PR does not fix all of them, but it does fix a defect reported as early as gh-4984. Let's get this messy stuff out of the way, then those other things can be cleaned up in future PRs.

The three commits are pretty clean. I'd suggest reviewing them separately and in order.

LMK if I should address gh-6474 the same sort of way... maybe inv_wishart?

@mdhaber mdhaber added defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.stats labels Apr 24, 2022
@mdhaber mdhaber requested a review from ev-br April 24, 2022 21:58
Copy link
Contributor Author

@mdhaber mdhaber left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments to facilitate review.

@@ -1578,6 +1590,128 @@ def rvs(self, size=1, random_state=None):
return self._dist.rvs(self.alpha, size, random_state)


class multivariate_beta_gen(dirichlet_gen):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy-paste with minimal changes (dirichlet -> multivariate_beta as needed)

@@ -637,120 +638,125 @@ def test_moments(self):
N*num_cols,num_rows).T)
assert_allclose(sample_rowcov, U, atol=0.1)

class TestDirichlet:

class DirichletTest:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will have two subclasses - one with dist = dirichlet, the other with dist = multivariate_beta.

@@ -1230,7 +1231,6 @@ def _dirichlet_check_parameters(alpha):


def _dirichlet_check_input(alpha, x):
x = np.asarray(x)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dirichlet and multivariate_beta both have a new _check_input method.
dirichlet does x = np.asarray(x) then calls _dirichlet_check_input(alpha, x).
multivariate_beta does x = np.moveaxis(x, -1, 0) then calls _dirichlet_check_input(alpha, x).
That is the only thing that's different between the two distributions, other than documentation.

return multivariate_beta_frozen(alpha, seed=seed)

def _check_input(self, alpha, x):
x = np.moveaxis(x, -1, 0)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably could have just done transpose. I didn't realize that dirichlet was only written for 2D x.

@h-vetinari
Copy link
Member

@h-vetinari can you help with this, either by making a PR against my branch or a follow-up to this one?

Hey, sorry I haven't gotten around to this yet. I'm in the middle of preparing a move, and am on low availability until ~middle of May. I'll try to take a look when I can, but no promises, unfortunately... 🙈

Copy link
Member

@tirthasheshpatel tirthasheshpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Only small nitpicks. Otherwise LGTM.

One question: Unlike univariate distributions, we are not constrained by the infrastructure to add new parameters to the pdf and logpdf methods. Instead of deprecating the whole distribution, we can alternatively add a new keyword to the pdf and logpdf method (e.g. transpose) and default it to True. We can then emit a deprecation warning when it is true and default it to false in 1.11.0. To me, this sounds simpler than adding a new distribution. Hove you considered doing this?

@mdhaber
Copy link
Contributor Author

mdhaber commented May 16, 2022

I don't remember! I think I got carried away after gh-15889 (which also might interest you) and maybe didn't stop to think. Good idea.

On the other hand, we'd have the unfortunate choice of getting stuck with a transpose keyword or backward incompatibility (default to transpose=None but give the option of transpose=True for temporary relief).

Would you like to submit a PR for that, and if it's merged, we'd close this?

@mdhaber
Copy link
Contributor Author

mdhaber commented May 16, 2022

Actually this needs an email to the mailing list. I'll send one with both options and let the default be to merge this one if there are no comments in favor of the other? Might as well since the work is done.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 16, 2022

Email sent 5/16/2022.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 19, 2022

OK, if that commit resolved the PEP8 issues, is this ready to merge after giving some time for people to respond to the email?

Copy link
Member

@tirthasheshpatel tirthasheshpatel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Added just one small comment. We can wait till Sunday for feedback on the mail. If there is none, this should be good to go in!

@mdhaber
Copy link
Contributor Author

mdhaber commented May 23, 2022

+1 -0 from the mailing list, @tirthasheshpatel. Would you like to follow up with multivariate_gamma/multivariate_invgamma and I review, or would you be willing to review another few like these?

@mdhaber mdhaber requested a review from tirthasheshpatel May 24, 2022 20:35
@mdhaber mdhaber added this to the 1.9.0 milestone May 24, 2022
Copy link
Contributor

@tylerjereddy tylerjereddy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Between the PR review and mailing list comments, I'll take this as +2 core developers in favor and CI is passing apart from a timeout.

I scanned through the diff and it looked well-done from a mechanical (non-stats-expert) point of view, but I'm mostly leaning on the current code review/feedback.

@tylerjereddy tylerjereddy merged commit d71d8ca into scipy:main May 25, 2022
@tylerjereddy
Copy link
Contributor

thanks

@tirthasheshpatel
Copy link
Member

Thanks @mdhaber!

Would you like to follow up with multivariate_gamma/multivariate_invgamma and I review, or would you be willing to review another few like these?

Sure, I can propose a PR for multivariate_gamma and multivariate_invgamma.

@oscarbenjamin
Copy link

I just picked this up in SymPy CI (sympy/sympy#23513). I went to follow the instruction to use multivariate_beta but it doesn't exist in older SciPy versions:

  File "/opt/hostedtoolcache/Python/3.11.0-beta.1/x64/lib/python3.11/site-packages/numpy/lib/utils.py", line 95, in newfunc
    warnings.warn(depdoc, DeprecationWarning, stacklevel=2)
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
DeprecationWarning:         `dirichlet.rvs` is deprecated, use `multivariate_beta.rvs` instead!
          # noqa
        `dirichlet` is deprecated due to an interface inconsistency: compared to
        other distributions, methods `pdf` and `logpdf` expect the transpose of the
        input `x`. Please use `multivariate_beta`, which corrects this inconsistency.
        In SciPy 1.11.0, this deprecation warning will be removed and `dirichlet` will
        become an alias for `multivariate_beta`.

What should be the expected code to use if wanting to suppress the warning while supporting multiple SciPy versions?

Something like this?

from scipy import stats

dirichlet = getattr(stats, 'multivariate_beta', None) or stats.dirichlet

And then in the future when we don't need to worry about old SciPy versions should we change the code again to just use multivariate_beta unconditionally? I think it's better if downstream code only needs to be updated once in response to something like this.

Generally I think that if there does not already exist an undeprecated alternative API then it is better to have a period of "soft deprecation" before emitting warnings or making any breaking change. By soft deprecation I mean something like:

  1. Add the new better API.
  2. Point to the new API in the docs and warn (in the docs) against using the old API.
  3. Maybe emit PendingDeprecationWarning so it can be seen in CI but not by end users.

It's not completely clear to me what is different between dirichlet and multivariate_beta or what will be changed about dirichlet in future. Is it only the pdf and logpdf methods that are different? Or is it the rvs method that was changed to make it consistent with the others?

If we were using those methods then presumably we would need to change something else in the code rather than simply replacing dirichlet with multivariate_beta. I think it's good to make that clear like:

The multivariate_beta class is a dropin replacement for dirichlet
except that the pdf and logpdf methods use a different convention.
If you use these methods like this

    from scipy.stats import dirichlet
    p = dirichlet.pdf(x)
    logp = dirichlet.logpdf(x)

then you should change that code to

    from scipy.stats import multivariate_beta
    p = multivariate_beta.pdf(np.transpose(x))
    logp = multivariate_beta.logpdf(np.transpose(x))

All other methods of multivariate_beta and dirichlet are identical
and any code using them does not need to be changed.

It's important when writing something like this to consider that the person (e.g. maintainer of large codebase) trying to fix the downstream code might know very little about the API and what it is used for and really needs clear instructions for what to do.

On the other hand SymPy doesn't use the pdf or logpdf methods so if that's all that has been changed then maybe there is no reason to give a warning if only the rvs method is being used (assuming that hasn't been changed and won't be in future).

@mdhaber
Copy link
Contributor Author

mdhaber commented May 25, 2022

What should be the expected code to use if wanting to suppress the warning while supporting multiple SciPy versions?

Yes, that, or try the multivariate_beta import and fall back to dirichlet.

PendingDeprecationWarning

That sounds like a project level decision. I don't think I've seen that used before.

if that's all that has been changed then maybe there is no reason to give a warning if only the rvs method is being used

For now, only pdf and logpdf has changed, and the only difference is that the expected input is the transpose of what it was before. So we could change this emit only on use of pdf and logpdf.

There is a separate change that may be made to RVS soon to fix an old bug.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 26, 2022

@tirthasheshpatel Thanks for offering to add multivariate_gamma and multivariate_invgamma. In the end, I think that in gh-16277 we'll add multivariate_beta without yet deprecating dirichlet, so I'd suggest adding those two classes without deprecating wishart and invwishart.

You are also slated to address gh-7689, which would fix the shape of the rvs output for all of these distributions. It would be nice if we could at least fix this in multivariate_beta, multivariate_gamma, and multivariate_invgamma before branching. (Otherwise we should probably hold up the release of multivariate_beta so we don't have to change it later.) Let me know if this is too much with the little time we have!

@tirthasheshpatel
Copy link
Member

tirthasheshpatel commented May 28, 2022

It would be nice if we could at least fix this in multivariate_beta, multivariate_gamma, and multivariate_invgamma before branching. (Otherwise we should probably hold up the release of multivariate_beta so we don't have to change it later.) Let me know if this is too much with the little time we have!

I will submit a PR for multivariate_beta, multivariate_gamma, and multivariate_invgamma ASAP (should I do that in one PR or propose a separate one for each?). I will also submit one fixing the size=1 issue and see if we can get it merged before 1.9 branches.

@mdhaber
Copy link
Contributor Author

mdhaber commented May 28, 2022

Let's do separate ones. (I think the first of those is done here, though.) Thank you!

@mdhaber
Copy link
Contributor Author

mdhaber commented May 29, 2022

@tirthasheshpatel Actually, I would prioritize fixing the squeeze issues of multivariate_beta in all methods. If we can't do that before branch, I'd suggest holding the addition of multivariate_beta until we can (i.e. reverting).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
defect A clear bug or issue that prevents SciPy from being installed or used as expected scipy.stats
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Dirichlet doesn't accept its own random variates as input to pdf
5 participants