diff --git a/deploy-model-with-fastapi/requirements.txt b/deploy-model-with-fastapi/requirements.txt index 6e69b62..c595247 100644 --- a/deploy-model-with-fastapi/requirements.txt +++ b/deploy-model-with-fastapi/requirements.txt @@ -5,3 +5,4 @@ pandas==2.2.3 scikit-learn==1.6.1 uvicorn-worker==0.3.0 uvicorn[standard]==0.34.3 +pydantic diff --git a/deploy-model-with-fastapi/server.py b/deploy-model-with-fastapi/server.py index 7cf36f2..bba9457 100644 --- a/deploy-model-with-fastapi/server.py +++ b/deploy-model-with-fastapi/server.py @@ -5,6 +5,7 @@ import joblib import pandas as pd from fastapi import FastAPI +from pydantic import BaseModel def _get_model_dir(): @@ -34,18 +35,23 @@ async def lifespan(app: FastAPI): async def health() -> Dict[str, bool]: return {"healthy": True} +class RequestBody(BaseModel): + sepal_length: float + sepal_width: float + petal_length: float + petal_width: float @app.post("/predict") def predict( - sepal_length: float, sepal_width: float, petal_length: float, petal_width: float + request_body: RequestBody ): global model class_names = ["setosa", "versicolor", "virginica"] data = dict( - sepal_length=sepal_length, - sepal_width=sepal_width, - petal_length=petal_length, - petal_width=petal_width, + sepal_length=request_body.sepal_length, + sepal_width=request_body.sepal_width, + petal_length=request_body.petal_length, + petal_width=request_body.petal_width, ) prediction = model.predict_proba(pd.DataFrame([data]))[0] predictions = []