-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #15 from databio/saanika
Added class AttrStandardizer
- Loading branch information
Showing
10 changed files
with
395 additions
and
202 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,43 @@ | ||
# bedmess | ||
# BEDMS | ||
|
||
bedmess is a tool used to standardize genomics/epigenomics metadata based on a schema chosen by the user ( eg. ENCODE, FAIRTRACKS). | ||
BEDMS (BED Metadata Standardizer) is a tool used to standardize genomics/epigenomics metadata based on a schema chosen by the user ( eg. ENCODE, FAIRTRACKS, BEDBASE). | ||
|
||
|
||
To install `attribute-standardizer` , you need to clone this repository first. Follow the steps given below to install: | ||
|
||
``` | ||
git clone https://github.com/databio/bedmess.git | ||
git clone https://github.com/databio/bedms.git | ||
cd bedmess | ||
cd bedms | ||
pip install . | ||
``` | ||
|
||
## Usage | ||
|
||
Using Python, this is how you can run `attribute_standardizer` : | ||
Using Python, this is how you can run `attribute_standardizer` and print the results : | ||
|
||
|
||
``` | ||
from attribute_standardizer.attribute_standardizer import attr_standardizer | ||
from attribute_standardizer import AttrStandardizer | ||
attr_standardizer(pep=/path/to/pep, schema="ENCODE") | ||
model = AttrStandardizer("ENCODE") | ||
model = AttrStandardizer("FAIRTRACKS") | ||
results = model.standardize(pep ="geo/gse178283:default") | ||
print(results) | ||
``` | ||
|
||
To see the available schemas, you can run: | ||
``` | ||
schemas = model.get_available_schemas() | ||
print(schemas) | ||
``` | ||
|
||
This will print the available schemas as a list. | ||
|
||
You can use the format provided in the `trial.py` script in this repository as a reference. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1 @@ | ||
from .attribute_standardizer import attr_standardizer | ||
from .attr_standardizer import AttrStandardizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,206 @@ | ||
import logging | ||
from typing import Dict, Tuple, Union | ||
|
||
import peppy | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as torch_functional | ||
|
||
from .const import ( | ||
CONFIDENCE_THRESHOLD, | ||
DROPOUT_PROB, | ||
EMBEDDING_SIZE, | ||
HIDDEN_SIZE, | ||
INPUT_SIZE_BOW_BEDBASE, | ||
INPUT_SIZE_BOW_ENCODE, | ||
INPUT_SIZE_BOW_FAIRTRACKS, | ||
OUTPUT_SIZE_BEDBASE, | ||
OUTPUT_SIZE_ENCODE, | ||
OUTPUT_SIZE_FAIRTRACKS, | ||
SENTENCE_TRANSFORMER_MODEL, | ||
PROJECT_NAME, | ||
) | ||
from .model import BoWSTModel | ||
from .utils import ( | ||
data_encoding, | ||
data_preprocessing, | ||
fetch_from_pephub, | ||
get_any_pep, | ||
load_from_huggingface, | ||
) | ||
|
||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(PROJECT_NAME) | ||
|
||
|
||
class AttrStandardizer: | ||
def __init__(self, schema: str, confidence: int = CONFIDENCE_THRESHOLD) -> None: | ||
""" | ||
Initializes the attribute standardizer with user provided schema, loads the model. | ||
:param str schema: User provided schema, can be "ENCODE" or "FAIRTRACKS" | ||
:param int confidence: Confidence threshold for the predictions. | ||
""" | ||
self.schema = schema | ||
self.model = self._load_model() | ||
self.conf_threshold = confidence | ||
|
||
def _get_parameters(self) -> Tuple[int, int, int, int, int, float]: | ||
""" | ||
Get the model parameters as per the chosen schema. | ||
:return Tuple[int, int, int, int, int, int, float]: Tuple containing the model parameters. | ||
""" | ||
if self.schema == "ENCODE": | ||
return ( | ||
INPUT_SIZE_BOW_ENCODE, | ||
EMBEDDING_SIZE, | ||
EMBEDDING_SIZE, | ||
HIDDEN_SIZE, | ||
OUTPUT_SIZE_ENCODE, | ||
DROPOUT_PROB, | ||
) | ||
elif self.schema == "FAIRTRACKS": | ||
return ( | ||
INPUT_SIZE_BOW_FAIRTRACKS, | ||
EMBEDDING_SIZE, | ||
EMBEDDING_SIZE, | ||
HIDDEN_SIZE, | ||
OUTPUT_SIZE_FAIRTRACKS, | ||
DROPOUT_PROB, | ||
) | ||
elif self.schema == "BEDBASE": | ||
return ( | ||
INPUT_SIZE_BOW_BEDBASE, | ||
EMBEDDING_SIZE, | ||
EMBEDDING_SIZE, | ||
HIDDEN_SIZE, | ||
OUTPUT_SIZE_BEDBASE, | ||
DROPOUT_PROB, | ||
) | ||
else: | ||
raise ValueError( | ||
f"Schema not available: {self.schema}. Presently, three schemas are available: ENCODE , FAIRTRACKS, BEDBASE" | ||
) | ||
|
||
def _load_model(self) -> nn.Module: | ||
""" | ||
Calls function to load the model from HuggingFace repository and sets to eval(). | ||
:return nn.Module: Loaded Neural Network Model. | ||
""" | ||
try: | ||
model = load_from_huggingface(self.schema) | ||
state_dict = torch.load(model) | ||
|
||
( | ||
input_size_values, | ||
input_size_values_embeddings, | ||
input_size_headers, | ||
hidden_size, | ||
output_size, | ||
dropout_prob, | ||
) = self._get_parameters() | ||
|
||
model = BoWSTModel( | ||
input_size_values, | ||
input_size_values_embeddings, | ||
input_size_headers, | ||
hidden_size, | ||
output_size, | ||
dropout_prob, | ||
) | ||
model.load_state_dict(state_dict) | ||
model.eval() | ||
return model | ||
|
||
except Exception as e: | ||
logger.error(f"Error loading the model: {str(e)}") | ||
raise | ||
|
||
def standardize( | ||
self, pep: Union[str, peppy.Project] | ||
) -> Dict[str, Dict[str, float]]: | ||
""" | ||
Fetches the user provided PEP from the PEPHub registry path, returns the predictions. | ||
:param str pep: peppy.Project object or PEPHub registry path to PEP. | ||
:return Dict[str, Dict[str, float]]: Suggestions to the user. | ||
""" | ||
if isinstance(pep, str): | ||
pep = get_any_pep(pep) | ||
elif isinstance(pep, peppy.Project): | ||
pass | ||
else: | ||
raise ValueError( | ||
"PEP should be either a path to PEPHub registry or peppy.Project object." | ||
) | ||
try: | ||
csv_file = fetch_from_pephub(pep) | ||
|
||
X_values_st, X_headers_st, X_values_bow, num_rows = data_preprocessing( | ||
csv_file | ||
) | ||
( | ||
X_headers_embeddings_tensor, | ||
X_values_embeddings_tensor, | ||
X_values_bow_tensor, | ||
label_encoder, | ||
) = data_encoding( | ||
num_rows, | ||
X_values_st, | ||
X_headers_st, | ||
X_values_bow, | ||
self.schema, | ||
model_name=SENTENCE_TRANSFORMER_MODEL, | ||
) | ||
|
||
logger.info("Data Preprocessing completed.") | ||
|
||
with torch.no_grad(): | ||
outputs = self.model( | ||
X_values_bow_tensor, | ||
X_values_embeddings_tensor, | ||
X_headers_embeddings_tensor, | ||
) | ||
probabilities = torch_functional.softmax(outputs, dim=1) | ||
|
||
values, indices = torch.topk(probabilities, k=3, dim=1) | ||
top_preds = indices.tolist() | ||
top_confidences = values.tolist() | ||
|
||
decoded_predictions = [ | ||
label_encoder.inverse_transform(indices) for indices in top_preds | ||
] | ||
|
||
suggestions = {} | ||
for i, category in enumerate(X_headers_st): | ||
category_suggestions = {} | ||
if top_confidences[i][0] >= self.conf_threshold: | ||
for j in range(3): | ||
prediction = decoded_predictions[i][j] | ||
probability = top_confidences[i][j] | ||
if probability >= self.conf_threshold: | ||
category_suggestions[prediction] = probability | ||
else: | ||
break | ||
else: | ||
category_suggestions["Not Predictable"] = 0.0 | ||
|
||
suggestions[category] = category_suggestions | ||
|
||
return suggestions | ||
|
||
except Exception as e: | ||
logger.error( | ||
f"Error occured during standardization in standardize function: {str(e)}" | ||
) | ||
|
||
@staticmethod | ||
def get_available_schemas() -> list[str]: | ||
""" | ||
Stores a list of available schemas. | ||
:return list: List of available schemas. | ||
""" | ||
schemas = ["ENCODE", "FAIRTRACKS", "BEDBASE"] | ||
return schemas |
Oops, something went wrong.