Image classification is one of the most used cases when we think about Artificial Intelligence, Machine Learning, or Deep Learning. There are many ways to develop an Image Classification model, like Pytorch, Tensorflow, Fastai, etc. A less famous framework on top of Tensorflow is TFLite Model Maker, developed by Google.
In this blog post, I will guide you step-by-step to develop an Image Classification model using TFLite Model Maker. You can read more about it here.
## Requirements
To follow this blog end-to-end, you need to set up a new environment on your computer. However, it is not compulsory to use your local machine, you can train a model on, let’s say Google Colab and download the trained model to server the requests for classification [it is out of the scope for this blog, maybe in my next blog I will cover this].
NOTE - it is not compulsory but if you can use VS Code to write and run the code, it will be very easy.
Create a virtual environment
[1] open terminal[linux or mac] or cmd tool[windows] navigate to the directory where you want to keep the project files and run
python3 -m venv tutorial-env
here, tutorial-env
is the name of the virtual environment.
You can get more help here.
[2] once the virtual environment is created, activate the virtual environment by running
on windowns run
tutorial-env\Scripts\activate.bat
on mac/linux run
source tutorial-env/bin/activate
Install required packages
Once the virtual environment is activated run the following command to get the required packages…
pip install tflite-model-maker matplotlib numpy ipykernel
NOTE - this will take some time to complete, and depends on your laptop or pc and connection speed.
Get the Dataset
In any machine learning or deep learning work, a dataset is as important as choosing the right model. There are again many ways to get the dataset, Kaggle is one of the most famous sources of the dataset. However, in this blog, I am going to show you a not much famous way to get a dataset. It is Roboflow, you need to open an account with Roboflow, and it is a simple process.
Once you login into Roboflow
you will see the following screen.
You can download any Image classification dataset, I am going to use Fruit Dataset for this blog. Download it on your local machine, otherwise, you can also download it on a Google Colab notebook by the link they give you.
Once you download the dataset extract it into the folder where you want to write the code and create a new file train_model.ipynb
and let’s start writing the code.
## Develop a model
Now we are ready to start the training. Let’s import the required packages before we start the training process
from typing import Tuple
from tflite_model_maker.image_classifier import DataLoader
from tflite_model_maker import image_classifier
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
The DataLoader
class is helpful to load the data from the folder path, and image_classifier
will create the model and train it. You will wonder why I have imported a package from typing
, you will find it out in a while…
Create dataloader
Let’s load the data, not into the memory, but differently, load it to the memory when you need it kind of a way.
= DataLoader.from_folder('Fruits_Dataset/train')
data = data.split(0.8)
train_data, rest_data = rest_data.split(0.5) validation_data, test_data
Provide the path where you see only the different class folders. We will split the data into three parts, 80% training, 10% validation, and the remaining 10% testing.
Visualise the data
Now, let’s see what our data look like…
=(10, 10))
plt.figure(figsizefor i, (image, label) in enumerate(data.gen_dataset().unbatch().take(5)):
5, 5, i+1)
plt.subplot(
plt.xticks([])
plt.yticks([])False)
plt.grid(=plt.cm.gray)
plt.imshow(image.numpy(), cmap='white')
plt.xlabel(data.index_to_label[label.numpy()], color plt.show()
It will look something like this…
We have 6 different classes to classify.
- Tomato_3
- Tomato_4
- Tomato_Yellow
- Tomato_Cherry_Red
- Tomato_Maroon
- Walnut
Train the model
Training a model with the TFLite Model Maker library is piece of cake and this is the best thing about the library. You write only one line of code and it’s done. Here is how you can do it as well.
= image_classifier.create(
model
train_data,='efficientnet_lite0',
model_spec=5,
epochs=validation_data
validation_data )
This piece of code is self-explanatory but still, man this is a blog,
- the first argument is the dataset.
- second is the model, at this point in time TFLite Model Maker supports
EfficientNetLite0-4
,MobileNet
, andResnet50
, you can choose one of them and you can read more about it here. - numbers of epochs, keep this a small number for testing the library then, once you got the idea, you can train your model for a longer period of time.
- provide the validation data to validate the model during the training period.
NOTE - there are two more arguments that I must explain here, one train_whole_model, in case you want to train the whole model, and the second use_augmentation, if you have a small number of samples for training, then you should use this.
The training time will depend on the hardware configuration and number of samples but for this Fruit Dataset roughly it will take 10 minutes on my 5 years old laptop.
Test the model
Now let’s test the performance of the model…
= model.evaluate(test_data)
loss, accuracy = model.predict_top_k(test_data, k=2) predicts
Also, in order to visualize the performance of the model, we will run the following piece of code.
def get_label_color(val1, val2):
if val1 == val2:
return 'white'
else:
return 'red'
=(10, 10))
plt.figure(figsize
for i, (image, label) in enumerate(test_data.gen_dataset().unbatch().take(10)):
= plt.subplot(2, 5, i+1)
ax
plt.xticks([])
plt.yticks([])False)
plt.grid(=plt.cm.gray)
plt.imshow(image.numpy(), cmap= predicts[i][0][0]
predict_label = get_label_color(predict_label,
color
test_data.index_to_label[label.numpy()])
ax.xaxis.label.set_color(color)
plt.xlabel(predict_label) plt.show()
This will generate a plot like this… It will show you a label in red color in case the predictions is wrong.
Save the model
The trained model can be saved on your local computer, even this lite model can be transferred to mobile devices, just run the following piece of code.
model.export(='.',
export_dir='fruit_classifier.tflite'
tflite_filename )
the parameter export_dir
is the path where you want to save the model, and tflite_filename
is the name of the model file, make sure the model has an extension of .tflite
Prediction on a single image
I have seen this in almost every blog on the internet, they show everything except one thing, how to get the prediction for a single image, you don’t have to worry anymore, I will show it for sure, so here you go.
= 'fruit_classifier.tflite'
MODEL_PATH
def get_interpreter(model_path: str) -> Tuple:
= tf.lite.Interpreter(model_path=model_path)
interpreter
interpreter.allocate_tensors()
= interpreter.get_input_details()
input_details = interpreter.get_output_details()
output_details
return interpreter, input_details, output_details
def predict(image_path: str) -> int:
= get_interpreter(MODEL_PATH)
interpreter, input_details, output_details = input_details[0]['shape']
input_shape = tf.io.read_file(image_path)
img = tf.io.decode_image(img, channels=3)
img = tf.image.resize(img, (input_shape[2], input_shape[2]))
img = tf.expand_dims(img, axis=0)
img = tf.cast(img, dtype=tf.uint8)
resized_img
0]['index'], resized_img)
interpreter.set_tensor(input_details[
interpreter.invoke()
= interpreter.get_tensor(output_details[0]['index'])
output_data = np.squeeze(output_data)
results return np.argmax(results, axis=0)
If you see the above code you will find that why I have imported Tuple
from typing
module, yes, I have written all the code with the latest function writing techniques.
The function get_interpreter
returns the interpreter
, it a fancy word for model
and some input and output details to get ready the image for testing and prediction.
The second function predict
will take the image path and returns the label
Now where?
The code explained here in this blog can be replicated for any image classification problem.
Where is the code?
The code used in this blog can be downloaded from my Github.
## About me
I am Raj Kapadia
, I am passionate about AI/ML/DL
and their use in different domains, I also love to build chatbots
using Google Dialogflow
. For any work, you can reach out to me at…