Image classification is one of the most used cases when we think about Artificial Intelligence, Machine Learning, or Deep Learning. There are many ways to develop an Image Classification model, like Pytorch, Tensorflow, Fastai, etc. A less famous framework on top of Tensorflow is TFLite Model Maker, developed by Google.
In this blog post, I will guide you step-by-step to develop an Image Classification model using TFLite Model Maker. You can read more about it here.
To follow this blog end-to-end, you need to set up a new environment on your computer. However, it is not compulsory to use your local machine, you can train a model on, let’s say Google Colab and download the trained model to server the requests for classification [it is out of the scope for this blog, maybe in my next blog I will cover this].
NOTE - it is not compulsory but if you can use VS Code to write and run the code, it will be very easy.
Create a virtual environment
 open terminal[linux or mac] or cmd tool[windows] navigate to the directory where you want to keep the project files and run
python3 -m venv tutorial-env
tutorial-env is the name of the virtual environment.
You can get more help here.
 once the virtual environment is created, activate the virtual environment by running
on windowns run
on mac/linux run
Install required packages
Once the virtual environment is activated run the following command to get the required packages…
pip install tflite-model-maker matplotlib numpy ipykernel
NOTE - this will take some time to complete, and depends on your laptop or pc and connection speed.
Get the Dataset
In any machine learning or deep learning work, a dataset is as important as choosing the right model. There are again many ways to get the dataset, Kaggle is one of the most famous sources of the dataset. However, in this blog, I am going to show you a not much famous way to get a dataset. It is Roboflow, you need to open an account with Roboflow, and it is a simple process.
Once you login into
Roboflow you will see the following screen.
You can download any Image classification dataset, I am going to use Fruit Dataset for this blog. Download it on your local machine, otherwise, you can also download it on a Google Colab notebook by the link they give you.
Once you download the dataset extract it into the folder where you want to write the code and create a new file
train_model.ipynb and let’s start writing the code.
## Develop a model
Now we are ready to start the training. Let’s import the required packages before we start the training process
from typing import Tuple from tflite_model_maker.image_classifier import DataLoader from tflite_model_maker import image_classifier import tensorflow as tf import numpy as np import matplotlib.pyplot as plt %matplotlib inline
DataLoader class is helpful to load the data from the folder path, and
image_classifier will create the model and train it. You will wonder why I have imported a package from
typing, you will find it out in a while…
Let’s load the data, not into the memory, but differently, load it to the memory when you need it kind of a way.
= DataLoader.from_folder('Fruits_Dataset/train') data = data.split(0.8) train_data, rest_data = rest_data.split(0.5) validation_data, test_data
Provide the path where you see only the different class folders. We will split the data into three parts, 80% training, 10% validation, and the remaining 10% testing.
Visualise the data
Now, let’s see what our data look like…
=(10, 10)) plt.figure(figsizefor i, (image, label) in enumerate(data.gen_dataset().unbatch().take(5)): 5, 5, i+1) plt.subplot( plt.xticks() plt.yticks()False) plt.grid(=plt.cm.gray) plt.imshow(image.numpy(), cmap='white') plt.xlabel(data.index_to_label[label.numpy()], color plt.show()
It will look something like this…
We have 6 different classes to classify.
Train the model
Training a model with the TFLite Model Maker library is piece of cake and this is the best thing about the library. You write only one line of code and it’s done. Here is how you can do it as well.
= image_classifier.create( model train_data,='efficientnet_lite0', model_spec=5, epochs=validation_data validation_data )
This piece of code is self-explanatory but still, man this is a blog,
- the first argument is the dataset.
- second is the model, at this point in time TFLite Model Maker supports
Resnet50, you can choose one of them and you can read more about it here.
- numbers of epochs, keep this a small number for testing the library then, once you got the idea, you can train your model for a longer period of time.
- provide the validation data to validate the model during the training period.
NOTE - there are two more arguments that I must explain here, one train_whole_model, in case you want to train the whole model, and the second use_augmentation, if you have a small number of samples for training, then you should use this.
The training time will depend on the hardware configuration and number of samples but for this Fruit Dataset roughly it will take 10 minutes on my 5 years old laptop.
Test the model
Now let’s test the performance of the model…
= model.evaluate(test_data) loss, accuracy = model.predict_top_k(test_data, k=2) predicts
Also, in order to visualize the performance of the model, we will run the following piece of code.
def get_label_color(val1, val2): if val1 == val2: return 'white' else: return 'red' =(10, 10)) plt.figure(figsize for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(10)): = plt.subplot(2, 5, i+1) ax plt.xticks() plt.yticks()False) plt.grid(=plt.cm.gray) plt.imshow(image.numpy(), cmap= predicts[i] predict_label = get_label_color(predict_label, color test_data.index_to_label[label.numpy()]) ax.xaxis.label.set_color(color) plt.xlabel(predict_label) plt.show()
This will generate a plot like this… It will show you a label in red color in case the predictions is wrong.
Save the model
The trained model can be saved on your local computer, even this lite model can be transferred to mobile devices, just run the following piece of code.
model.export(='.', export_dir='fruit_classifier.tflite' tflite_filename )
export_dir is the path where you want to save the model, and
tflite_filename is the name of the model file, make sure the model has an extension of
Prediction on a single image
I have seen this in almost every blog on the internet, they show everything except one thing, how to get the prediction for a single image, you don’t have to worry anymore, I will show it for sure, so here you go.
= 'fruit_classifier.tflite' MODEL_PATH def get_interpreter(model_path: str) -> Tuple: = tf.lite.Interpreter(model_path=model_path) interpreter interpreter.allocate_tensors() = interpreter.get_input_details() input_details = interpreter.get_output_details() output_details return interpreter, input_details, output_details def predict(image_path: str) -> int: = get_interpreter(MODEL_PATH) interpreter, input_details, output_details = input_details['shape'] input_shape = tf.io.read_file(image_path) img = tf.io.decode_image(img, channels=3) img = tf.image.resize(img, (input_shape, input_shape)) img = tf.expand_dims(img, axis=0) img = tf.cast(img, dtype=tf.uint8) resized_img 0]['index'], resized_img) interpreter.set_tensor(input_details[ interpreter.invoke() = interpreter.get_tensor(output_details['index']) output_data = np.squeeze(output_data) results return np.argmax(results, axis=0)
If you see the above code you will find that why I have imported
typing module, yes, I have written all the code with the latest function writing techniques.
get_interpreter returns the
interpreter, it a fancy word for
model and some input and output details to get ready the image for testing and prediction.
The second function
predict will take the image path and returns the
The code explained here in this blog can be replicated for any image classification problem.
Where is the code?
The code used in this blog can be downloaded from my Github.
## About me
Raj Kapadia, I am passionate about
AI/ML/DL and their use in different domains, I also love to build
Google Dialogflow. For any work, you can reach out to me at…