You are currently viewing Knowledge Distillation

Knowledge Distillation

Loading

In this blog we will going to learn about the Knowledge distillation, its types and the implementation in python.

Introduction

Knowledge distillation is a technique in machine learning where knowledge is transferred from a large, complex model to a smaller, more compact model.

Knowledge Distillation

Algorithm

  • Define Teacher Network and Student Network
  • Train the teacher network
  • Train the student network in coordination with the teacher network

Benefits of using the Knowledge Distillation

  • The size of the model after distillation will be less
  • Regularization (student will mimic the behavior of the teacher and can act as regularization)
  • Transfer Learning (can transfer knowledge learned by a large pre-trained model to a smaller model trained on a related task)
  • Ensemble Learning (teacher model an ensemble of multiple models)

Types of Knowledge distillation algorithm

Response-Based Distillation

  • It is also known as Logit distillation, here the small model learns by copying the final answers (predicted probabilities) of the big model.

Feature-Based Distillation

  • The small model learns by copying the internal features or hidden-layer representations that the big model creates while processing data.

Relation-Based Distillation

  • The small model learns by copying the relationships between different data points or features that the big model has learned.

Implementation of Knowledge Distillation

The first step is to import all the necessary library.

import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

from sklearn.metrics import classification_report

from sklearn.tree import DecisionTreeClassifier
from functools import partial
from sklearn.metrics import f1_score

The second step is to load the iris dataset with the help of scikit-learn library. Next with the help of pandas library input features and target columns are defined. After that the first 5 rows are displayed with the help of head function.

dataset = load_iris()
df = pd.DataFrame(data=dataset.data, columns=dataset.feature_names)
df['target'] = dataset.target
df.head()

The next step is to split the dataset into training and testing in the ratio of 70% and 30%. After that shape of training and testing data are checked with the help of the shape function.

df_train, df_test = train_test_split(df, test_size=0.3)
df_train.shape, df_test.shape, df_train.shape[0] / (df_train.shape[0] + df_test.shape[0])
# ((105, 5), (45, 5), 0.7)

The next step is to implement the Random forest classifier algorithm and then fit it on the training data. These are the following parameters in random forest

  • n_estimator is set to 100
  • max_depth is set to 5
  • n_jobs is set to -1
rf = RandomForestClassifier(n_estimators=100, max_depth=5, n_jobs=-1)
rf.fit(df_train.drop(columns='target'), df_train['target'])

The next step is to evaluate the performance on the test dataset using classification report.

print(classification_report(df_test['target'],rf.predict(df_test.drop(columns='target'))))

The next step is to perform the Knowledge distillation by

  • Using a Random Forest (teacher) to make predictions
  • Creating a new dataset where the Random Forest’s predictions become new labels
  • Training a single Decision Tree (student) to learn from these labels
def __inp(df, exclude_columns=['target']):
    return df.drop(columns=list(set(exclude_columns) & set(df.columns)))

First of all the _inp function takes a dataframe, removes the target column and returns only the input features. By this way the model does not use the original labels.

def __out(df, target_column='target'):
    return df[target_column]

The _out function returns only the target columns.

def relable(df, model):
    df = df.copy()
    df['relabel'] = model.predict(__inp(df))
    return df

The next step is to create the relable function which will make a copy of the dataset. After that it uses the teacher model to predict the labels. Next it store the predictions in a new columns called relabel.

It means the student tree will learn to mimic the Random Forest.

# relable everything
df_train_tree = relable(df_train, rf)
df_test_tree = relable(df_test, rf)
df_tree = relable(df, rf)
df_train_tree.head()

Now in the above code we are creating a new relabeled dataset.

In the below blocks of code we are now training a Decision Tree (student model) to mimic the Random Forest (teacher model) using the relabeled data created earlier.

# Update helper functions to use relabeled data

__inp = partial(__inp, exclude_columns=['target', 'relabel'])
__rel = partial(__out, target_column='relabel')
__f1_score = partial(f1_score, average="macro")

In the above code the helper function is updated to use relabeled data.

  • __inp → returns only input features (removes both target and relabel)
  • __rel → extracts the relabel column (RandomForest’s predictions)
  • __f1_score → sets F1 score to macro average (useful for multi-class)

Next the decision tree classifier is trained on the relabeled data

dt = DecisionTreeClassifier(Overmax_depth=None, min_samples_leaf=1)
dt.fit(__inp(df_tree), __rel(df_tree))

This allows the tree to mimic the RandomForest exactly.

The next step is to evaluate on the training data (relabeled data)

# Evaluate Decision Tree performance on relabeled training data
print(classification_report(__rel(df_train_tree), dt.predict(__inp(df_train_tree))))

The next step is to evaluate decision tree on real test labels.

# Evaluate Decision Tree on the original test set labels

print(f"This shows the performance on the actual `target` values of the test set (never seen).")
print(classification_report(__out(df_test_tree), dt.predict(__inp(df_test_tree))))

Conclusion

Overall we can say that, Knowledge distillation enables a small student model to effectively learn from a larger teacher model, achieving similar performance with lower complexity, making it ideal for deployment on mobile and embedded devices.

Tell me in the comments if you have any problems regarding implementation, feel free to drop a comment, and our team will reply within 24 hours.

If you like the article and would like to support me, make sure to: