Training#
Training a simple Transformer model#
We start by training a simple Transformer model on the Setfit/emotion
dataset. We chose microsoft/xtremedistil-l6-h256-uncased
as it is a relatively lightweight base model.
Note: in the following sections we limit the number of training steps in the Trainer
as it is a simple demo code but you will need to increase (or unset) the max_steps
parameter to achieve decent performance.
[1]:
from bert_squeeze.assistants import TrainAssistant
from lightning.pytorch import Trainer
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
from .autonotebook import tqdm as notebook_tqdm
[2]:
config_assistant = {
"name": "bert",
"train_kwargs": {
"objective": "ce"
},
"model_kwargs": {
"pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
"num_labels": 6
},
"data_kwargs": {
"max_length": 64,
"tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
"dataset_config": {
"path": "Setfit/emotion"
}
}
}
[3]:
assistant = TrainAssistant(**config_assistant)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
[4]:
model = assistant.model
train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 365.50it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-6e5f71b3bbea4a96.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-5a6336aaca39c01e.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-6f41d000bddb1cc5.arrow
DatasetDict({
train: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 16000
})
test: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
validation: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
})
[5]:
basic_trainer = Trainer(
max_steps=10
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[6]:
basic_trainer.fit(
model=model,
train_dataloaders=train_dataloader,
val_dataloaders=test_dataloader
)
| Name | Type | Params
------------------------------------------------
0 | objective | CrossEntropyLoss | 0
1 | encoder | CustomBertModel | 12.8 M
2 | classifier | Sequential | 67.8 K
------------------------------------------------
12.8 M Trainable params
0 Non-trainable params
12.8 M Total params
51.272 Total estimated model params size (MB)
Sanity Checking DataLoader 0: 0%| | 0/2 [00:00<?, ?it/s]
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
rank_zero_warn(
Epoch 0: 0%|▎ | 1/500 [00:01<09:27, 1.14s/it, v_num=14]
/Users/julesbelveze/Desktop/bert-squeeze/bert_squeeze/utils/optimizers/bert_adam.py:226: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
next_m.mul_(beta1).add_(1 - beta1, grad)
Epoch 0: 2%|███▌ | 10/500 [00:10<08:34, 1.05s/it, v_num=14]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0: 2%|███▌ | 10/500 [00:10<08:54, 1.09s/it, v_num=14]
Training FastBert#
Fine-tuning a FastBert
model is as easy as fine-tuning a regular BERT. The only difference is that you need to use the FastBertLogic
callback. The callback is in charge of freezing the model’s backbone after some steps.
[7]:
config_assistant_fastbert = {
"name": "fastbert",
"train_kwargs": {
"objective": "ce"
},
"model_kwargs": {
"pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
"num_labels": 6
},
"data_kwargs": {
"max_length": 64,
"tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
"dataset_config": {
"path": "Setfit/emotion"
}
}
}
fastbert_assistant = TrainAssistant(**config_assistant_fastbert)
basic_trainer = Trainer(
max_steps=10,
callbacks=fastbert_assistant.callbacks
)
basic_trainer.fit(
model=fastbert_assistant.model,
train_dataloaders=fastbert_assistant.data.train_dataloader(),
val_dataloaders=fastbert_assistant.data.test_dataloader()
)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 554.02it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-59ba575e7730665b.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-889acfd77e30ecff.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-e076430c3e6abb98.arrow
| Name | Type | Params
------------------------------------------------
0 | objective | CrossEntropyLoss | 0
1 | embeddings | BertEmbeddings | 7.9 M
2 | encoder | FastBertGraph | 6.7 M
------------------------------------------------
14.7 M Trainable params
0 Non-trainable params
14.7 M Total params
58.669 Total estimated model params size (MB)
DatasetDict({
train: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 16000
})
test: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
validation: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
})
Epoch 0: 2%|███▌ | 10/500 [00:11<09:11, 1.13s/it, v_num=15]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0: 2%|███▌ | 10/500 [00:11<09:24, 1.15s/it, v_num=15]
Training TheseusBert#
Similarly, fine-tuning a TheseusBert
model is as simple as fine-tuning a regular BERT. For TheseusBert
you do not even need to use a callback. The submodules are indeed substituted through a scheduler.
[8]:
config_assistant_fastbert = {
"name": "theseusbert",
"train_kwargs": {
"objective": "ce"
},
"model_kwargs": {
"pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
"num_labels": 6
},
"data_kwargs": {
"max_length": 64,
"tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
"dataset_config": {
"path": "Setfit/emotion"
}
}
}
fastbert_assistant = TrainAssistant(**config_assistant_fastbert)
basic_trainer = Trainer(
max_steps=10
)
basic_trainer.fit(
model=fastbert_assistant.model,
train_dataloaders=fastbert_assistant.data.train_dataloader(),
val_dataloaders=fastbert_assistant.data.test_dataloader()
)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of TheseusBertModel were not initialized from the model checkpoint at microsoft/xtremedistil-l6-h256-uncased and are newly initialized: ['encoder.successor_layers.2.attention.self.key.weight', 'encoder.successor_layers.2.attention.output.dense.weight', 'encoder.successor_layers.3.attention.output.dense.bias', 'encoder.successor_layers.2.attention.self.key.bias', 'encoder.successor_layers.5.attention.output.dense.bias', 'encoder.successor_layers.1.attention.self.query.bias', 'encoder.successor_layers.1.attention.self.value.bias', 'encoder.successor_layers.2.attention.output.LayerNorm.weight', 'encoder.successor_layers.1.output.dense.bias', 'encoder.successor_layers.2.attention.output.dense.bias', 'encoder.successor_layers.3.intermediate.dense.bias', 'encoder.successor_layers.3.attention.self.query.weight', 'encoder.successor_layers.1.attention.self.key.weight', 'encoder.successor_layers.3.attention.output.dense.weight', 'encoder.successor_layers.4.attention.self.value.bias', 'encoder.successor_layers.2.intermediate.dense.bias', 'encoder.successor_layers.1.intermediate.dense.bias', 'encoder.successor_layers.5.output.dense.bias', 'encoder.successor_layers.0.attention.self.key.bias', 'encoder.successor_layers.1.attention.self.value.weight', 'encoder.successor_layers.5.attention.self.query.weight', 'encoder.successor_layers.0.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.output.dense.bias', 'encoder.successor_layers.0.attention.output.LayerNorm.bias', 'encoder.successor_layers.4.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.self.value.bias', 'encoder.successor_layers.3.attention.self.key.weight', 'encoder.successor_layers.0.output.dense.bias', 'encoder.successor_layers.3.attention.self.value.weight', 'encoder.successor_layers.1.intermediate.dense.weight', 'encoder.successor_layers.3.output.dense.weight', 'encoder.successor_layers.1.attention.output.dense.weight', 'encoder.successor_layers.0.attention.self.query.bias', 'encoder.successor_layers.2.attention.self.query.bias', 'encoder.successor_layers.5.attention.output.LayerNorm.weight', 'encoder.successor_layers.1.attention.output.dense.bias', 'encoder.successor_layers.1.attention.output.LayerNorm.weight', 'encoder.successor_layers.2.intermediate.dense.weight', 'encoder.successor_layers.1.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.key.bias', 'encoder.successor_layers.0.attention.self.key.weight', 'encoder.successor_layers.3.attention.self.query.bias', 'encoder.successor_layers.4.attention.output.dense.weight', 'encoder.successor_layers.5.attention.self.query.bias', 'encoder.successor_layers.0.attention.output.dense.weight', 'encoder.successor_layers.5.output.dense.weight', 'encoder.successor_layers.4.attention.self.query.bias', 'encoder.successor_layers.0.attention.self.query.weight', 'encoder.successor_layers.2.output.dense.bias', 'encoder.successor_layers.0.output.LayerNorm.weight', 'encoder.successor_layers.4.output.LayerNorm.weight', 'encoder.successor_layers.3.intermediate.dense.weight', 'encoder.successor_layers.1.attention.self.key.bias', 'encoder.successor_layers.2.output.LayerNorm.bias', 'encoder.successor_layers.4.attention.self.key.bias', 'encoder.successor_layers.0.intermediate.dense.weight', 'encoder.successor_layers.1.output.dense.weight', 'encoder.successor_layers.5.output.LayerNorm.bias', 'encoder.successor_layers.4.output.dense.bias', 'encoder.successor_layers.3.attention.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.self.value.weight', 'encoder.successor_layers.3.attention.self.key.bias', 'encoder.successor_layers.5.attention.output.LayerNorm.bias', 'encoder.successor_layers.4.output.dense.weight', 'encoder.successor_layers.4.attention.output.LayerNorm.weight', 'encoder.successor_layers.4.attention.self.value.weight', 'encoder.successor_layers.2.attention.self.value.bias', 'encoder.successor_layers.4.attention.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.value.bias', 'encoder.successor_layers.0.attention.output.LayerNorm.weight', 'encoder.successor_layers.3.output.dense.bias', 'encoder.successor_layers.4.attention.output.dense.bias', 'encoder.successor_layers.3.output.LayerNorm.weight', 'encoder.successor_layers.2.output.LayerNorm.weight', 'encoder.successor_layers.4.attention.self.key.weight', 'encoder.successor_layers.5.attention.output.dense.weight', 'encoder.successor_layers.0.output.dense.weight', 'encoder.successor_layers.3.output.LayerNorm.bias', 'encoder.successor_layers.4.attention.self.query.weight', 'encoder.successor_layers.2.attention.self.value.weight', 'encoder.successor_layers.0.intermediate.dense.bias', 'encoder.successor_layers.4.intermediate.dense.weight', 'encoder.successor_layers.5.intermediate.dense.bias', 'encoder.successor_layers.5.output.LayerNorm.weight', 'encoder.successor_layers.2.attention.self.query.weight', 'encoder.successor_layers.3.attention.self.value.bias', 'encoder.successor_layers.1.attention.output.LayerNorm.bias', 'encoder.successor_layers.2.output.dense.weight', 'encoder.successor_layers.5.intermediate.dense.weight', 'encoder.successor_layers.4.intermediate.dense.bias', 'encoder.successor_layers.1.attention.self.query.weight', 'encoder.successor_layers.3.attention.output.LayerNorm.weight', 'encoder.successor_layers.5.attention.self.value.weight', 'encoder.successor_layers.1.output.LayerNorm.weight', 'encoder.successor_layers.2.attention.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.key.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 378.92it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-59ba575e7730665b.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-889acfd77e30ecff.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-e076430c3e6abb98.arrow
| Name | Type | Params
------------------------------------------------
0 | objective | CrossEntropyLoss | 0
1 | encoder | TheseusBertModel | 17.5 M
2 | classifier | Sequential | 67.8 K
------------------------------------------------
17.6 M Trainable params
0 Non-trainable params
17.6 M Total params
70.226 Total estimated model params size (MB)
DatasetDict({
train: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 16000
})
test: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
validation: Dataset({
features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
num_rows: 2000
})
})
Epoch 0: 2%|███▌ | 10/500 [00:10<08:34, 1.05s/it, v_num=16]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0: 2%|███▌ | 10/500 [00:10<08:51, 1.08s/it, v_num=16]
[ ]: