Prediction Writer Callback
Module containing PredictionWriterCallback class.
PredictionWriterCallback
Bases: BasePredictionWriter
A PyTorch Lightning callback to save predictions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_strategy
|
WriteStrategy
|
A strategy for writing predictions. |
required |
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Attributes:
| Name | Type | Description |
|---|---|---|
write_strategy |
WriteStrategy
|
A strategy for writing predictions. |
dirpath |
pathlib.Path, default="predictions"
|
The path to the directory where prediction outputs will be saved. If
|
writing_predictions |
bool
|
If writing predictions is turned on or off. |
from_write_func_params(write_type, tiled, write_func=None, write_extension=None, write_func_kwargs=None, dirpath='predictions')
classmethod
Initialize a PredictionWriterCallback from write function parameters.
This will automatically create a WriteStrategy to be passed to the
initialization of PredictionWriterCallback.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
write_type
|
('tiff', 'custom')
|
The data type to save as, includes custom. |
"tiff"
|
tiled
|
bool
|
Whether the prediction will be tiled or not. |
required |
write_func
|
WriteFunc
|
If a known |
None
|
write_extension
|
str
|
If a known |
None
|
write_func_kwargs
|
dict of {{str: any}}
|
Additional keyword arguments to be passed to the save function. |
None
|
dirpath
|
Path or str
|
The path to the directory where prediction outputs will be saved. If
|
"predictions"
|
Returns:
| Type | Description |
|---|---|
PredictionWriterCallback
|
Callback for writing predictions. |
setup(trainer, pl_module, stage)
Create the prediction output directory when predict begins.
Called when fit, validate, test, predict, or tune begins.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
stage
|
str
|
Stage of training e.g. 'predict', 'fit', 'validate'. |
required |
write_on_batch_end(trainer, pl_module, prediction, batch_indices, batch, batch_idx, dataloader_idx)
Write predictions at the end of a batch.
The method of prediction is determined by the attribute write_strategy.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer
|
Trainer
|
PyTorch Lightning trainer. |
required |
pl_module
|
LightningModule
|
PyTorch Lightning module. |
required |
prediction
|
Any
|
Prediction outputs of |
required |
batch_indices
|
sequence of Any
|
Batch indices. |
required |
batch
|
Any
|
Input batch. |
required |
batch_idx
|
int
|
Batch index. |
required |
dataloader_idx
|
int
|
Dataloader index. |
required |