Meritshot Tutorials
- Home
- »
- Example: Deploying a TensorFlow/Keras Model
Flask Tutorial
-
Introduction to Flask for Machine LearningIntroduction to Flask for Machine Learning
-
Why Use Flask to Deploy ML Models?Why Use Flask to Deploy ML Models?
-
Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)Flask vs. Other Deployment Tools (FastAPI, Django, Streamlit)
-
Setting Up the EnvironmentSetting Up the Environment
-
Basics of FlaskBasics of Flask
-
Flask Application StructureFlask Application Structure
-
Running the Development ServerRunning the Development Server
-
Debug ModeDebug Mode
-
Preparing Machine Learning Models for DeploymentPreparing Machine Learning Models for Deployment
-
Saving the Trained ModelSaving the Trained Model
-
Loading the Saved Model in PythonLoading the Saved Model in Python
-
Understanding Routes and EndpointsUnderstanding Routes and Endpoints
-
Setting Up API Endpoints for PredictionSetting Up API Endpoints for Prediction
-
Flask Templates and Jinja2 BasicsFlask Templates and Jinja2 Basics
-
Creating a Simple HTML Form for User InputCreating a Simple HTML Form for User Input
-
Connecting the Frontend to the BackendConnecting the Frontend to the Backend
-
Handling Requests and ResponsesHandling Requests and Responses
-
Accepting User Input for PredictionsAccepting User Input for Predictions
-
Returning Predictions as JSON or HTMLReturning Predictions as JSON or HTML
-
Deploying a Pretrained Model with FlaskDeploying a Pretrained Model with Flask
-
Example: Deploying a TensorFlow/Keras ModelExample: Deploying a TensorFlow/Keras Model
-
Example: Deploying a PyTorch ModelExample: Deploying a PyTorch Model
-
Flask and RESTful APIs for MLFlask and RESTful APIs for ML
-
Serving JSON ResponsesServing JSON Responses
-
Testing API Endpoints with PostmanTesting API Endpoints with Postman
-
Handling Real-World ScenariosHandling Real-World Scenarios
-
Scaling ML Model Predictions for Large InputsScaling ML Model Predictions for Large Inputs
-
Batch Predictions vs. Single PredictionsBatch Predictions vs. Single Predictions
-
Adding Authentication and SecurityAdding Authentication and Security
-
Adding API Authentication (Token-Based)Adding API Authentication (Token-Based)
-
Protecting Sensitive DataProtecting Sensitive Data
-
Deploying Flask ApplicationsDeploying Flask Applications
-
Deploying on HerokuDeploying on Heroku
-
Deploying on AWS, GCP, or AzureDeploying on AWS, GCP, or Azure
-
Containerizing Flask Apps with DockerContainerizing Flask Apps with Docker
Example: Deploying a TensorFlow/Keras Model
In this section, we will go through the steps of deploying a pre-trained TensorFlow/Keras model using Flask. The objective is to create a simple API endpoint that receives input data, applies the trained model to make predictions, and returns the result to the user.
We will walk through an example of deploying a Keras model trained on the MNIST dataset (handwritten digit classification). This example will show how to load the model, process input data, and return predictions via a Flask API.
Steps to Deploy a TensorFlow/Keras Model with Flask
- Train and Save the Keras Model
First, we need to train a simple Keras model and save it using TensorFlow’s model.save() function. This model will then be loaded into the Flask application for making predictions.
Training and Saving the Keras Model:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
# Load and preprocess the MNIST dataset
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
train_images = train_images / 255.0 # Normalize to [0, 1]
test_images = test_images / 255.0 # Normalize to [0, 1]
# Build a simple neural network model
model = models.Sequential([
layers.Flatten(input_shape=(28, 28)),
layers.Dense(128, activation=’relu’),
layers.Dropout(0.2),
layers.Dense(10, activation=’softmax’)
])
# Compile the model
model.compile(optimizer=’adam’, loss=’sparse_categorical_crossentropy’,
metrics=[‘accuracy’])
# Train the model
model.fit(train_images, train_labels, epochs=5)
# Save the model
model.save(‘mnist_model.h5’)
print(“Model saved!”)
In this example:
- We use the MNIST dataset, which consists of 28×28 pixel grayscale images of handwritten digits.
- The neural network model has one hidden layer with 128 units and uses dropout to prevent overfitting.
- After training the model, we save it to a file (mnist_model.h5) using TensorFlow’s model.save() method.
2. Flask Application for Prediction
Next, we create a Flask application that loads the saved Keras model and serves predictions through an API endpoint.
Flask Application for Deployment:
from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np
from PIL import Image
import io
app = Flask(__name__)
# Load the pre-trained Keras model
model = tf.keras.models.load_model(‘mnist_model.h5’)
@app.route(‘/predict’, methods=[‘POST’])
def predict():
try:
# Get the image file from the request
img_file = request.files[‘image’]
# Read the image and convert it to grayscale
img = Image.open(io.BytesIO(img_file.read())).convert(‘L’)
img = img.resize((28, 28)) # Resize to 28×28 pixels
# Convert the image to a numpy array
img_array = np.array(img) / 255.0 # Normalize the image
img_array = img_array.reshape(1, 28, 28) # Reshape for model input
except Exception as e:
return jsonify({‘error’: str(e)}), 400
# Make a prediction
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions)
# Return the prediction as JSON
return jsonify({‘predicted_digit’: int(predicted_class)})
if __name__ == “__main__”:
app.run(debug=True)
Explanation of the Flask Code:
- Loading the Model: The pre-trained Keras model is loaded using tf.keras.models.load_model().
model = tf.keras.models.load_model(‘mnist_model.h5’)
- Creating the /predict Endpoint: The /predict route is defined to handle POST requests, where the user uploads an image to be classified.
- Handling User Input: The image is uploaded as part of the request (using the request.files object). It is then opened, converted to grayscale, resized to 28×28 pixels (the input size expected by the MNIST model), and normalized by dividing by 255.
- Making the Prediction: The image is passed to the model for prediction.
predictions = model.predict(img_array)
predicted_class = np.argmax(predictions)
- Returning the Result: The predicted class (digit) is returned as a JSON response. This response can then be consumed by the frontend or any other client application.
- Running the Flask Application
To run the Flask application, execute the script in your terminal:
python app.py
Once the application is running, you can send a POST request to the /predict endpoint with an image of a handwritten digit. The response will contain the predicted digit.
Example using cURL:
curl -X POST -F “image=@digit.png” http://127.0.0.1:5000/predict
Response:
{
“predicted_digit”: 7
}
Summary of the Steps:
- Train a Keras model on a dataset (e.g., MNIST).
- Save the model using model.save() in the .h5 format.
- Create a Flask API that loads the saved model and exposes a /predict endpoint.
- Handle image input in the POST request, preprocess it, and pass it to the model for prediction.
- Return the prediction as a JSON response.
Frequently Asked Questions
- How do I deploy a TensorFlow model for multiple classes?
- Answer: The process is similar to deploying a binary classification model. The model’s output layer should have the same number of units as the number of classes, and the predictions will be handled using the argmax() function to return the class with the highest probability.
- Can I use other image formats (e.g., JPEG) for predictions?
- Answer: Yes, you can use other image formats (e.g., JPEG, PNG). Ensure the image is processed correctly (e.g., converting to grayscale and resizing) before passing it to the model.
- How can I improve the accuracy of my model?
- Answer: To improve accuracy, you can experiment with different neural network architectures, hyperparameter tuning, data augmentation, or training the model with more data.
- What should I do if the model takes too long to load or make predictions?
- Answer: You can optimize the model by using techniques like model quantization, pruning, or converting it to TensorFlow Lite for better performance in production environments.
- Can I deploy models other than Keras with Flask?
- Answer: Yes, Flask can be used to deploy models trained with other libraries like scikit-learn, PyTorch, or XGBoost. The process is the same: load the model into Flask and expose an API endpoint for predictions.
