Meritshot Tutorials

  1. Home
  2. »
  3. Flask and RESTful APIs for ML

Flask Tutorial

Flask and RESTful APIs for ML

8.1 Building a REST API for Predictions

In this section, we will learn how to build a RESTful API using Flask to serve machine learning (ML) model predictions. A RESTful API allows users to interact with your ML model over HTTP, making it easy to integrate your model with other applications, including web or mobile applications.

We will go through the process of setting up the Flask application to handle HTTP requests, define RESTful routes, and return predictions in a structured format such as JSON.

What is a REST API?

A REST API (Representational State Transfer API) is an architectural style for designing networked applications. It uses HTTP requests to perform CRUD operations (Create, Read, Update, Delete) on resources. For machine learning models, we typically create GET and POST methods for receiving inputs and providing outputs.

In our case, the resource will be the model prediction, and users will be able to interact with the model by sending input data to the API and receiving a response with predictions.

Steps to Build a REST API for Predictions Using Flask

  1. Set up the Flask Application
    First, let’s create a simple Flask application that will expose a RESTful API endpoint for predictions. We will assume the model is already trained and saved.

Flask Application Setup:

from flask import Flask, request, jsonify

import torch

from torchvision import models, transforms

from PIL import Image

import io

app = Flask(__name__)

# Load the pre-trained PyTorch model (replace with your trained model)

model = models.resnet18(pretrained=True)  # Example model

model.eval()  # Set to evaluation mode for inference

# Define image transformation for input

transform = transforms.Compose([

    transforms.Resize(224),

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

])

@app.route(‘/predict’, methods=[‘POST’])

def predict():

    try:

        # Get image file from the request

        img_file = request.files[‘image’]

        img = Image.open(io.BytesIO(img_file.read()))

        # Preprocess the image

        img_tensor = transform(img).unsqueeze(0)

        # Make prediction

        with torch.no_grad():

            outputs = model(img_tensor)

            _, predicted_class = torch.max(outputs, 1)

        # Return the prediction in JSON format

        return jsonify({‘predicted_class’: int(predicted_class.item())})

    except Exception as e:

        return jsonify({‘error’: str(e)}), 400

if __name__ == “__main__”:

    app.run(debug=True)

In this Flask app:

  • /predict endpoint: Accepts a POST request where an image is sent via the request body, and the server responds with the predicted class of the image.
  • Image Processing: The image is processed to fit the input format expected by the PyTorch model.
  • Prediction: The model performs inference on the processed image, and the class with the highest probability is returned.

Explanation of Code:

  • Flask Setup: We start by setting up a basic Flask application and importing the necessary libraries (torch, PIL, transform).
  • Model Loading: In this example, we use a pre-trained ResNet18 model from torchvision.models. This can be replaced with your own pre-trained model.
  • Image Transformation: The transform variable holds a series of transformations that are applied to the input image to ensure it is correctly formatted for the model.
  • /predict Endpoint: The route /predict accepts POST requests with an image file, which is processed and passed to the model for prediction. The result is returned in a JSON response.

Testing the API:

To test the REST API, we can send an HTTP POST request with an image using tools like Postman or cURL.

Example using Postman:

  • URL: http://127.0.0.1:5000/predict
  • Method: POST

Body: Choose form-data and attach an image file under the Key: image.

Example using cURL:

curl -X POST -F “image=@image.jpg” http://127.0.0.1:5000/predict

The expected response should be in JSON format, containing the predicted class ID:

{

   “predicted_class”: 5

}

Explanation of the Flask Code:

  • Loading the Model: The saved PyTorch model is loaded using model.load_state_dict().

model.load_state_dict(torch.load(‘cifar10_model.pth’))

  • Creating the /predict Endpoint: The /predict route is defined to handle POST requests, where the user uploads an image to be classified.
  • Handling User Input: The image is uploaded as part of the request (using the request.files object). It is then opened, resized, and normalized to match the input requirements of the model.
  • Making the Prediction: The image is passed to the model for prediction.

outputs = model(img_tensor)

_, predicted_class = torch.max(outputs, 1)

  • Returning the Result: The predicted class is returned as a JSON response. This response can then be consumed by the frontend or any other client application.
  1. Running the Flask Application
    To run the Flask application, execute the script in your terminal:

python app.py

Once the application is running, you can send a POST request to the /predict endpoint with an image from the CIFAR-10 dataset. The response will contain the predicted class.

Example using cURL:

curl -X POST -F “image=@image.png” http://127.0.0.1:5000/predict

Response:

{

    “predicted_class”: 3

}

Flask API Routing for ML Models

Flask allows us to define different routes to handle different HTTP methods. In this case, we defined a route for predictions (/predict), but you can expand this API to include other functionalities, such as:

  • Health check endpoint (to check if the model is available).
  • Model status (to check if the model is loaded and ready).

Here’s how you can add a health check route:

@app.route(‘/health’, methods=[‘GET’])

def health_check():

    return jsonify({‘status’: ‘Model is running!’})

This simple route will allow users to check if the Flask server and model are working.

Deployment Considerations

  1. Model Size: If the model is large (e.g., a deep neural network), ensure your deployment server has enough memory (RAM) to load the model.
  2. Concurrency: Flask is not designed for handling heavy concurrent requests. For production use, consider deploying with a more robust server such as Gunicorn or uWSGI.
  3. Error Handling: You should implement error handling for cases when the input data is invalid, or the model encounters an issue.

Frequently Asked Questions

  1. Can I use this API for other types of ML models (e.g., Regression, NLP models)?
    • Answer: Yes, the same principles apply to other types of models like regression or NLP models. You just need to modify the model loading, data preprocessing, and prediction steps according to the type of model you’re using.
  2. How can I test my REST API before integrating with the frontend?
    • Answer: You can use Postman, Insomnia, or cURL to send requests to your Flask API and view the responses. These tools allow you to simulate real user interactions with your API.
  3. What should I do if my model is too large to fit in memory?
    • Answer: You can optimize the model (e.g., using model quantization or pruning) to reduce its size or use model compression techniques. Alternatively, you can host the model on a dedicated cloud service.
  4. Can I deploy this API in the cloud?
    • Answer: Yes, you can deploy this Flask API in cloud services like AWS (using Elastic Beanstalk or Lambda), Google Cloud (using App Engine), or Heroku. They provide easy ways to deploy Flask applications with minimal configuration.
  5. How can I add authentication to my API?
    • Answer: You can implement API key or OAuth 2.0 authentication to secure your REST API. Flask extensions like Flask-JWT-Extended can help with handling authentication using JWT tokens.