Meritshot Tutorials

  1. Home
  2. »
  3. Example: Deploying a PyTorch Model

Flask Tutorial

Example: Deploying a PyTorch Model

In this section, we will go through the steps to deploy a pre-trained PyTorch model using Flask. The objective is to create a simple API endpoint that receives input data, applies the trained model to make predictions, and returns the result to the user.

We will walk through an example of deploying a PyTorch model trained on the CIFAR-10 dataset (image classification of 10 different classes like airplanes, cars, birds, etc.). This example will show how to load the model, process input data, and return predictions via a Flask API.

Steps to Deploy a PyTorch Model with Flask

  1. Train and Save the PyTorch Model
    First, we need to train a simple PyTorch model and save it using PyTorch’s function. This model will then be loaded into the Flask application for making predictions.

Training and Saving the PyTorch Model:

import torch

import torch.nn as nn

import torch.optim as optim

from import DataLoader

from torchvision import datasets, transforms

import torchvision.models as models

# Set device

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

# Transformations for data preprocessing

transform = transforms.Compose([

    transforms.Resize(224),  # Resize the image to 224×224

    transforms.ToTensor(),    # Convert to tensor

 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize


# Load CIFAR-10 dataset

train_data = datasets.CIFAR10(root=’./data’, train=True, download=True, transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True)

# Define a simple CNN model

model = models.resnet18(pretrained=True)  # Using pre-trained ResNet18

model.fc = nn.Linear(model.fc.in_features, 10)  # Modify the final layer for 10 classes (CIFAR-10)

model =

# Loss and optimizer

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model

epochs = 5

for epoch in range(epochs):


    running_loss = 0.0

    for images, labels in train_loader:

        images, labels =,


        outputs = model(images)

        loss = criterion(outputs, labels)



        running_loss += loss.item()

    print(f”Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}”)

# Save the trained model, ‘cifar10_model.pth’)

print(“Model saved!”)

In this example:

    • We use the CIFAR-10 dataset, which consists of 60,000 32×32 color images in 10 different classes.
    • We use a pre-trained ResNet18 model from the torchvision.models library and fine-tune the final layer to classify the 10 classes of the CIFAR-10 dataset.
    • After training, we save the model weights using

2. Flask Application for Prediction

Now, let’s create a Flask application that loads the saved PyTorch model and serves predictions through an API endpoint.

Flask Application for Deployment:

from flask import Flask, request, jsonify

import torch

from torchvision import transforms

from PIL import Image

import io

app = Flask(__name__)

# Load the pre-trained PyTorch model

device = torch.device(“cuda” if torch.cuda.is_available() else “cpu”)

model = models.resnet18(pretrained=True)

model.fc = nn.Linear(model.fc.in_features, 10)


model =

model.eval()  # Set model to evaluation mode

# Image transformation for the CIFAR-10 dataset

transform = transforms.Compose([



  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])


@app.route(‘/predict’, methods=[‘POST’])

def predict():


        # Get the image file from the request

        img_file = request.files[‘image’]

        img =

        # Apply transformations to the image

        img_tensor = transform(img).unsqueeze(0).to(device)

    except Exception as e:

        return jsonify({‘error’: str(e)}), 400

    # Make a prediction

    with torch.no_grad():

        outputs = model(img_tensor)

        _, predicted_class = torch.max(outputs, 1)

    # Return the prediction as JSON

    return jsonify({‘predicted_class’: int(predicted_class.item())})

if __name__ == “__main__”:

Explanation of the Flask Code:

  • Loading the Model: The saved PyTorch model is loaded using model.load_state_dict().


  • Creating the /predict Endpoint: The /predict route is defined to handle POST requests, where the user uploads an image to be classified.
  • Handling User Input: The image is uploaded as part of the request (using the request.files object). It is then opened, resized, and normalized to match the input requirements of the model.
  • Making the Prediction: The image is passed to the model for prediction.

outputs = model(img_tensor)

_, predicted_class = torch.max(outputs, 1)

  • Returning the Result: The predicted class is returned as a JSON response. This response can then be consumed by the frontend or any other client application.
  1. Running the Flask Application
    To run the Flask application, execute the script in your terminal:


Once the application is running, you can send a POST request to the /predict endpoint with an image from the CIFAR-10 dataset. The response will contain the predicted class.

Example using cURL:

curl -X POST -F “image=@image.png”



    “predicted_class”: 3


Summary of the Steps:

  1. Train a PyTorch model on a dataset (e.g., CIFAR-10).
  2. Save the model using in the .pth format.
  3. Create a Flask API that loads the saved model and exposes a /predict endpoint.
  4. Handle image input in the POST request, preprocess it, and pass it to the model for prediction.
  5. Return the prediction as a JSON response.

Frequently Asked Questions

  1. Can I deploy models trained on other datasets, like ImageNet?
    • Answer: Yes, you can deploy any PyTorch model that you’ve trained, as long as it is compatible with the input format expected by your model (e.g., size, color channels).
  2. What happens if the model fails to predict?
    • Answer: If the model cannot predict correctly, ensure that the input data is preprocessed correctly and that the model is loaded properly. You can add exception handling in your Flask application to handle such errors.
  3. Can I deploy the model with multiple classes?
    • Answer: Yes, you can deploy models with multiple classes. For example, the CIFAR-10 model is trained with 10 classes, and you can add more classes if needed by modifying the final layer of your model.
  4. How can I improve prediction performance on large images?
    • Answer: You can resize large images before passing them to the model and consider using a model optimized for large images, such as ResNet or Inception.
  5. Can I use GPU for faster inference?
    • Answer: Yes, if you have a GPU available, you can run the PyTorch model on the GPU by transferring the model and input tensor to the GPU using .to(device).