Meritshot Tutorials

  1. Home
  2. »
  3. Example: Deploying a TensorFlow/Keras Model

Flask Tutorial

Example: Deploying a TensorFlow/Keras Model

In this section, we will go through the steps of deploying a pre-trained TensorFlow/Keras model using Flask. The objective is to create a simple API endpoint that receives input data, applies the trained model to make predictions, and returns the result to the user.

We will walk through an example of deploying a Keras model trained on the MNIST dataset (handwritten digit classification). This example will show how to load the model, process input data, and return predictions via a Flask API.

Steps to Deploy a TensorFlow/Keras Model with Flask

  1. Train and Save the Keras Model
    First, we need to train a simple Keras model and save it using TensorFlow’s model.save() function. This model will then be loaded into the Flask application for making predictions.

Training and Saving the Keras Model:

import tensorflow as tf

from tensorflow.keras import layers, models

from tensorflow.keras.datasets import mnist

# Load and preprocess the MNIST dataset

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

train_images = train_images / 255.0  # Normalize to [0, 1]

test_images = test_images / 255.0   # Normalize to [0, 1]

# Build a simple neural network model

model = models.Sequential([

    layers.Flatten(input_shape=(28, 28)),

    layers.Dense(128, activation=’relu’),

    layers.Dropout(0.2),

    layers.Dense(10, activation=’softmax’)

])

# Compile the model

model.compile(optimizer=’adam’,         loss=’sparse_categorical_crossentropy’,

              metrics=[‘accuracy’])

# Train the model

model.fit(train_images, train_labels, epochs=5)

# Save the model

model.save(‘mnist_model.h5’)

print(“Model saved!”)

In this example:

    • We use the MNIST dataset, which consists of 28×28 pixel grayscale images of handwritten digits.
    • The neural network model has one hidden layer with 128 units and uses dropout to prevent overfitting.
    • After training the model, we save it to a file (mnist_model.h5) using TensorFlow’s model.save() method.

2. Flask Application for Prediction

Next, we create a Flask application that loads the saved Keras model and serves predictions through an API endpoint.

Flask Application for Deployment:

from flask import Flask, request, jsonify

import tensorflow as tf

import numpy as np

from PIL import Image

import io

app = Flask(__name__)

# Load the pre-trained Keras model

model = tf.keras.models.load_model(‘mnist_model.h5’)

@app.route(‘/predict’, methods=[‘POST’])

def predict():

    try:

        # Get the image file from the request

        img_file = request.files[‘image’]

        # Read the image and convert it to grayscale

        img = Image.open(io.BytesIO(img_file.read())).convert(‘L’)

        img = img.resize((28, 28))  # Resize to 28×28 pixels

        # Convert the image to a numpy array

        img_array = np.array(img) / 255.0  # Normalize the image

        img_array = img_array.reshape(1, 28, 28)  # Reshape for model input

    except Exception as e:

        return jsonify({‘error’: str(e)}), 400

    # Make a prediction

    predictions = model.predict(img_array)

    predicted_class = np.argmax(predictions)

    # Return the prediction as JSON

    return jsonify({‘predicted_digit’: int(predicted_class)})

if __name__ == “__main__”:

    app.run(debug=True)

Explanation of the Flask Code:

  • Loading the Model: The pre-trained Keras model is loaded using tf.keras.models.load_model().

model = tf.keras.models.load_model(‘mnist_model.h5’)

  • Creating the /predict Endpoint: The /predict route is defined to handle POST requests, where the user uploads an image to be classified.
  • Handling User Input: The image is uploaded as part of the request (using the request.files object). It is then opened, converted to grayscale, resized to 28×28 pixels (the input size expected by the MNIST model), and normalized by dividing by 255.
  • Making the Prediction: The image is passed to the model for prediction.

predictions = model.predict(img_array)

predicted_class = np.argmax(predictions)

  • Returning the Result: The predicted class (digit) is returned as a JSON response. This response can then be consumed by the frontend or any other client application.
  1. Running the Flask Application
    To run the Flask application, execute the script in your terminal:

python app.py

Once the application is running, you can send a POST request to the /predict endpoint with an image of a handwritten digit. The response will contain the predicted digit.

Example using cURL:

curl -X POST -F “image=@digit.png” http://127.0.0.1:5000/predict

Response:

{

    “predicted_digit”: 7

}

Summary of the Steps:

  1. Train a Keras model on a dataset (e.g., MNIST).
  2. Save the model using model.save() in the .h5 format.
  3. Create a Flask API that loads the saved model and exposes a /predict endpoint.
  4. Handle image input in the POST request, preprocess it, and pass it to the model for prediction.
  5. Return the prediction as a JSON response.

Frequently Asked Questions

  1. How do I deploy a TensorFlow model for multiple classes?
    • Answer: The process is similar to deploying a binary classification model. The model’s output layer should have the same number of units as the number of classes, and the predictions will be handled using the argmax() function to return the class with the highest probability.
  2. Can I use other image formats (e.g., JPEG) for predictions?
    • Answer: Yes, you can use other image formats (e.g., JPEG, PNG). Ensure the image is processed correctly (e.g., converting to grayscale and resizing) before passing it to the model.
  3. How can I improve the accuracy of my model?
    • Answer: To improve accuracy, you can experiment with different neural network architectures, hyperparameter tuning, data augmentation, or training the model with more data.
  4. What should I do if the model takes too long to load or make predictions?
    • Answer: You can optimize the model by using techniques like model quantization, pruning, or converting it to TensorFlow Lite for better performance in production environments.
  5. Can I deploy models other than Keras with Flask?
    • Answer: Yes, Flask can be used to deploy models trained with other libraries like scikit-learn, PyTorch, or XGBoost. The process is the same: load the model into Flask and expose an API endpoint for predictions.