-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
FIX Remove warnings when fitting a dataframe #21578
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FIX Remove warnings when fitting a dataframe #21578
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. If @jeremiedbb calls look at it since we work on this before the release but did not think about these methods.
I assume that this is relatively difficult to have a general test since it would depend on some parameters for some estimators. We might still have such of corner case for some of the estimators.
sklearn/ensemble/_forest.py
Outdated
@@ -512,7 +513,7 @@ def _compute_oob_predictions(self, X, y): | |||
(n_samples, 1, n_outputs) | |||
The OOB predictions. | |||
""" | |||
X = self._validate_data(X, dtype=DTYPE, accept_sparse="csr", reset=False) | |||
X = check_array(X, dtype=DTYPE, accept_sparse="csr", copy=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like we either want to validate the data or not. If this is a private method and we know validation is done somewhere else, why do we need to call check_array
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, check_array
is required to convert the CSC matrix (during fit) into a CSR matrix for prediction. I'm undecided on who would be responsibility for this. It's either the caller of _compute_oob_predictions
or _compute_oob_predictions
itself.
In any case, I updated the PR with a comment regarding this behavior.
@@ -745,4 +749,4 @@ def _global_clustering(self, X=None): | |||
self.subcluster_labels_ = clusterer.fit_predict(self.subcluster_centers_) | |||
|
|||
if compute_labels: | |||
self.labels_ = self.predict(X) | |||
self.labels_ = self._predict(X) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wouldn't a cleaner API be like self.predict(X, validate_input=False)
, or self.validate_input(predict=False).predict(X)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.predict(X, validate_input=False)
I don't think introducing a public arg for that is cleaner. I find it clean that we use a private function internally and expose a public function that does extra validation.
self.validate_input(predict=False).predict(X)
I don't get the predict=False
arg, could explain more what you have in mind ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are other reasons why a public, or a developer API, would be nice to have when it comes to [skipping] validation: #16653 (comment)
The predict=False
would kinda set a flag in the estimator to skip the validation in a certain method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The predict=False would kinda set a flag in the estimator to skip the validation in a certain method.
I think adding more state to the estimator after __init__
is outside the scope of this PR, but we can use this PR as a motivation to do it. It would kind of be like "inference mode".
self.predict(X, validate_input=False)
I think it would be very nice to have this type of kwarg everywhere. It would be similar to the check_finite
flag in SciPy. (Every year I see the "Scikit-learn is slow during prediction" and it comes down to the validation we do.)
In both cases, I do not think we should change public API with a bug fix PR.
I find the introduction of a private method like I think this would be a nice trigger to introduce the developer or public API we've talked about, to disable input validation in these estimators. Not sure what others think @scikit-learn/core-devs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I agree that having a generic public API to disable validation checks at inference is beyond the scope of this bugfix PR and just calling ad-hoc private functions when necessary is good enough for now.
I just have a question:
sklearn/ensemble/_forest.py
Outdated
""" | ||
# Prediction requires X to be in CSR format | ||
if issparse(X): | ||
X = check_array(X, accept_sparse="csr", force_all_finite=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why force_all_finite=True
here if input validation has already been performed in the caller?
Wouldn't the following be enough?
X = check_array(X, accept_sparse="csr", force_all_finite=True) | |
X = X.tocsr() |
Also, this comment about adding a comment to explain the test has not been addressed: https://github.com/scikit-learn/scikit-learn/pull/21578/files#r745491797 |
…om_forest_classifier
I updated PR with comment and suggestion. |
Some test froze in the macos CI and triggered the 60 minutes timeout. I pushed another commit to check whether this is deterministically happening in this PR or a rare random event. |
Here is where it happened in the run for the 5c22dec commit:
|
Should we open an issue for this to be discussed? Or is there already one? |
@lorentzenchr we could open a new issue, some background is present here: #16653 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you, @thomasjpfan.
I agree with you: this PR is fine as is and a general API is to be introduced in another.
And as always, Julien is nitpicking with some tiny suggestions when nothing has to fundamentally be changed.
…om_forest_classifier
…om_forest_classifier
Done in #21804. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK with the follow-up PR for MLP and to merge it after resolving problem on the CI.
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Co-authored-by: Olivier Grisel <[email protected]>
Reference Issues/PRs
Fixes #21577
Fixes #21618
What does this implement/fix? Explain your changes.
This PR removes input validation for methods that already validate the input. The common test was updated to catch warnings during
fit
.