Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
mcleber authored Dec 20, 2024
1 parent ce2ae21 commit 304cbd3
Showing 1 changed file with 126 additions and 0 deletions.
126 changes: 126 additions & 0 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import os
import mlflow
import json
import uvicorn
import numpy as np
from pydantic import BaseModel
from fastapi import FastAPI


class FetalHealthData(BaseModel):
accelerations: float
fetal_movement: float
uterine_contractions: float
severe_decelerations: float


app = FastAPI(title="Fetal Health API",
openapi_tags=[
{
"name": "Health",
"description": "Get API health status"
},
{
"name": "Prediction",
"description": "Model prediction"
}
])


def load_model():
"""
Loads a pre-trained model from an MLflow server.
This function connects to an MLflow server using the provided tracking URI, username,
and password.
It retrieves the latest version of the 'fetal_health' model registered on the server.
The function then loads the model using the specified run ID and returns the loaded model.
Returns:
loaded_model: The loaded pre-trained model.
Raises:
None
"""
print('reading model...')
MLFLOW_TRACKING_URI = 'https://dagshub.com/my_user/mlops_cardiotocography.mlflow'
MLFLOW_TRACKING_USERNAME = 'my_user'
MLFLOW_TRACKING_PASSWORD = 'my_token'
os.environ['MLFLOW_TRACKING_USERNAME'] = MLFLOW_TRACKING_USERNAME
os.environ['MLFLOW_TRACKING_PASSWORD'] = MLFLOW_TRACKING_PASSWORD

print("setting mlflow...")
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

print("creating client...")
client = mlflow.MlflowClient(tracking_uri=MLFLOW_TRACKING_URI)

print("getting registered model...")
registered_model = client.get_registered_model('fetal_health')

print("read model...")
run_id = registered_model.latest_versions[-1].run_id
logged_model = f'runs:/{run_id}/model'
loaded_model = mlflow.pyfunc.load_model(logged_model)

print(loaded_model)
return loaded_model


@app.on_event(event_type='startup')
def startup_event():
"""
A function that is called when the application starts up. It loads a model into the
global variable `loaded_model`.
Parameters:
None
Returns:
None
"""
global loaded_model
loaded_model = load_model()


@app.get(path='/',
tags=['Health'])
def api_health():
"""
A function that represents the health endpoint of the API.
Returns:
dict: A dictionary containing the status of the API, with the key "status" and
the value "healthy".
"""
return {"status": "healthy"}


@app.post(path='/predict',
tags=['Prediction'])
def predict(request: FetalHealthData):
"""
Predicts the fetal health based on the given request data.
Args:
request (FetalHealthData): The request data containing the fetal health parameters.
Returns:
dict: A dictionary containing the prediction of the fetal health.
Raises:
None
"""
global loaded_model
received_data = np.array([
request.accelerations,
request.fetal_movement,
request.uterine_contractions,
request.severe_decelerations,
]).reshape(1, -1)

print(received_data)
prediction = loaded_model.predict(received_data)
print(prediction)

return {"prediction": str(np.argmax(prediction[0]))}

0 comments on commit 304cbd3

Please sign in to comment.