How is Transfer Learning done in Neural Networks and Convolutional Neural Networks?

  • Rohit Dwivedi
  • Jun 23, 2020
  • Deep Learning
How is Transfer Learning done in Neural Networks and Convolutional Neural Networks? title banner

Humans often use their gained knowledge through experience to be applied to new tasks that they have never done before. We check whether we know something related to the problem studied in the past and then apply everything we know. When you started coding you only knew C language but when you started to learn python, you used the knowledge of C as well. Experience in the C language was used to study python language.

 

Transfer Learning is a process where a model built for a problem is again reused for another problem based on some factors. It is the most common approach used in Computer Vision and Natural Language Processing where models are used as the starting point for other problems so as to save the time. Using Transfer Learning you can achieve good results if the data is kind of similar to the data on which model is trained. 

 

For example, consider the ResNet model of Computer Vision that is trained on the “ImageNet” dataset having 14 million images that can be used for Image Classification tasks directly with changes in fully-connected layers only for deciding the output. 

 

 

Features of Transfer Learning

 

  • It involves transfer of knowledge that is grasped in one source task to learn and refine the related target task.

  • It has been observed that DNN trained on the natural images shows a strange occurrence where the first layer of the network appears to learn features alike to Gabor filters. 

  • These first layers features are found to be general features for many datasets.

  • Those features in first layers ignoring the image dataset, task, loss function, are considered general features.

  • If some of the layers of convnet learn the similar type of features, the filters learned in these layers can be reused by different convnet. This is called Transfer learning in convents and it works very well.

  • Using transfer learning saves your training time, resources and also defeats the problem of small data availability.

 

Keras has a variety of architectures that can be used for transfer learning. You can refer here.

 

 

Different scenarios of Transfer Learning

 

  1. ConvNet as fixed feature extractor

  • Remove the fully connected layers from the pre-trained convnet trained on imagenet data.

  • Freeze the rest of the convnet layers to be used as feature extractor on the new data.

  • Train a linear classifier like Softmax for the new dataset on extracting the features for all images.

 

  1. Fine-tuning the ConvNet. 

  • Retrain the classifier on the new dataset and tune the weights of the pretrained model by resuming back propagation.

  • Either tune all the layers of the ConvNet or tune the high level portion of the network and keep a rest of the layers fixed.

  • Before convnet used to take more generic features like edge that was handy in many problems, but later on layers become steadily more particular to the information about the classes that are present in the original data.

 

  1. Pretrained models

  • ConvNet can take several days to get trained. Many have delivered their final convent checkpoints open source so that it is beneficial for others.

  • Those weights can be used and fine-tuned to save training time.

 

You can check the Caffe library that has a Model Zoo where people have delivered their network weights that can be used. 

 

Different Transfer Learning Strategies 

 

  1. The target dataset is similar to the base training dataset and is small in size

Using a deep net based convnet on such a dataset can result in overfitting of the model. The macro features of the pre-trained are likely to be important for this data also as the data is alike to the base model data. 

 

Strategy 

  • Freeze the weights of covenant and make them in use for extracting the features.

  • Detach the fully connected layers from the pre-trained base convnet and add new fully connected layers according to the number of classes in the target dataset.

  • Freeze all the weight from the pretrained model and randomize the weights of this new F.C layer.

  • Update the weights of the new fully connected layers by training this new network.

 

  1. The training dataset is similar to the base training dataset and is large in size

 

In this case, the data we have is large in nature so it is assumed that there won't be any overfitting of the model here if the fine tuning is performed in the whole full network.

 

Strategy 

  • Detach the fully connected layers from the pre-trained base convnet and add new fully connected layers according to the number of classes in the target dataset.

  • Initialize the weights in the fully connected layer randomly.

  • Using the pre-trained weights, Initialize the rest of the weights.

  • Train the entire network.  

 

  1. The target dataset is different from the base training dataset and is small in nature

 

Overfitting is again a concern in this type of situation where the data is small.The higher level features in Convnet would not give any appropriate results on the target dataset as the target dataset is not similar to the base dataset.So, the network will only make use of the first few layers of the base convnet.

 

Strategy 

  • Remove all the layers of the pre-trained network that are there in the beginning of the ConvNet.

  • Add the rest pre-trained layers new F.C layers according to the number of classes in the new dataset.

  • Freeze all the weight from the pretrained model and randomize the weights of this new F.C layer.

  • Update the weights of the new fully connected layers by training the network.

 

  1. The target dataset is different from the base training dataset and is large in nature

 

We can train the convnet from scratch, due to the large size of the target dataset and dissimilarity from the base dataset. But, it is good practice to initialize the weights from the pre-trained model and fine tune them that can result in the training faster. This is also termed as Domain Adaptation

 

Strategy 

  • Remove all the layers of the pre-trained network that are there in the beginning of the ConvNet.

  • Add the rest pre-trained layers new F.C layers according to the number of classes in the new dataset.

  • Freeze all the weight from the pretrained model and randomize the weights of this new F.C layer.

  • Update the weights of the new fully connected layers by training the network

 

 

Summary


Strategies to decide for Transfer learning quickly.

Overview of Transfer Learning Strategies: Image Source


The above images show the overview of different transfer learning strategies that have been discussed above on the basis of similarity of the data and size of the data of the base model and targeted data.

 

 

Conclusion

 

Transfer learning is the process that is widely used today in training deep neural networks so as to save the training time and resources and get good results. 

 

There are different factors on which different transfer learning strategies are used based on size of the data and similarity of the data that are discussed in the blog. Also in the last section of the blog I have discussed the summary by which one can quickly decide about the strategy that is to be used for transfer learning.

0%

Rohit Dwivedi

Data Science enthusiast who is currently pursuing a Post Graduate Program in Machine learning and Artificial Intelligence from Great Leaning. He has experience in Data Analytics, Machine Learning, Neural Networks, Computer Vision, and Natural Language Processing. He has done various good projects in the domain of analytics. His goal is to build various use cases using the power of Artificial Intelligence and Machine Learning and solving business problems.

Trending blogs

  • Introduction to Time Series Analysis: Time-Series Forecasting Machine learning Methods & Models

    READ MORE
  • How is Artificial Intelligence (AI) Making TikTok Tick?

    READ MORE
  • 7 Types of Activation Functions in Neural Network

    READ MORE
  • 7 types of regression techniques you should know in Machine Learning

    READ MORE
  • 6 Major Branches of Artificial Intelligence (AI)

    READ MORE
  • Introduction to Logistic Regression - Sigmoid Function, Code Explanation

    READ MORE
  • What is K-means Clustering in Machine Learning?

    READ MORE
  • Top 10 Big Data Technologies in 2020

    READ MORE
  • Introduction to Linear Discriminant Analysis in Supervised Learning

    READ MORE
  • Convolutional Neural Network (CNN): Graphical Visualization with Code Explanation

    READ MORE
Write a BLOG