Training#

Training a simple Transformer model#

We start by training a simple Transformer model on the Setfit/emotion dataset. We chose microsoft/xtremedistil-l6-h256-uncased as it is a relatively lightweight base model.

Note: in the following sections we limit the number of training steps in the Trainer as it is a simple demo code but you will need to increase (or unset) the max_steps parameter to achieve decent performance.

[1]:
from bert_squeeze.assistants import TrainAssistant
from lightning.pytorch import Trainer
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
[2]:
config_assistant = {
    "name": "bert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}
[3]:
assistant = TrainAssistant(**config_assistant)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
[4]:
model = assistant.model

train_dataloader = assistant.data.train_dataloader()
test_dataloader = assistant.data.test_dataloader()
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 365.50it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-6e5f71b3bbea4a96.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-5a6336aaca39c01e.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-6f41d000bddb1cc5.arrow
DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})

[5]:
basic_trainer = Trainer(
    max_steps=10
)
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[6]:
basic_trainer.fit(
    model=model,
    train_dataloaders=train_dataloader,
    val_dataloaders=test_dataloader
)

  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0
1 | encoder    | CustomBertModel  | 12.8 M
2 | classifier | Sequential       | 67.8 K
------------------------------------------------
12.8 M    Trainable params
0         Non-trainable params
12.8 M    Total params
51.272    Total estimated model params size (MB)
Sanity Checking DataLoader 0:   0%|                                                                                                                                                                                     | 0/2 [00:00<?, ?it/s]
/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, val_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(

/Users/julesbelveze/Desktop/bert-squeeze/.venv/bert_squeeze/lib/python3.8/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:430: PossibleUserWarning: The dataloader, train_dataloader, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 8 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
  rank_zero_warn(
Epoch 0:   0%|▎                                                                                                                                                                                     | 1/500 [00:01<09:27,  1.14s/it, v_num=14]
/Users/julesbelveze/Desktop/bert-squeeze/bert_squeeze/utils/optimizers/bert_adam.py:226: UserWarning: This overload of add_ is deprecated:
        add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
        add_(Tensor other, *, Number alpha) (Triggered internally at /Users/runner/work/pytorch/pytorch/pytorch/torch/csrc/utils/python_arg_parser.cpp:1485.)
  next_m.mul_(beta1).add_(1 - beta1, grad)
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:34,  1.05s/it, v_num=14]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:54,  1.09s/it, v_num=14]

Training FastBert#

Fine-tuning a FastBert model is as easy as fine-tuning a regular BERT. The only difference is that you need to use the FastBertLogic callback. The callback is in charge of freezing the model’s backbone after some steps.

[7]:
config_assistant_fastbert = {
    "name": "fastbert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}

fastbert_assistant = TrainAssistant(**config_assistant_fastbert)

basic_trainer = Trainer(
    max_steps=10,
    callbacks=fastbert_assistant.callbacks
)

basic_trainer.fit(
    model=fastbert_assistant.model,
    train_dataloaders=fastbert_assistant.data.train_dataloader(),
    val_dataloaders=fastbert_assistant.data.test_dataloader()
)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 554.02it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-59ba575e7730665b.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-889acfd77e30ecff.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-e076430c3e6abb98.arrow


  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0
1 | embeddings | BertEmbeddings   | 7.9 M
2 | encoder    | FastBertGraph    | 6.7 M
------------------------------------------------
14.7 M    Trainable params
0         Non-trainable params
14.7 M    Total params
58.669    Total estimated model params size (MB)
DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:11<09:11,  1.13s/it, v_num=15]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:11<09:24,  1.15s/it, v_num=15]

Training TheseusBert#

Similarly, fine-tuning a TheseusBert model is as simple as fine-tuning a regular BERT. For TheseusBert you do not even need to use a callback. The submodules are indeed substituted through a scheduler.

[8]:
config_assistant_fastbert = {
    "name": "theseusbert",
    "train_kwargs": {
        "objective": "ce"
    },
    "model_kwargs": {
        "pretrained_model": "microsoft/xtremedistil-l6-h256-uncased",
        "num_labels": 6
    },
    "data_kwargs": {
        "max_length": 64,
        "tokenizer_name": "microsoft/xtremedistil-l6-h256-uncased",
        "dataset_config": {
            "path": "Setfit/emotion"
        }
    }
}

fastbert_assistant = TrainAssistant(**config_assistant_fastbert)

basic_trainer = Trainer(
    max_steps=10
)

basic_trainer.fit(
    model=fastbert_assistant.model,
    train_dataloaders=fastbert_assistant.data.train_dataloader(),
    val_dataloaders=fastbert_assistant.data.test_dataloader()
)
WARNING:root:Found value for `dataset_config.path` which conflicts with parameter `dataset_path`, usingvalue from the later.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Some weights of TheseusBertModel were not initialized from the model checkpoint at microsoft/xtremedistil-l6-h256-uncased and are newly initialized: ['encoder.successor_layers.2.attention.self.key.weight', 'encoder.successor_layers.2.attention.output.dense.weight', 'encoder.successor_layers.3.attention.output.dense.bias', 'encoder.successor_layers.2.attention.self.key.bias', 'encoder.successor_layers.5.attention.output.dense.bias', 'encoder.successor_layers.1.attention.self.query.bias', 'encoder.successor_layers.1.attention.self.value.bias', 'encoder.successor_layers.2.attention.output.LayerNorm.weight', 'encoder.successor_layers.1.output.dense.bias', 'encoder.successor_layers.2.attention.output.dense.bias', 'encoder.successor_layers.3.intermediate.dense.bias', 'encoder.successor_layers.3.attention.self.query.weight', 'encoder.successor_layers.1.attention.self.key.weight', 'encoder.successor_layers.3.attention.output.dense.weight', 'encoder.successor_layers.4.attention.self.value.bias', 'encoder.successor_layers.2.intermediate.dense.bias', 'encoder.successor_layers.1.intermediate.dense.bias', 'encoder.successor_layers.5.output.dense.bias', 'encoder.successor_layers.0.attention.self.key.bias', 'encoder.successor_layers.1.attention.self.value.weight', 'encoder.successor_layers.5.attention.self.query.weight', 'encoder.successor_layers.0.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.output.dense.bias', 'encoder.successor_layers.0.attention.output.LayerNorm.bias', 'encoder.successor_layers.4.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.self.value.bias', 'encoder.successor_layers.3.attention.self.key.weight', 'encoder.successor_layers.0.output.dense.bias', 'encoder.successor_layers.3.attention.self.value.weight', 'encoder.successor_layers.1.intermediate.dense.weight', 'encoder.successor_layers.3.output.dense.weight', 'encoder.successor_layers.1.attention.output.dense.weight', 'encoder.successor_layers.0.attention.self.query.bias', 'encoder.successor_layers.2.attention.self.query.bias', 'encoder.successor_layers.5.attention.output.LayerNorm.weight', 'encoder.successor_layers.1.attention.output.dense.bias', 'encoder.successor_layers.1.attention.output.LayerNorm.weight', 'encoder.successor_layers.2.intermediate.dense.weight', 'encoder.successor_layers.1.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.key.bias', 'encoder.successor_layers.0.attention.self.key.weight', 'encoder.successor_layers.3.attention.self.query.bias', 'encoder.successor_layers.4.attention.output.dense.weight', 'encoder.successor_layers.5.attention.self.query.bias', 'encoder.successor_layers.0.attention.output.dense.weight', 'encoder.successor_layers.5.output.dense.weight', 'encoder.successor_layers.4.attention.self.query.bias', 'encoder.successor_layers.0.attention.self.query.weight', 'encoder.successor_layers.2.output.dense.bias', 'encoder.successor_layers.0.output.LayerNorm.weight', 'encoder.successor_layers.4.output.LayerNorm.weight', 'encoder.successor_layers.3.intermediate.dense.weight', 'encoder.successor_layers.1.attention.self.key.bias', 'encoder.successor_layers.2.output.LayerNorm.bias', 'encoder.successor_layers.4.attention.self.key.bias', 'encoder.successor_layers.0.intermediate.dense.weight', 'encoder.successor_layers.1.output.dense.weight', 'encoder.successor_layers.5.output.LayerNorm.bias', 'encoder.successor_layers.4.output.dense.bias', 'encoder.successor_layers.3.attention.output.LayerNorm.bias', 'encoder.successor_layers.0.attention.self.value.weight', 'encoder.successor_layers.3.attention.self.key.bias', 'encoder.successor_layers.5.attention.output.LayerNorm.bias', 'encoder.successor_layers.4.output.dense.weight', 'encoder.successor_layers.4.attention.output.LayerNorm.weight', 'encoder.successor_layers.4.attention.self.value.weight', 'encoder.successor_layers.2.attention.self.value.bias', 'encoder.successor_layers.4.attention.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.value.bias', 'encoder.successor_layers.0.attention.output.LayerNorm.weight', 'encoder.successor_layers.3.output.dense.bias', 'encoder.successor_layers.4.attention.output.dense.bias', 'encoder.successor_layers.3.output.LayerNorm.weight', 'encoder.successor_layers.2.output.LayerNorm.weight', 'encoder.successor_layers.4.attention.self.key.weight', 'encoder.successor_layers.5.attention.output.dense.weight', 'encoder.successor_layers.0.output.dense.weight', 'encoder.successor_layers.3.output.LayerNorm.bias', 'encoder.successor_layers.4.attention.self.query.weight', 'encoder.successor_layers.2.attention.self.value.weight', 'encoder.successor_layers.0.intermediate.dense.bias', 'encoder.successor_layers.4.intermediate.dense.weight', 'encoder.successor_layers.5.intermediate.dense.bias', 'encoder.successor_layers.5.output.LayerNorm.weight', 'encoder.successor_layers.2.attention.self.query.weight', 'encoder.successor_layers.3.attention.self.value.bias', 'encoder.successor_layers.1.attention.output.LayerNorm.bias', 'encoder.successor_layers.2.output.dense.weight', 'encoder.successor_layers.5.intermediate.dense.weight', 'encoder.successor_layers.4.intermediate.dense.bias', 'encoder.successor_layers.1.attention.self.query.weight', 'encoder.successor_layers.3.attention.output.LayerNorm.weight', 'encoder.successor_layers.5.attention.self.value.weight', 'encoder.successor_layers.1.output.LayerNorm.weight', 'encoder.successor_layers.2.attention.output.LayerNorm.bias', 'encoder.successor_layers.5.attention.self.key.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
WARNING:datasets.builder:Using custom data configuration Setfit--emotion-89147fdf376d67e2
WARNING:datasets.builder:Reusing dataset json (/Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 378.92it/s]
INFO:root:Dataset 'Setfit/emotion' successfully loaded.
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-59ba575e7730665b.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-889acfd77e30ecff.arrow
WARNING:datasets.arrow_dataset:Loading cached processed dataset at /Users/julesbelveze/.cache/huggingface/datasets/json/Setfit--emotion-89147fdf376d67e2/0.0.0/ac0ca5f5289a6cf108e706efcf040422dbbfa8e658dee6a819f20d76bb84d26b/cache-e076430c3e6abb98.arrow


  | Name       | Type             | Params
------------------------------------------------
0 | objective  | CrossEntropyLoss | 0
1 | encoder    | TheseusBertModel | 17.5 M
2 | classifier | Sequential       | 67.8 K
------------------------------------------------
17.6 M    Trainable params
0         Non-trainable params
17.6 M    Total params
70.226    Total estimated model params size (MB)
DatasetDict({
    train: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 16000
    })
    test: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
    validation: Dataset({
        features: ['labels', 'label_text', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2000
    })
})
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:34,  1.05s/it, v_num=16]
`Trainer.fit` stopped: `max_steps=10` reached.
Epoch 0:   2%|███▌                                                                                                                                                                                 | 10/500 [00:10<08:51,  1.08s/it, v_num=16]
[ ]: