Distillation + Pruning + Quantization#
The following code snippets allow you to distil a microsoft/xtremedistil-l6-h384-uncased
teacher into a microsoft/xtremedistil-l6-h256-uncased
student.
At the end of the training the student is then dynamically quantized and some weights are pruned based on their magnitude.
If you do not want to perform quantization and/or pruning simply remove the corresponding callback from the configuration.
[1]:
from bert_squeeze.assistants import DistilAssistant
from lightning.pytorch import Trainer
[2]:
# We are using xtremedistil because they are lightweight models but feel free
# to change it to the base model you want.
config_assistant = {
"name": "distil",
"teacher_kwargs": {
"pretrained_model": "microsoft/xtremedistil-l6-h384-uncased",
"num_labels": 2
},
"student_kwargs": {
"pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
"num_labels": 2
},
"data_kwargs": {
"teacher_module": {
"dataset_config": {
"path": "linxinyuan/cola",
}
}
},
"callbacks": [
{
"_target_": "bert_squeeze.utils.callbacks.pruning.ThresholdBasedPruning",
"threshold": 0.2,
"start_pruning_epoch": -1
},
{
"_target_": "bert_squeeze.utils.callbacks.quantization.DynamicQuantization"
}
]
}
[3]:
assistant = DistilAssistant(**config_assistant)
model = assistant.model
callbacks = assistant.callbacks
train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()
basic_trainer = Trainer(
max_steps=2,
callbacks=callbacks
)
basic_trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=test_dataloader
)
WARNING:datasets.builder:Using custom data configuration default
WARNING:datasets.builder:Reusing dataset mind (/Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4)
INFO:root:Dataset 'linxinyuan/cola' successfully loaded.
WARNING:datasets.builder:Using custom data configuration default
WARNING:datasets.builder:Reusing dataset mind (/Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4)
INFO:root:Dataset 'linxinyuan/cola' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4/cache-e56c581864af9d3c.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4/cache-dc929e775e97a506.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4/cache-545d3be6487c373a.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/linxinyuan___mind/default/0.0.0/0871d55203d4de46ef1815400998ed8f219236694f0d03786bde849741f04cd4/cache-3037a309a170717b.arrow
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
| Name | Type | Params
----------------------------------------------------
0 | teacher | LtCustomBert | 22.9 M
1 | student | LtCustomBert | 12.8 M
2 | loss_ce | LabelSmoothingLoss | 0
3 | loss_distill | MSELoss | 0
----------------------------------------------------
35.7 M Trainable params
0 Non-trainable params
35.7 M Total params
142.718 Total estimated model params size (MB)
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/Users/julesbelveze/Desktop/bert-squeeze/bert_squeeze/utils/optimizers/bert_adam.py:226: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
next_m.mul_(beta1).add_(1 - beta1, grad)
`Trainer.fit` stopped: `max_steps=2` reached.
Pruning model...
INFO:root:Model quantized and saved - size (MB): 142.818233
[ ]: