From 0d47b1255a8108a1258c8e2633e84adfd6f29086 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Mon, 13 Jan 2025 11:01:50 +0100 Subject: [PATCH] Update paper_ranking.py --- src/bioregistry/analysis/paper_ranking.py | 26 +++++++++++------------ 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/bioregistry/analysis/paper_ranking.py b/src/bioregistry/analysis/paper_ranking.py index 64358c4f7..9c9909243 100644 --- a/src/bioregistry/analysis/paper_ranking.py +++ b/src/bioregistry/analysis/paper_ranking.py @@ -9,7 +9,7 @@ from collections import defaultdict from collections.abc import Iterable from pathlib import Path -from typing import Any, NamedTuple +from typing import Any, NamedTuple, Optional, Union import click import numpy as np @@ -45,8 +45,11 @@ XTest: TypeAlias = NDArray[np.str_] YTest: TypeAlias = NDArray[np.str_] +ClassifierHint: TypeAlias = Union[ClassifierMixin, LinearClassifierMixin] +Classifiers: TypeAlias = list[tuple[str, ClassifierHint]] -def get_publications_from_bioregistry(path: Path | None = None) -> pd.DataFrame: + +def get_publications_from_bioregistry(path: Optional[Path] = None) -> pd.DataFrame: """Load bioregistry data from a JSON file, extracting publication details and fetching abstracts if missing. :param path: Path to the bioregistry JSON file. @@ -101,7 +104,7 @@ def load_curated_papers(file_path: Path = CURATED_PAPERS_PATH) -> pd.DataFrame: return curated_df -def _get_metadata_for_ids(pubmed_ids: Iterable[int | str]) -> dict[str, dict[str, Any]]: +def _get_metadata_for_ids(pubmed_ids: Iterable[Union[int, str]]) -> dict[str, dict[str, Any]]: """Get metadata for articles in PubMed, wrapping the INDRA client.""" from indra.literature import pubmed_client @@ -187,7 +190,7 @@ def load_google_curation_df() -> pd.DataFrame: return df -def _map_labels(s: str) -> int | None: +def _map_labels(s: str) -> Optional[int]: """Map labels to binary values. :param s: Label value. @@ -200,9 +203,6 @@ def _map_labels(s: str) -> int | None: return None -Classifiers = list[tuple[str, ClassifierMixin | LinearClassifierMixin]] - - def train_classifiers(x_train: XTrain, y_train: YTrain) -> Classifiers: """Train multiple classifiers on the training data. @@ -241,16 +241,14 @@ def generate_meta_features( def _cross_val_predict( - clf: ClassifierMixin | LinearClassifierMixin, x_train: XTrain, y_train: YTrain, cv: int -) -> NDArray: + clf: ClassifierHint, x_train: XTrain, y_train: YTrain, cv: int +) -> NDArray[np.float64]: if not hasattr(clf, "predict_proba"): return cross_val_predict(clf, x_train, y_train, cv=cv, method="decision_function") return cross_val_predict(clf, x_train, y_train, cv=cv, method="predict_proba")[:, 1] -def _predict( - clf: ClassifierMixin | LinearClassifierMixin, x: NDArray[np.float64] -) -> NDArray[np.float64]: +def _predict(clf: ClassifierHint, x: NDArray[np.float64]) -> NDArray[np.float64]: if hasattr(clf, "predict_proba"): return clf.predict_proba(x)[:, 1] else: @@ -265,7 +263,7 @@ class MetaClassifierEvaluationResults(NamedTuple): def _evaluate_meta_classifier( - meta_clf: ClassifierMixin, x_test_meta: NDArray[np.float64], y_test: YTest + meta_clf: ClassifierMixin, x_test_meta: XTest, y_test: YTest ) -> MetaClassifierEvaluationResults: """Evaluate meta-classifier using MCC and AUC-ROC scores. @@ -285,7 +283,7 @@ def predict_and_save( vectorizer: TfidfVectorizer, classifiers: Classifiers, meta_clf: ClassifierMixin, - filename: str | Path, + filename: str, ) -> None: """Predict and save scores for new data using trained classifiers and meta-classifier.