Commit 203a7bb8 authored by Sascha Herzinger's avatar Sascha Herzinger

Added NelsonAalen Estimator for survival analysis

parent be7cd03e
Pipeline #5339 failed with stages
in 2 minutes and 48 seconds
"""This module provides statistics for a Kaplan Meier Survival Analysis.""" """This module provides statistics for a Survival Analysis."""
import logging import logging
from typing import List from typing import List
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from lifelines import KaplanMeierFitter from lifelines import KaplanMeierFitter, NelsonAalenFitter
from fractalis.analytics.task import AnalyticTask from fractalis.analytics.task import AnalyticTask
from fractalis.analytics.tasks.shared import utils from fractalis.analytics.tasks.shared import utils
...@@ -14,15 +14,16 @@ from fractalis.analytics.tasks.shared import utils ...@@ -14,15 +14,16 @@ from fractalis.analytics.tasks.shared import utils
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class KaplanMeierSurvivalTask(AnalyticTask): class SurvivalTask(AnalyticTask):
"""Kaplan Meier Survival Analysis Task implementing AnalyticTask. This """Survival Analysis Task implementing AnalyticTask.
class is a submittable celery task.""" This class is a submittable celery task."""
name = 'kaplan-meier-estimate' name = 'survival-analysis'
def main(self, durations: List[pd.DataFrame], def main(self, durations: List[pd.DataFrame],
categories: List[pd.DataFrame], categories: List[pd.DataFrame],
event_observed: List[pd.DataFrame], event_observed: List[pd.DataFrame],
estimator: str,
id_filter: List[str], id_filter: List[str],
subsets: List[List[str]]) -> dict: subsets: List[List[str]]) -> dict:
# TODO: Docstring # TODO: Docstring
...@@ -45,8 +46,9 @@ class KaplanMeierSurvivalTask(AnalyticTask): ...@@ -45,8 +46,9 @@ class KaplanMeierSurvivalTask(AnalyticTask):
stats = {} stats = {}
# for every category and subset combination estimate the survival fun. # for every category and subset combination estimate the survival fun.
for category in df['category'].unique().tolist(): for category in df['category'].unique().tolist():
if not stats.get(category):
stats[category] = {}
for subset in df['subset'].unique().tolist(): for subset in df['subset'].unique().tolist():
kmf = KaplanMeierFitter()
sub_df = df[(df['category'] == category) & sub_df = df[(df['category'] == category) &
(df['subset'] == subset)] (df['subset'] == subset)]
T = sub_df['value'] T = sub_df['value']
...@@ -56,17 +58,35 @@ class KaplanMeierSurvivalTask(AnalyticTask): ...@@ -56,17 +58,35 @@ class KaplanMeierSurvivalTask(AnalyticTask):
E = event_observed[0].merge(sub_df, how='right', on='id') E = event_observed[0].merge(sub_df, how='right', on='id')
E = [bool(x) and not np.isnan(x) for x in E['value']] E = [bool(x) and not np.isnan(x) for x in E['value']]
assert len(E) == len(T) assert len(E) == len(T)
kmf.fit(durations=T, event_observed=E) if estimator == 'NelsonAalen':
if not stats.get(category): fitter = NelsonAalenFitter()
stats[category] = {} fitter.fit(durations=T, event_observed=E)
# noinspection PyUnresolvedReferences estimate = fitter.cumulative_hazard_[
'NA_estimate'].tolist()
ci_lower = fitter.confidence_interval_[
'NA_estimate_lower_0.95'].tolist()
ci_upper = fitter.confidence_interval_[
'NA_estimate_upper_0.95'].tolist()
elif estimator == 'KaplanMeier':
fitter = KaplanMeierFitter()
fitter.fit(durations=T, event_observed=E)
# noinspection PyUnresolvedReferences
estimate = fitter.survival_function_[
'KM_estimate'].tolist()
ci_lower = fitter.confidence_interval_[
'KM_estimate_lower_0.95'].tolist()
ci_upper = fitter.confidence_interval_[
'KM_estimate_upper_0.95'].tolist()
else:
error = 'Unknown estimator: {}'.format(estimator)
logger.exception(error)
raise ValueError(error)
timeline = fitter.timeline.tolist()
stats[category][subset] = { stats[category][subset] = {
'timeline': kmf.timeline, 'timeline': timeline,
'median': kmf.median_, 'estimate': estimate,
'survival_function': 'ci_lower': ci_lower,
kmf.survival_function_.to_dict(orient='list'), 'ci_upper': ci_upper
'confidence_interval':
kmf.confidence_interval_.to_dict(orient='list')
} }
return { return {
......
"""This module contains tests for the kaplan_meier_survival module.""" """This module contains tests for the survival module."""
from lifelines.datasets import load_waltons from lifelines.datasets import load_waltons
from fractalis.analytics.tasks.kaplan_meier_survival.main \ from fractalis.analytics.tasks.kaplan_meier_survival.main \
import KaplanMeierSurvivalTask import SurvivalTask
class TestKaplanMeierSurvivalTask: class TestSurvivalTask:
task = KaplanMeierSurvivalTask() task = SurvivalTask()
def test_correct_output_for_simple_input(self): def test_correct_output_for_simple_input(self):
df = load_waltons() df = load_waltons()
...@@ -19,12 +19,13 @@ class TestKaplanMeierSurvivalTask: ...@@ -19,12 +19,13 @@ class TestKaplanMeierSurvivalTask:
results = self.task.main(durations=[duration], results = self.task.main(durations=[duration],
categories=[], categories=[],
event_observed=[], event_observed=[],
estimator='KaplanMeier',
id_filter=[], id_filter=[],
subsets=[]) subsets=[])
assert 'timeline' in results['stats'][''][0] assert results['stats'][''][0]['timeline']
assert 'median' in results['stats'][''][0] assert results['stats'][''][0]['estimate']
assert 'survival_function' in results['stats'][''][0] assert results['stats'][''][0]['ci_lower']
assert 'confidence_interval' in results['stats'][''][0] assert results['stats'][''][0]['ci_upper']
def test_correct_output_for_complex_input(self): def test_correct_output_for_complex_input(self):
df = load_waltons() df = load_waltons()
...@@ -39,7 +40,14 @@ class TestKaplanMeierSurvivalTask: ...@@ -39,7 +40,14 @@ class TestKaplanMeierSurvivalTask:
results = self.task.main(durations=[duration], results = self.task.main(durations=[duration],
categories=[categories], categories=[categories],
event_observed=[event_observed], event_observed=[event_observed],
estimator='NelsonAalen',
id_filter=[], id_filter=[],
subsets=[]) subsets=[])
assert results['stats']['control'] assert results['stats']['control'][0]['timeline']
assert results['stats']['miR-137'] assert results['stats']['control'][0]['estimate']
assert results['stats']['control'][0]['ci_lower']
assert results['stats']['control'][0]['ci_upper']
assert results['stats']['miR-137'][0]['timeline']
assert results['stats']['miR-137'][0]['estimate']
assert results['stats']['miR-137'][0]['ci_lower']
assert results['stats']['miR-137'][0]['ci_upper']
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment