Skip to content

Commit

Permalink
Update paper_ranking.py
Browse files Browse the repository at this point in the history
  • Loading branch information
cthoyt committed Jan 13, 2025
1 parent 55f133c commit 0d47b12
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions src/bioregistry/analysis/paper_ranking.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections import defaultdict
from collections.abc import Iterable
from pathlib import Path
from typing import Any, NamedTuple
from typing import Any, NamedTuple, Optional, Union

import click
import numpy as np
Expand Down Expand Up @@ -45,8 +45,11 @@
XTest: TypeAlias = NDArray[np.str_]
YTest: TypeAlias = NDArray[np.str_]

ClassifierHint: TypeAlias = Union[ClassifierMixin, LinearClassifierMixin]
Classifiers: TypeAlias = list[tuple[str, ClassifierHint]]

def get_publications_from_bioregistry(path: Path | None = None) -> pd.DataFrame:

def get_publications_from_bioregistry(path: Optional[Path] = None) -> pd.DataFrame:
"""Load bioregistry data from a JSON file, extracting publication details and fetching abstracts if missing.
:param path: Path to the bioregistry JSON file.
Expand Down Expand Up @@ -101,7 +104,7 @@ def load_curated_papers(file_path: Path = CURATED_PAPERS_PATH) -> pd.DataFrame:
return curated_df


def _get_metadata_for_ids(pubmed_ids: Iterable[int | str]) -> dict[str, dict[str, Any]]:
def _get_metadata_for_ids(pubmed_ids: Iterable[Union[int, str]]) -> dict[str, dict[str, Any]]:
"""Get metadata for articles in PubMed, wrapping the INDRA client."""
from indra.literature import pubmed_client

Expand Down Expand Up @@ -187,7 +190,7 @@ def load_google_curation_df() -> pd.DataFrame:
return df


def _map_labels(s: str) -> int | None:
def _map_labels(s: str) -> Optional[int]:
"""Map labels to binary values.
:param s: Label value.
Expand All @@ -200,9 +203,6 @@ def _map_labels(s: str) -> int | None:
return None


Classifiers = list[tuple[str, ClassifierMixin | LinearClassifierMixin]]


def train_classifiers(x_train: XTrain, y_train: YTrain) -> Classifiers:
"""Train multiple classifiers on the training data.
Expand Down Expand Up @@ -241,16 +241,14 @@ def generate_meta_features(


def _cross_val_predict(
clf: ClassifierMixin | LinearClassifierMixin, x_train: XTrain, y_train: YTrain, cv: int
) -> NDArray:
clf: ClassifierHint, x_train: XTrain, y_train: YTrain, cv: int
) -> NDArray[np.float64]:
if not hasattr(clf, "predict_proba"):
return cross_val_predict(clf, x_train, y_train, cv=cv, method="decision_function")
return cross_val_predict(clf, x_train, y_train, cv=cv, method="predict_proba")[:, 1]


def _predict(
clf: ClassifierMixin | LinearClassifierMixin, x: NDArray[np.float64]
) -> NDArray[np.float64]:
def _predict(clf: ClassifierHint, x: NDArray[np.float64]) -> NDArray[np.float64]:
if hasattr(clf, "predict_proba"):
return clf.predict_proba(x)[:, 1]
else:
Expand All @@ -265,7 +263,7 @@ class MetaClassifierEvaluationResults(NamedTuple):


def _evaluate_meta_classifier(
meta_clf: ClassifierMixin, x_test_meta: NDArray[np.float64], y_test: YTest
meta_clf: ClassifierMixin, x_test_meta: XTest, y_test: YTest
) -> MetaClassifierEvaluationResults:
"""Evaluate meta-classifier using MCC and AUC-ROC scores.
Expand All @@ -285,7 +283,7 @@ def predict_and_save(
vectorizer: TfidfVectorizer,
classifiers: Classifiers,
meta_clf: ClassifierMixin,
filename: str | Path,
filename: str,
) -> None:
"""Predict and save scores for new data using trained classifiers and meta-classifier.
Expand Down

0 comments on commit 0d47b12

Please sign in to comment.