Building a Pokemon Classifier
Here I try to build a pokemon classifier, to understand transfer learning for image classification
In this notebook, I used the pokemon images dataset from here but unfortuantely it is not available now.
from fastai.vision import *
from fastai.metrics import error_rate
from fastai.callbacks.tracker import ReduceLROnPlateauCallback, SaveModelCallback
from fastai.callbacks import CSVLogger
path = Path(".")
Form data bunch object from the folders.
data = ImageDataBunch.from_folder(path, train=".",
ds_tfms=get_transforms(),
size=128, bs=64, valid_pct=0.2).normalize(imagenet_stats)
Check the number of different pokemon images that we have.
len(data.classes)
learn = cnn_learner(data, models.resnet18, metrics=error_rate).mixup().to_fp16()
Adding callbacks to monitor the training process and
- Reduce the
learning_rate
by using theReduceLROnPlateauCallback
. - Saving the model on every improvement in
error_rate
- Log the training stats in a csv file.
callbacks_list = [
ReduceLROnPlateauCallback(learn=learn, monitor='error_rate', factor=1e-6, patience=5, min_delta=1e-5),
SaveModelCallback(learn, mode="min", every='improvement', monitor='error_rate', name='best'),
CSVLogger(learn=learn, append=True)
]
Now, All the setup has been made, Let's train the model with default parameters, for 15
epochs.
learn.fit_one_cycle(15, callbacks=callbacks_list)
Now that we have got some decent accuracy let us try to save the model and interpret from it.
In the following cell, I
- Load the
best
weights saved by the callbacks during training. - Convert the model back to use 32 bit precision.
- Export the model as a whole.
- Export the
weights
alone.
learn.load("best");
learn.to_fp32()
learn.export("pokemon_resnet18_st1.pkl")
learn.save("pokemon_resnet18_st1_wgts")
It is very important that we get to know what the model has learnt from the training process. We can do that with the help of ClassificationInterpretation
class from the fastai
library.
interp = ClassificationInterpretation.from_learner(learn)
# Get the instances where the model has made the most error (by loss value) in the validation set.
losses,idxs = interp.top_losses()
# Check whether the values are all of same length as the validation set
len(data.valid_ds)==len(losses)==len(idxs)
Interpret the images where the model made errors during the validation.
The cell below shows
- the image.
- the model's prediction of that image.
- the actual label of that image.
- the loss and probability(the extent to which the model is sure about it's prediction).
You can notice that the image has some of it's regions blighted, as far I know these are the regions that the model looked at to make the prediction for the corresponding image.
interp.plot_top_losses(9, figsize=(15,11))
Let us also see which pokemon have confused the model the most.
interp.most_confused(min_val=3)
Apart from the 2nd one in this list, You can see why the model was confused generally, most of it's confusion stem from the evolved species of the same pokemon.
Let's try to train the model a little bit differently this time.
learn.load('best');
Till now we have been training only the tail region of the model (i.e.) only the last two/ three layers of our model, so essentially this model is almost same as the model which was pretrained on 1000
categories of the ImageNet
dataset with some minor tweaks for our problem here. We have some options to improve the model, which are
- Train all the layers so that the model can adapt to the current classification problem. We do that by
unfreeze()
. - Train with a very low learning rate so that it does'nt forget the learnings from the pretrained weights.
Let's see how well we can improve the model.
learn.to_fp16()
learn.unfreeze()
Before we start training again, We need to figure out at what speed the neural network should learn, this is controlled by the learning rate parameter and finding a value for is crucial to the training process.
Luckily the fastai's lr_find
method will help us do just the same.
learn.lr_find(start_lr=1e-20)
# Plot the learning rates and the corresponding losses.
learn.recorder.plot(suggestion=True)
# Get the suggested learning rate
min_grad_lr = learn.recorder.min_grad_lr
Use the same callbacks as before and train for 30
epochs.
learn.fit_one_cycle(30, min_grad_lr, callbacks=callbacks_list)
We can see that the model has improved slightly but not much, other ways that we can try are
- Try using a different architecture rather than resnet18.
- Add more Image augmentation methods (even though
fastai
has some reasonable defaults).
Persist the environment so that we would be able to deploy the model without any problems
!pip freeze > resnet18.txt
That's it for this post, Please share it if you have found it useful. Don't hesitate to leave a comment if you find that any of my explanation needs some clarification.