dnn_cool_train_loop
Here, you can find a full list of the things you can do with dnn_cool_train_loop
.
dnn_cool_train_loop
is a microlib, which allows you to train multi-task neural networks.
The train/valid loop is done by predicting always with batch size equal to 1, and using gradient accumulation to
simulate the effective batch size.
This means that custom collator is not needed, even for inputs of different shapes, but training is a little slower.
The dataset should be a dict
-based dataset, which has two additional keys: idx
(the index in the global dataset)
and tasks
- the targets, for which a backward pass should be done.
Contents
Main API
Train for one epoch
To train for one epoch, use the train_for_one_epoch
function.
For example:
from dnn_cool_train_loop.train_loop import train_for_one_epoch
train_for_one_epoch(
epoch=0,
model=model,
train_dataset=train_dataset,
criterion_dict=criterion_dict,
batch_size_effective=batch_size_effective,
metrics_dir=metrics_dir,
metrics_dict=metrics_dict,
lr=lr,
)
Validate for one epoch
To validate for one epoch, use the train_for_one_epoch
function.
For example:
from dnn_cool_train_loop.valid_loop import validate_for_one_epoch
validate_for_one_epoch(
epoch=0,
model=model,
valid_dataset=valid_dataset,
criterion_dict=criterion_dict,
batch_size_effective=batch_size_effective,
metrics_dir=metrics_dir,
metrics_dict=metrics_dict,
)
Full training loop
To run a full training loop, you may just use the train_for_one_epoch
and validate_for_one_epoch
functions with functools.partial
and iterate over epochs.
An example would be:
train_fn = partial(
train_for_one_epoch,
model=model,
train_dataset=train_dataset,
criterion_dict=criterion_dict,
batch_size_effective=batch_size_effective,
metrics_dir=metrics_dir,
metrics_dict=metrics_dict,
lr=lr,
)
valid_fn = partial(
validate_for_one_epoch,
model=model,
valid_dataset=valid_dataset,
criterion_dict=criterion_dict,
batch_size_effective=batch_size_effective,
metrics_dir=metrics_dir,
metrics_dict=metrics_dict,
)
valid_fn("init")
for i in range(5):
train_fn(i)
valid_fn(i)