Meritshot Tutorials

  1. Home
  2. »
  3. Deploying a Pretrained Model with Flask

Flask Tutorial

Deploying a Pretrained Model with Flask

Example: Deploying a scikit-learn Model

In this section, we will go through the steps of deploying a pre-trained scikit-learn model using Flask. The goal is to create an API endpoint that accepts input, applies the trained machine learning model, and returns the prediction to the user.

We will use a simple model as an example. Let’s assume we have a RandomForestRegressor model that predicts house prices based on features like the number of rooms and the area of the house.

Steps to Deploy a scikit-learn Model with Flask

  1. Train and Save the Model
    First, we need to train a simple scikit-learn model and save it using pickle. This model will then be loaded into the Flask application for predictions.

Training and Saving the Model:

import pandas as pd

from sklearn.ensemble import RandomForestRegressor

from sklearn.model_selection import train_test_split

import pickle

# Load dataset (for example, a simple dataset with house features)

data = pd.DataFrame({

    ‘rooms’: [3, 4, 5, 2, 6],

    ‘area’: [1000, 1500, 2000, 1200, 2500],

    ‘price’: [400000, 500000, 600000, 350000, 750000]

})

# Split the dataset into features and target

X = data[[‘rooms’, ‘area’]]

y = data[‘price’]

# Train the model

model = RandomForestRegressor()

model.fit(X, y)

# Save the model using pickle

with open(‘house_price_model.pkl’, ‘wb’) as file:

    pickle.dump(model, file)

print(“Model saved!”)

In this example:

    • We train a Random Forest model to predict house prices based on the number of rooms and the area of the house.
    • After training, we save the model using pickle.

2. Flask Application for Prediction

Now, let’s build a simple Flask application that will load this pre-trained model and serve predictions via an API endpoint.

Flask Application for Deployment:

from flask import Flask, request, jsonify

import pickle

import numpy as np

app = Flask(__name__)

# Load the pre-trained model

with open(‘house_price_model.pkl’, ‘rb’) as model_file:

    model = pickle.load(model_file)

@app.route(‘/predict’, methods=[‘POST’])

def predict():

    # Get the user input from the POST request

    try:

        rooms = float(request.form[‘rooms’])

        area = float(request.form[‘area’])

    except KeyError:

        return jsonify({‘error’: ‘Invalid input. Please provide “rooms” and “area”‘}), 400

    # Prepare the input for the model

    input_data = np.array([[rooms, area]])

    # Make the prediction

    prediction = model.predict(input_data)

    # Return the prediction as JSON

    return jsonify({‘predicted_price’: prediction[0]})

if __name__ == “__main__”:

    app.run(debug=True)

Explanation of the Flask Code:

  • Loading the Model: We load the pre-trained model using pickle. This model is the one we saved earlier.

with open(‘house_price_model.pkl’, ‘rb’) as model_file:

                model = pickle.load(model_file)

  • Creating the /predict Endpoint: We define a route ‘/predict’ to handle POST requests. The user will send input (such as the number of rooms and area), and the model will return a predicted house price.
  • Handling User Input: The input data (number of rooms and area) is extracted from the form data of the request. If the required fields are missing or incorrect, the API returns a 400 error with an appropriate message.
  • Making the Prediction: After preparing the user input, we feed it to the model for prediction.

input_data = np.array([[rooms, area]])

prediction = model.predict(input_data)

  • Returning the Result: The predicted price is returned as a JSON response, so the user can easily consume it in a front-end or other applications.
  1. Running the Flask Application
    To run the Flask application, execute the script in your terminal:

python app.py

Once the application is running, you can make POST requests to the /predict endpoint with data for house features (rooms and area). For example, using Postman or cURL, you can send a POST request:

cURL Example:

curl -X POST -F “rooms=4” -F “area=1500” http://127.0.0.1:5000/predict

The response will look like this:

{

    “predicted_price”: 490000.0

}

Summary of the Steps:

  1. Train a model using scikit-learn (e.g., RandomForestRegressor).
  2. Save the model with pickle or joblib.
  3. Create a Flask API that loads the saved model and exposes a /predict endpoint.
  4. Handle POST requests and return predictions in JSON format.

Frequently Asked Questions

  1. How do I deploy multiple models with Flask?
    • Answer: You can deploy multiple models by loading each model into a separate variable and creating different routes for each model. For example, you can have /predict_house_price for one model and /predict_loan_approval for another.
  2. How do I ensure the model works with new data?
    • Answer: Ensure that the features in the new data match the features used during training. You may need to preprocess the data before passing it into the model (e.g., scaling, encoding categorical variables).
  3. Can I deploy models trained using libraries other than scikit-learn (e.g., TensorFlow, PyTorch)?
    • Answer: Yes, you can deploy models from other libraries (e.g., TensorFlow, PyTorch). Simply load the model using the appropriate library (e.g., tensorflow.keras.models.load_model for Keras models) and follow the same process for serving predictions.
  4. How do I handle large model files?
    • Answer: If your model file is large, consider using a cloud service to store the model (e.g., AWS S3, Google Cloud Storage) and load it dynamically when needed. Alternatively, use model compression techniques to reduce the file size.
  5. Can I use Flask for real-time prediction in production?
    • Answer: Flask can be used for real-time predictions, but it may not be the best choice for very high-throughput production environments. For more scalable solutions, consider using FastAPI or deploying the model with cloud-based solutions such as AWS Lambda or Google Cloud Functions.