Skip to content

Prediction Writer Callback

Source

Module containing PredictionWriterCallback class.

PredictionWriterCallback

Bases: BasePredictionWriter

A PyTorch Lightning callback to save predictions.

Parameters:

Name Type Description Default
write_strategy WriteStrategy

A strategy for writing predictions.

required
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Attributes:

Name Type Description
write_strategy WriteStrategy

A strategy for writing predictions.

dirpath pathlib.Path, default="predictions"

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

writing_predictions bool

If writing predictions is turned on or off.

from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions') classmethod

Initialize a PredictionWriterCallback from write function parameters.

This will automatically create a WriteStrategy to be passed to the initialization of PredictionWriterCallback.

Parameters:

Name Type Description Default
write_type ('tiff', 'custom')

The data type to save as, includes custom.

"tiff"
tiled bool

Whether the prediction will be tiled or not.

required
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func_kwargs dict of {{str: any}}

Additional keyword arguments to be passed to the save function.

None
dirpath Path or str

The path to the directory where prediction outputs will be saved. If dirpath is not absolute it is assumed to be relative to current working directory.

"predictions"

Returns:

Type Description
PredictionWriterCallback

Callback for writing predictions.

setup(trainer, pl_module, stage)

Create the prediction output directory when predict begins.

Called when fit, validate, test, predict, or tune begins.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
stage str

Stage of training e.g. 'predict', 'fit', 'validate'.

required

write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)

Write predictions at the end of a batch.

The method of prediction is determined by the attribute write_strategy.

Parameters:

Name Type Description Default
trainer Trainer

PyTorch Lightning trainer.

required
pl_module LightningModule

PyTorch Lightning module.

required
prediction Any

Prediction outputs of batch.

required
batch_indices sequence of Any

Batch indices.

required
batch Any

Input batch.

required
batch_idx int

Batch index.

required
dataloader_idx int

Dataloader index.

required