Meritshot Tutorials
- Home
- »
- Example: Deploying a PyTorch Model
Flask Tutorial
-
Introduction to Flask for Machine LearningIntroduction to Flask for Machine Learning
-
Why Use Flask to Deploy ML Models?Why Use Flask to Deploy ML Models?
-
Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)
-
Setting Up the EnvironmentSetting Up the Environment
-
Basics of FlaskBasics of Flask
-
Flask Application StructureFlask Application Structure
-
Running the Development ServerRunning the Development Server
-
Debug ModeDebug Mode
-
Preparing Machine Learning Models for DeploymentPreparing Machine Learning Models for Deployment
-
Saving the Trained ModelSaving the Trained Model
-
Loading the Saved Model in PythonLoading the Saved Model in Python
-
Understanding Routes and EndpointsUnderstanding Routes and Endpoints
-
Setting Up API Endpoints for PredictionSetting Up API Endpoints for Prediction
-
Flask Templates and Jinja2 BasicsFlask Templates and Jinja2 Basics
-
Creating a Simple HTML Form for User InputCreating a Simple HTML Form for User Input
-
Connecting the Frontend to the BackendConnecting the Frontend to the Backend
-
Handling Requests and ResponsesHandling Requests and Responses
-
Accepting User Input for PredictionsAccepting User Input for Predictions
-
Returning Predictions as JSON or HTMLReturning Predictions as JSON or HTML
-
Deploying a Pretrained Model with FlaskDeploying a Pretrained Model with Flask
-
Example: Deploying a TensorFlow/Keras ModelExample: Deploying a TensorFlow/Keras Model
-
Example: Deploying a PyTorch ModelExample: Deploying a PyTorch Model
-
Flask and RESTful APIs for MLFlask and RESTful APIs for ML
-
Serving JSON ResponsesServing JSON Responses
-
Testing API Endpoints with PostmanTesting API Endpoints with Postman
-
Handling Real-World ScenariosHandling Real-World Scenarios
-
Scaling ML Model Predictions for Large InputsScaling ML Model Predictions for Large Inputs
-
Batch Predictions vs. Single PredictionsBatch Predictions vs. Single Predictions
-
Adding Authentication and SecurityAdding Authentication and Security
-
Adding API Authentication (Token-Based)Adding API Authentication (Token-Based)
-
Protecting Sensitive DataProtecting Sensitive Data
-
Deploying Flask ApplicationsDeploying Flask Applications
-
Deploying on HerokuDeploying on Heroku
-
Deploying on AWS, GCP, or AzureDeploying on AWS, GCP, or Azure
-
Containerizing Flask Apps with DockerContainerizing Flask Apps with Docker
Example: Deploying a PyTorch Model
In this section, we will go through the steps to deploy a pre-trained PyTorch model using Flask. The objective is to create a simple API endpoint that receives input data, applies the trained model to make predictions, and returns the result to the user.
We will walk through an example of deploying a PyTorch model trained on the CIFAR-10 dataset (image classification of 10 different classes like airplanes, cars, birds, etc.). This example will show how to load the model, process input data, and return predictions via a Flask API.
Steps to Deploy a PyTorch Model with Flask
- Train and Save the PyTorch Model
First, we need to train a simple PyTorch model and save it using PyTorch’s torch.save() function. This model will then be loaded into the Flask application for making predictions.
Training and Saving the PyTorch Model:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.models as models
# Set device
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
# Transformations for data preprocessing
transform = transforms.Compose([
transforms.Resize(224), # Resize the image to 224×224
transforms.ToTensor(), # Convert to tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalize
])
# Load CIFAR-10 dataset
train_data = datasets.CIFAR10(root=’./data’, train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=32, shuffle=True)
# Define a simple CNN model
model = models.resnet18(pretrained=True) # Using pre-trained ResNet18
model.fc = nn.Linear(model.fc.in_features, 10) # Modify the final layer for 10 classes (CIFAR-10)
model = model.to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# Train the model
epochs = 5
for epoch in range(epochs):
model.train()
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f”Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}”)
# Save the trained model
torch.save(model.state_dict(), ‘cifar10_model.pth’)
print(“Model saved!”)
In this example:
- We use the CIFAR-10 dataset, which consists of 60,000 32×32 color images in 10 different classes.
- We use a pre-trained ResNet18 model from the torchvision.models library and fine-tune the final layer to classify the 10 classes of the CIFAR-10 dataset.
- After training, we save the model weights using torch.save().
2. Flask Application for Prediction
Now, let’s create a Flask application that loads the saved PyTorch model and serves predictions through an API endpoint.
Flask Application for Deployment:
from flask import Flask, request, jsonify
import torch
from torchvision import transforms
from PIL import Image
import io
app = Flask(__name__)
# Load the pre-trained PyTorch model
device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)
model = models.resnet18(pretrained=True)
model.fc = nn.Linear(model.fc.in_features, 10)
model.load_state_dict(torch.load(‘cifar10_model.pth’))
model = model.to(device)
model.eval() # Set model to evaluation mode
# Image transformation for the CIFAR-10 dataset
transform = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
@app.route(‘/predict’, methods=[‘POST’])
def predict():
try:
# Get the image file from the request
img_file = request.files[‘image’]
img = Image.open(io.BytesIO(img_file.read()))
# Apply transformations to the image
img_tensor = transform(img).unsqueeze(0).to(device)
except Exception as e:
return jsonify({‘error’: str(e)}), 400
# Make a prediction
with torch.no_grad():
outputs = model(img_tensor)
_, predicted_class = torch.max(outputs, 1)
# Return the prediction as JSON
return jsonify({‘predicted_class’: int(predicted_class.item())})
if __name__ == “__main__”:
app.run(debug=True)
Explanation of the Flask Code:
- Loading the Model: The saved PyTorch model is loaded using model.load_state_dict().
model.load_state_dict(torch.load(‘cifar10_model.pth’))
- Creating the /predict Endpoint: The /predict route is defined to handle POST requests, where the user uploads an image to be classified.
- Handling User Input: The image is uploaded as part of the request (using the request.files object). It is then opened, resized, and normalized to match the input requirements of the model.
- Making the Prediction: The image is passed to the model for prediction.
outputs = model(img_tensor)
_, predicted_class = torch.max(outputs, 1)
- Returning the Result: The predicted class is returned as a JSON response. This response can then be consumed by the frontend or any other client application.
- Running the Flask Application
To run the Flask application, execute the script in your terminal:
python app.py
Once the application is running, you can send a POST request to the /predict endpoint with an image from the CIFAR-10 dataset. The response will contain the predicted class.
Example using cURL:
curl -X POST -F “image=@image.png” http://127.0.0.1:5000/predict
Response:
{
“predicted_class”: 3
}
Summary of the Steps:
- Train a PyTorch model on a dataset (e.g., CIFAR-10).
- Save the model using torch.save() in the .pth format.
- Create a Flask API that loads the saved model and exposes a /predict endpoint.
- Handle image input in the POST request, preprocess it, and pass it to the model for prediction.
- Return the prediction as a JSON response.
Frequently Asked Questions
- Can I deploy models trained on other datasets, like ImageNet?
- Answer: Yes, you can deploy any PyTorch model that you’ve trained, as long as it is compatible with the input format expected by your model (e.g., size, color channels).
- What happens if the model fails to predict?
- Answer: If the model cannot predict correctly, ensure that the input data is preprocessed correctly and that the model is loaded properly. You can add exception handling in your Flask application to handle such errors.
- Can I deploy the model with multiple classes?
- Answer: Yes, you can deploy models with multiple classes. For example, the CIFAR-10 model is trained with 10 classes, and you can add more classes if needed by modifying the final layer of your model.
- How can I improve prediction performance on large images?
- Answer: You can resize large images before passing them to the model and consider using a model optimized for large images, such as ResNet or Inception.
- Can I use GPU for faster inference?
- Answer: Yes, if you have a GPU available, you can run the PyTorch model on the GPU by transferring the model and input tensor to the GPU using .to(device).
