deftraining_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss # Multiple optimizers (e.g.: GANs) deftraining_step(self, batch, batch_idx, optimizer_idx): if optimizer_idx == 0: # do training_step with encoder if optimizer_idx == 1: # do training_step with decoder # Truncated back-propagation through time deftraining_step(self, batch, batch_idx, hiddens): # hiddens are the hidden states from the previous truncated backprop step ... out, hiddens = self.lstm(data, hiddens) ... return {'loss': loss, 'hiddens': hiddens}
deffit(...): on_fit_start() if global_rank == 0: # prepare data is called on GLOBAL_ZERO only prepare_data() for gpu/tpu in gpu/tpus: train_on_device(model.copy()) on_fit_end() deftrain_on_device(model): # setup is called PER DEVICE setup() configure_optimizers() on_pretrain_routine_start() for epoch in epochs: train_loop() teardown() deftrain_loop(): on_train_epoch_start() train_outs = [] for train_batch in train_dataloader(): on_train_batch_start() # ----- train_step methods ------- out = training_step(batch) train_outs.append(out) loss = out.loss backward() on_after_backward() optimizer_step() on_before_zero_grad() optimizer_zero_grad() on_train_batch_end(out) if should_check_val: val_loop() # end training epoch logs = training_epoch_end(outs) defval_loop(): model.eval() torch.set_grad_enabled(False) on_validation_epoch_start() val_outs = [] for val_batch in val_dataloader(): on_validation_batch_start() # -------- val step methods ------- out = validation_step(val_batch) val_outs.append(out) on_validation_batch_end(out) validation_epoch_end(val_outs) on_validation_epoch_end() # set up for train model.train() torch.set_grad_enabled(True)
# default used by the Trainer (no scaling of batch size) trainer = Trainer(auto_scale_batch_size=None) # run batch size scaling, result overrides hparams.batch_size trainer = Trainer(auto_scale_batch_size='binsearch') # call tune to find the batch size trainer.tune(model)
use (float) to check within a training epoch:此时这个值为一个epoch的百分比。每百分之多少测试一次。
use (int) to check every n steps (batches):每多少个batch测试一次。
1 2 3 4 5 6 7 8 9 10
# default used by the Trainer trainer = Trainer(val_check_interval=1.0) # check validation set 4 times during a training epoch trainer = Trainer(val_check_interval=0.25) # check validation set every 1000 training batches # use this when using iterableDataset and your dataset has no length # (ie: production cases with streaming data) trainer = Trainer(val_check_interval=1000)
# default used by the Trainer trainer = Trainer(limit_train_batches=1.0) # run through only 25% of the training set each epoch trainer = Trainer(limit_train_batches=0.25) # run through only 10 batches of the training set each epoch trainer = Trainer(limit_train_batches=10)
fast_dev_run
:bool量。如果设定为true,会只执行一个batch的train, val 和 test,然后结束。仅用于debug。
Setting this argument will disable tuner, checkpoint callbacks, early stopping callbacks, loggers and logger callbacks like
LearningRateLogger
and runs for only 1 epoch
1 2 3 4 5 6 7 8
# default used by the Trainer trainer = Trainer(fast_dev_run=False) # runs 1 train, val, test batch and program ends trainer = Trainer(fast_dev_run=True) # runs 7 train, val, test batches and program ends trainer = Trainer(fast_dev_run=7)
train_dataloader
(
Optional
[
DataLoader
]) – A Pytorch DataLoader with training samples. If the model has a predefined train_dataloader method this will be skipped.
val_dataloaders
(
Union
[
DataLoader
,
List
[
DataLoader
],
None
]) – Either a single Pytorch Dataloader or a list of them, specifying validation samples. If the model has a predefined val_dataloaders method this will be skipped
.test()
若非直接调用,不会运行。
trainer.test()
.test()
会自动load最优模型。
model.eval()
and
torch.no_grad()
在进行测试时会被自动调用。
默认情况下,
Trainer()
运行于CPU上。
手动添加命令行参数:
1 2 3 4 5 6 7 8 9 10 11 12 13
from argparse import ArgumentParser defmain(hparams): model = LightningModule() trainer = Trainer(gpus=hparams.gpus) trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--gpus', default=None) args = parser.parse_args() main(args)
自动添加所有
Trainer
会用到的命令行参数:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
from argparse import ArgumentParser defmain(args): model = LightningModule() trainer = Trainer.from_argparse_args(args) trainer.fit(model) if __name__ == '__main__': parser = ArgumentParser() parser = Trainer.add_argparse_args( # group the Trainer arguments together parser.add_argument_group(title="pl.Trainer args") ) args = parser.parse_args() main(args)
monitor
(
str
) – quantity to be monitored. Default:
'early_stop_on'
.
min_delta
(
float
) – minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute change of less than min_delta, will count as no improvement. Default:
0.0
.
patience
(
int
) – number of validation epochs with no improvement after which training will be stopped. Default:
3
.
mode
(
str
) – one of
'min'
,
'max'
. In
'min'
mode, training will stop when the quantity monitored has stopped decreasing and in
'max'
mode it will stop when the quantity monitored has stopped increasing.
strict
(
bool
) – whether to crash the training if monitor is not found in the validation metrics. Default:
True
.
1 2 3 4 5
from pytorch_lightning import Trainer from pytorch_lightning.callbacks import EarlyStopping early_stopping = EarlyStopping('val_loss') trainer = Trainer(callbacks=[early_stopping])
ModelCheckpoint
:见上文
Saving and Loading
.
PrintTableMetricsCallback
:在每个epoch结束后打印一份结果整理表格。
1 2 3 4 5 6 7 8 9 10 11 12 13
from pl_bolts.callbacks import PrintTableMetricsCallback callback = PrintTableMetricsCallback() trainer = pl.Trainer(callbacks=[callback]) trainer.fit(...) # ------------------------------ # at the end of every epoch it will print # ------------------------------ # loss│train_loss│val_loss│epoch # ────────────────────────────── # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0
deftraining_step(...): ... # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment tensorboard.add_image() tensorboard.add_histogram(...) tensorboard.add_figure(...)
def__iter__(self): n = len(self.data_source) if self.replacement: returniter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist()) returniter(torch.randperm(n).tolist())
The return statement is the important part, where the shuffling takes place. It simply creates a random permutation of the indices.
That means you will see your entire dataset every time you fully consume the iterator, just in a different order every time. Therefore there is no data lost (not including cases with
drop_last=True
) and your model will see all data at every epoch.
总结下来,如果使用了
shuffle=True
选项,那么即使每次都不跑完整个epoch,你还是有机会见到所有的数据的。数据集的shuffle发生在
iter
被创建的时候,在我们一般的代码中,也就是内层for循环开始时。但如果你没有选择
shuffle=True
,那你将永远只能看到你设定的前面N个数据。
# Single optimizer for epoch in epochs: for batch in data: loss = model.training_step(batch, batch_idx, ...) loss.backward() optimizer.step() optimizer.zero_grad() for scheduler in schedulers: scheduler.step() # Multiple optimizers for epoch in epochs: for batch in data: for opt in optimizers: disable_grads_for_other_optimizers() train_step(opt) opt.step() for scheduler in schedulers: scheduler.step()