Top ten ways to tackle overfitting models

Exploring the lesser-known approaches to a well-known problem

Overfitting models are high in variance, low in bias, and cannot generalize on unseen data. If the training accuracy is very high and the validation accuracy is super low, or the training loss and validation loss are very distinct from each other, it means that the model is overfitting. Below are ten techniques that I use to tackle overfitting in my models that would also work for you, and there are two more bonus points to help you get more insights on overfitting models.

Photo by Author
  • Reducing Model Complexity
    This is pretty self-explanatory. If you understand that overfitting happens when your model is predicting perfectly on the training set and poorly on the validation set then it happens because it has captured the features as well as the details of the data so it fails at generalizing the predictions. Hence most of the time reducing the model complexity helps massively so the variance in the model is reduced.
  • Regularize your layers
    The aim of regularization is to keep all of the features but impose a constraint on the magnitude of the coefficients.
    It is preferred because you can still retain the features but modify the weights assigned to those features by penalizing them. When the constraints are applied to the parameters, the model is less prone to overfit since it produces a smooth function. The penalty factors control the parameters/weights and ensure that the model is not overtraining itself on the training data. When the penalty factors are smaller in magnitude they regularize although if they are larger then they might affect the optimization function.
  • Use dropout layers
    Techniques like ridge or lasso regularization involve modifying the cost function by editing the weights to some extent, although performing dropout modifies the network itself. The common problem with overfitting is model complexity, dropout turns off several perceptrons/neurons at random along with their connections and thus reduces the model complexity further helping it to generalize.
  • Increase your data
    Overfitting occurs when the model captures the noise in the data along with the features that contribute to accurate predictions. The model that learns the noise of the training data is high variance and thus cannot generalize well. With the addition of more data, more noise and bias are introduced, thus preventing the model from overfitting.
  • Augment your data
    Augmentation is usually performed on graphical data. When you feed the data into the network, you are actually feeding the order of certain pixels that cause to form the image holistically which means a slight amount of variation in the image can possibly be a different data point for the model. Data augmentation involves several transforms like rotation, image-flip, scale transformations, etc while training. Therefore, you get multiple data points from the same set of images, all contributing to the model’s accuracy. The pre-processing section of Keras is really helpful in data augmentation because it does not demand you to edit raw images, nor does it amend them for you on-disk, the entire augmentation is performed on pre-processing during training so you can experiment without modifying your original dataset.
  • Early Stopping
    This stops training the model as soon as the model stops improving.
    In TensorFlow’s Keras API/library, you can stop it with the following code tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5) where the patience parameter waits for the number of epochs that you specify before it could stop the model from training any further and capturing the noise in the data.
  • Ensemble models
    Ensemble models can be understood with the following example in a much better way. Let’s say you’re trying to build a model to predict if the person is happy or sad, your dataset comprises of the visual, audio, and language modality each capturing the facial features (expressions of the person), audio features (the tone of the person), and the context (of what the person actually speaks) respectively.
    You have three models, one for each modality but none of them performs well, so you fuse them together either with multi-modal fusion or other ensemble techniques like cross-modality. When you perform multi-modal fusion the model learns representations from all the modalities and thus predicts much better than when trained on an individual modality.
  • Pruning
    Pruning in horticulture literally means cutting off unnecessary or dead branches. Neural Network pruning works in a much similar fashion where you turn the redundant nodes off, thus making your model simpler and light-weight. This technique applies to decision trees in two ways:
    a. Pre-pruning: This approach restricts the tree from growing before it perfectly classifies the training set.
    b. Post-pruning: This approach on the other hand allows the tree to completely grow and classify the training set and then prune the model to remove the dead branches or nodes.
    The parameters that are pruned away help the model to generalize resulting in a much higher validation accuracy than that of the network that overfits. One of the biggest advantages of pruning to regularize is that you can apply it post-training than during the training so you get to see both sides of the coin as of whether you need regularization or not.
Photo by Author [neural-network-pruning]
  • Transfer Learning
    Transfer learning targets the problem of insufficient data due to which the model can fail to generalize and thus overfit. Let’s understand this with another example, if you have to classify 30 categories of product images from crop-tops, blazers, tuxedo, shirts, tops, jeans, trousers, pants, skirts, etc, but you only have 100 images per class, leaving you with a dataset of 3,000 images and 30 categories to classify. You have created a few models but whatever model you create, even the simplest them seem to overfit and cannot classify the unseen data correctly. One way to tackle this is by using a pre-trained model which in our case is trained on the Fashion MNIST dataset which can help you capture distinctive features of products from 10 categories like from t-shirts, trousers, dress, sandal, bags, etc. with a total of 70,000 images and you can use these learned features to further retrain the model on your dataset by using the prior learned features from Fashion MNIST.
    Using transfer learning also helps the model to learn more generalized features than the ones from the data, thus integrating noise into it.
  • Multi-task Learning
    Multi-task learning in simpler terms is essentially solving two problem statements by one model. This is used to reduce or overcome overfitting mainly because learning multiple tasks simultaneously causes the obtained representation to capture information from all the tasks, thus reducing the possibility to overfit.
    As illustrated in the figure below, Model 1 and Model 2 give different outputs on the same input image because Model 1 is trained on the animal data, while Model 2 is trained on human data, but we needed to detect humans and animals simultaneously from the input image, thus we performed multi-task learning and trained the Model 3 on the animal as well as human data to then be able to detect both of them simultaneously.
Photo by Author [multi-task-learning]

Some more insights into high-variance or overfitting models:

  • Understanding the extent of overfitting:
    Although k-fold cross-validation does not significantly help to overcome overfitting, it might help you understand how much does your model overfit. Take an example of 3-fold cross-validation. Here we have three ways to split our train and test data.
    > Approach 1: [Train][Train][Test]
    > Approach 2: [Train][Test][Train]
    > Approach 3: [Test][Train][Train]
    Now, the accuracy you get from these three (considering that you are optimizing your accuracy metric) will be averaged to get the final accuracy. Following this approach, you are performing multiple hold-out validation sets to finally be able to conclude if the model actually overfits or not and if it does then to what extent.
Photo by Author [3-fold-cross-validation]
  • What should you not do:
    As Andrew Ng and many other deep learning practitioners mention very vividly, using Principal Component Analysis (PCA) for regularization is a common practice, but it does not work, because while performing PCA what you’re essentially doing is condensing the (n+1) dimensional features into n dimensions, thus you still catch all the features, and your model will still be complex, so you will have to pick the most effective ones and drop the redundant ones.

References:

[1] N. Srivastava, G. Hinton, A. Krizhevsky, I. Sutskever, and R. Salakhutdinov, Dropout: A Simple Way to Prevent Neural Networks from Overfitting (2013), Journal of Machine Learning Research 15 (2014) 1929–1958

That’s all from my side, hope you guys learned something new. Let me know which one of the above techniques you tried and if it worked for you.

Happy Coding!

Diving into AI, working towards finding a closer link in science and humanity. Presently working on devising novel depression prediction methods.