Skip to content

ENH - implement Cox estimator #157

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
May 30, 2023

Conversation

Badr-MOUFAD
Copy link
Collaborator

@Badr-MOUFAD Badr-MOUFAD commented May 17, 2023

This adds a Cox estimator. It proceeds by

  • implementing a Cox datafit using Breslow estimate
  • Derive Cox estimator from it
  • unittest against lifelines.CoxPHFitter
  • Timing benchmarks against lifelines

Problem setup

This problem falls in the survival analysis setup. Given

  • $\mathbf{X} \in \mathbb{R}^{n \times p}$ a matrix of $p$ predictors and $n$ samples $x_i \in \mathbb{R}^p$
  • $y \in \mathbb{R}^n$ a vector recording the times of events occurrence
  • $s \in \{ 0, 1 \}^n$ a binary vector where $1$ means event occurred

The minus log-likelihood to be minimized reads [1]

$$ l(\beta) = \sum_{i=1}^{n} -s_i \langle x_i, \beta \rangle + \log(\textstyle\sum_{y_j \geq y_i} e^{\langle x_j, \beta \rangle}) $$

which represent the datafit to be considered.

Defining the matrix $\mathbf{B} \in \mathbb{R}^{n \times n}$ with $\mathbf{B}_{i,j} = 1 \ \mathrm{if} \ y_j \geq y_i$ and $0 \ \mathrm{otherwise}$, the datafit can be rewritten as

$$ l(\beta) = -\langle s, \mathbf{X}\beta \rangle + \langle s, \log(\mathbf{B}e^{\mathbf{X}\beta}) \rangle $$

Cox Estimator

Referring to skglm toolbox, the Cox estimator can be fit using a ProxNewton solver.


Benchmark

Links to results and repository to reproduce

Note

Revival of #124, and closes #121

References

[1] Lin, D. Y. "On the Breslow estimator." Lifetime data analysis 13 (2007): 471-480.

@Badr-MOUFAD
Copy link
Collaborator Author

@PABannier, WDYT about the input conversion (X, y, s)?

Also, I would be grateful if you could check the datafit expression 🙏

@PABannier
Copy link
Collaborator

PABannier commented May 17, 2023

@Badr-MOUFAD The input conversion is sound and greatly simplifies the expression of the datafit. However I'm not sure about Breslow's estimate, I think you miss a log in the second dot product. Could you compare it with lifeline's or PySurvival? Make sure you don't have tied ties otherwise Efron's and Breslow's estimates are different.

By the way, I don't think you need to compare it to an R library, lifelines or PySurvival will do the trick ;)

@Badr-MOUFAD
Copy link
Collaborator Author

You are right @PABannier, thanks for having a look.
I have updated the expression in the PR description

@Badr-MOUFAD
Copy link
Collaborator Author

When unit testing against lifeline

  1. I don't get the same solution as lifeline (Perhaps become of the optimization method ?)
  2. The objective values mismatch (with skglm having the lowest)

I pushed a debug_script where I fit a Cox estimator on a dummy dataset with different occurrence times to avoid data ties.

@PABannier, @mathurinm, any thoughts?

@mathurinm
Copy link
Collaborator

mathurinm commented May 24, 2023

The script fails on my machine with:

(click to expand error logs)
In [1]: %run debug_cox_lifeline.py
---------------------------------------------------------------------------
TypingError                               Traceback (most recent call last)
File ~/workspace/skglm/debug_cox_lifeline.py:33
     30 datafit = compiled_clone(Cox())
     31 penalty = compiled_clone(L1(alpha))
---> 33 datafit.initialize(X, (tm, s))
     35 w, _, _ = ProxNewton(fit_intercept=False).solve(
     36     X, (tm, s), datafit, penalty
     37 )
     39 # fit lifeline estimator

File ~/mambaforge/lib/python3.10/site-packages/numba/experimental/jitclass/boxing.py:61, in _generate_method.<locals>.wrapper(*args, **kwargs)
     59 @wraps(func)
     60 def wrapper(*args, **kwargs):
---> 61     return method(*args, **kwargs)

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
    464         msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
    465                f"by the following argument(s):\n{args_str}\n")
    466         e.patch_message(msg)
--> 468     error_rewrite(e, 'typing')
    469 except errors.UnsupportedError as e:
    470     # Something unsupported is present in the user code, add help info
    471     error_rewrite(e, 'unsupported_error')

File ~/mambaforge/lib/python3.10/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
    407     raise e
    408 else:
--> 409     raise e.with_traceback(None)

TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<built-in function getitem>) found for signature:

 >>> getitem(array(float64, 1d, C), Tuple(slice<a:b>, none))

There are 22 candidate implementations:
  - Of which 20 did not match due to:
  Overload of function 'getitem': File: <numerous>: Line N/A.
    With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
   No match.
  - Of which 2 did not match due to:
  Overload in function 'GetItemBuffer.generic': File: numba/core/typing/arraydecl.py: Line 166.
    With argument(s): '(array(float64, 1d, C), Tuple(slice<a:b>, none))':
   Rejected as the implementation raised a specific error:
     NumbaTypeError: unsupported array index type none in Tuple(slice<a:b>, none)
  raised from /home/mathurin/mambaforge/lib/python3.10/site-packages/numba/core/typing/arraydecl.py:72

During: typing of intrinsic-call at /home/mathurin/workspace/skglm/skglm/datafits/single_task.py (593)
During: typing of static-get-item at /home/mathurin/workspace/skglm/skglm/datafits/single_task.py (593)

File "skglm/datafits/single_task.py", line 593:
    def initialize(self, X, y):
        <source elided>
        tm, s = y
        self.B = (tm >= tm[:, None]).astype(X.dtype)
        ^

- Resolution failure for non-literal arguments:
None

During: resolving callee type: BoundFunction((<class 'numba.core.types.misc.ClassInstanceType'>, 'initialize') for instance.jitclass.Cox#7f376353e2c0<B:array(float64, 2d, C)>)
During: typing of call at <string> (3)

updating numba from 0.56.3 to 0.57 fixed the issue : if we can avoid requiring numba>=0.57 it's better, otherwise we should put it in the requirements

@Badr-MOUFAD
Copy link
Collaborator Author

My bad, forgot to require numba 0.57.

@Badr-MOUFAD
Copy link
Collaborator Author

Here is a link to benchmark results on a dummy dataset.
Also, a link to the benchopt repo to reproduce

@mathurinm mathurinm requested a review from PABannier May 25, 2023 06:25
Copy link
Collaborator

@PABannier PABannier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Open question: do we want to make it an estimator? Once Efron is implemented, I think it's worth implementing, especially if we aim at a wider adoption by the biostat community.

Copy link
Collaborator

@mathurinm mathurinm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

merge once the test for gradient and hessian has been added

@Badr-MOUFAD Badr-MOUFAD merged commit 399dfc6 into scikit-learn-contrib:main May 30, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

FEAT Add Cox loss
3 participants