Meritshot Tutorials
- Home
- »
- Deploying a Pretrained Model with Flask
Flask Tutorial
-
Introduction to Flask for Machine LearningIntroduction to Flask for Machine Learning
-
Why Use Flask to Deploy ML Models?Why Use Flask to Deploy ML Models?
-
Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)
-
Setting Up the EnvironmentSetting Up the Environment
-
Basics of FlaskBasics of Flask
-
Flask Application StructureFlask Application Structure
-
Running the Development ServerRunning the Development Server
-
Debug ModeDebug Mode
-
Preparing Machine Learning Models for DeploymentPreparing Machine Learning Models for Deployment
-
Saving the Trained ModelSaving the Trained Model
-
Loading the Saved Model in PythonLoading the Saved Model in Python
-
Understanding Routes and EndpointsUnderstanding Routes and Endpoints
-
Setting Up API Endpoints for PredictionSetting Up API Endpoints for Prediction
-
Flask Templates and Jinja2 BasicsFlask Templates and Jinja2 Basics
-
Creating a Simple HTML Form for User InputCreating a Simple HTML Form for User Input
-
Connecting the Frontend to the BackendConnecting the Frontend to the Backend
-
Handling Requests and ResponsesHandling Requests and Responses
-
Accepting User Input for PredictionsAccepting User Input for Predictions
-
Returning Predictions as JSON or HTMLReturning Predictions as JSON or HTML
-
Deploying a Pretrained Model with FlaskDeploying a Pretrained Model with Flask
-
Example: Deploying a TensorFlow/Keras ModelExample: Deploying a TensorFlow/Keras Model
-
Example: Deploying a PyTorch ModelExample: Deploying a PyTorch Model
-
Flask and RESTful APIs for MLFlask and RESTful APIs for ML
-
Serving JSON ResponsesServing JSON Responses
-
Testing API Endpoints with PostmanTesting API Endpoints with Postman
-
Handling Real-World ScenariosHandling Real-World Scenarios
-
Scaling ML Model Predictions for Large InputsScaling ML Model Predictions for Large Inputs
-
Batch Predictions vs. Single PredictionsBatch Predictions vs. Single Predictions
-
Adding Authentication and SecurityAdding Authentication and Security
-
Adding API Authentication (Token-Based)Adding API Authentication (Token-Based)
-
Protecting Sensitive DataProtecting Sensitive Data
-
Deploying Flask ApplicationsDeploying Flask Applications
-
Deploying on HerokuDeploying on Heroku
-
Deploying on AWS, GCP, or AzureDeploying on AWS, GCP, or Azure
-
Containerizing Flask Apps with DockerContainerizing Flask Apps with Docker
Deploying a Pretrained Model with Flask
Example: Deploying a scikit-learn Model
In this section, we will go through the steps of deploying a pre-trained scikit-learn model using Flask. The goal is to create an API endpoint that accepts input, applies the trained machine learning model, and returns the prediction to the user.
We will use a simple model as an example. Let’s assume we have a RandomForestRegressor model that predicts house prices based on features like the number of rooms and the area of the house.
Steps to Deploy a scikit-learn Model with Flask
- Train and Save the Model
First, we need to train a simple scikit-learn model and save it using pickle. This model will then be loaded into the Flask application for predictions.
Training and Saving the Model:
import pandas as pd
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
import pickle
# Load dataset (for example, a simple dataset with house features)
data = pd.DataFrame({
‘rooms’: [3, 4, 5, 2, 6],
‘area’: [1000, 1500, 2000, 1200, 2500],
‘price’: [400000, 500000, 600000, 350000, 750000]
})
# Split the dataset into features and target
X = data[[‘rooms’, ‘area’]]
y = data[‘price’]
# Train the model
model = RandomForestRegressor()
model.fit(X, y)
# Save the model using pickle
with open(‘house_price_model.pkl’, ‘wb’) as file:
pickle.dump(model, file)
print(“Model saved!”)
In this example:
- We train a Random Forest model to predict house prices based on the number of rooms and the area of the house.
- After training, we save the model using pickle.
2. Flask Application for Prediction
Now, let’s build a simple Flask application that will load this pre-trained model and serve predictions via an API endpoint.
Flask Application for Deployment:
from flask import Flask, request, jsonify
import pickle
import numpy as np
app = Flask(__name__)
# Load the pre-trained model
with open(‘house_price_model.pkl’, ‘rb’) as model_file:
model = pickle.load(model_file)
@app.route(‘/predict’, methods=[‘POST’])
def predict():
# Get the user input from the POST request
try:
rooms = float(request.form[‘rooms’])
area = float(request.form[‘area’])
except KeyError:
return jsonify({‘error’: ‘Invalid input. Please provide “rooms” and “area”‘}), 400
# Prepare the input for the model
input_data = np.array([[rooms, area]])
# Make the prediction
prediction = model.predict(input_data)
# Return the prediction as JSON
return jsonify({‘predicted_price’: prediction[0]})
if __name__ == “__main__”:
app.run(debug=True)
Explanation of the Flask Code:
- Loading the Model: We load the pre-trained model using pickle. This model is the one we saved earlier.
with open(‘house_price_model.pkl’, ‘rb’) as model_file:
model = pickle.load(model_file)
- Creating the /predict Endpoint: We define a route ‘/predict’ to handle POST requests. The user will send input (such as the number of rooms and area), and the model will return a predicted house price.
- Handling User Input: The input data (number of rooms and area) is extracted from the form data of the request. If the required fields are missing or incorrect, the API returns a 400 error with an appropriate message.
- Making the Prediction: After preparing the user input, we feed it to the model for prediction.
input_data = np.array([[rooms, area]])
prediction = model.predict(input_data)
- Returning the Result: The predicted price is returned as a JSON response, so the user can easily consume it in a front-end or other applications.
- Running the Flask Application
To run the Flask application, execute the script in your terminal:
python app.py
Once the application is running, you can make POST requests to the /predict endpoint with data for house features (rooms and area). For example, using Postman or cURL, you can send a POST request:
cURL Example:
curl -X POST -F “rooms=4” -F “area=1500” http://127.0.0.1:5000/predict
The response will look like this:
{
“predicted_price”: 490000.0
}
Summary of the Steps:
- Train a model using scikit-learn (e.g., RandomForestRegressor).
- Save the model with pickle or joblib.
- Create a Flask API that loads the saved model and exposes a /predict endpoint.
- Handle POST requests and return predictions in JSON format.
Frequently Asked Questions
- How do I deploy multiple models with Flask?
- Answer: You can deploy multiple models by loading each model into a separate variable and creating different routes for each model. For example, you can have /predict_house_price for one model and /predict_loan_approval for another.
- How do I ensure the model works with new data?
- Answer: Ensure that the features in the new data match the features used during training. You may need to preprocess the data before passing it into the model (e.g., scaling, encoding categorical variables).
- Can I deploy models trained using libraries other than scikit-learn (e.g., TensorFlow, PyTorch)?
- Answer: Yes, you can deploy models from other libraries (e.g., TensorFlow, PyTorch). Simply load the model using the appropriate library (e.g., tensorflow.keras.models.load_model for Keras models) and follow the same process for serving predictions.
- How do I handle large model files?
- Answer: If your model file is large, consider using a cloud service to store the model (e.g., AWS S3, Google Cloud Storage) and load it dynamically when needed. Alternatively, use model compression techniques to reduce the file size.
- Can I use Flask for real-time prediction in production?
- Answer: Flask can be used for real-time predictions, but it may not be the best choice for very high-throughput production environments. For more scalable solutions, consider using FastAPI or deploying the model with cloud-based solutions such as AWS Lambda or Google Cloud Functions.
