Slackbot
10/16/2022, 9:27 PMOluwaseyi Gbadamosi
10/16/2022, 9:34 PMfrom __future__ import annotations
import bentoml
from data_validation.schema import InputSchema
from pycaret.classification import load_model, predict_model
from fastapi.responses import JSONResponse
from fastapi import status, Response
from typing import IO, Dict, Any
from numpy.typing import NDArray
import numpy as np
import pandas as pd
from scripts.helper import model_predict
import os
from fastapi import FastAPI
import csv
from fastapi import File, UploadFile, HTTPException
from scripts.custom_runner import PyCaretRunnable, churn_model
# Model & Service names
SERVICE_NAME = 'churn-api'
MODEL_NAME = 'churn-model:aumvrosds2s2n4cx'
churn_model = bentoml.picklable_model.get(MODEL_NAME)
class PyCaretRunnable(bentoml.Runnable):
SUPPORTED_RESOURCES = ("cpu")
SUPPORTS_CPU_MULTI_THREADING = True
def __init__(self):
"""Constructor method for loading the model"""
self.model = bentoml.picklable_model.load_model(churn_model)
@bentoml.Runnable.method(batchable=False)
def make_prediction(self, input_data: dict) -> list:
"""This method uses the trained model to generate predictions.
Args:
input_data (dict): The JSON data sent by the client.
Returns:
list: A list with the predicted values.
"""
return model_predict(self.model, input_data)
model_runner = bentoml.Runner(PyCaretRunnable, models=[churn_model])
svc = bentoml.Service(SERVICE_NAME, runners=[model_runner])
fastapi_app = FastAPI()
svc.mount_asgi_app(fastapi_app)
@fastapi_app.get("/metadata")
def metadata():
return {"name": churn_model.tag.name, "version": churn_model.tag.version}
@fastapi_app.post("/predict/batch/api/v1")
async def predict_batch(file: UploadFile = File(...)):
print(file.filename)
# Handle the file only if it is a CSV
if file.filename.endswith(".csv"):
# Create a temporary file with the same name as the uploaded
# CSV file to load the data into a pandas Dataframe
with open(file.filename, "wb")as f:
f.write(file.file.read())
data = pd.read_csv(file.filename)
print(data.shape)
os.remove(file.filename)
pred = await model_runner.make_prediction.async_run(data)
# Return a JSON object containing the model predictions
return {
pred
# "Labels": model.predict(data)
}
else:
# Raise a HTTP 400 Exception, indicating Bad Request
# (you can learn more about HTTP response status codes here)
raise HTTPException(status_code=400, detail="Invalid file format. Only CSV Files accepted.")
@fastapi_app.post("/predict_fastapi")
def predict(features: InputSchema):
input_df = pd.DataFrame([features.dict()])
results = model_runner.make_prediction.run(input_df)
return { "prediction": results.tolist()[0] }
Chaoyu
10/17/2022, 1:07 AMOluwaseyi Gbadamosi
10/17/2022, 2:01 AMChaoyu
10/17/2022, 2:16 AMChaoyu
10/17/2022, 2:17 AMChaoyu
10/17/2022, 2:17 AMOluwaseyi Gbadamosi
10/17/2022, 7:02 AMOluwaseyi Gbadamosi
10/17/2022, 7:21 AMOluwaseyi Gbadamosi
10/18/2022, 5:14 AM