training_factory
Convenience functions to create training configurations.
create_training_configuration(algorithm, trainer_params, logger, checkpoint_params=None, early_stopping_params=None, monitor_metric='val_loss') #
Create a dictionary with the parameters of the training model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
algorithm | (care, n2n, n2v) | Algorithm type, used to select the default checkpointing preset. | "care" |
trainer_params | dict | Parameters for Lightning Trainer class, see PyTorch Lightning documentation. | required |
logger | (wandb, tensorboard, none) | Logger to use. | "wandb" |
checkpoint_params | dict | Parameters for the checkpoint callback, see PyTorch Lightning documentation ( | None |
early_stopping_params | dict | Parameters for the early stopping callback, see PyTorch Lightning documentation ( | None |
monitor_metric | str | Metric to monitor for early stopping. | "val_loss" |
Returns:
| Type | Description |
|---|---|
NGTrainingConfig | Training configuration with the specified parameters. |
Source code in src/careamics/config/ng_factories/training_factory.py
update_trainer_params(trainer_params=None, num_epochs=None, num_steps=None) #
Update trainer parameters with num_epochs and num_steps.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
trainer_params | dict | Parameters for Lightning Trainer class, by default None. | None |
num_epochs | int | Number of epochs to train for. If provided, this will be added as max_epochs to trainer_params, by default None. | None |
num_steps | int | Number of batches in 1 epoch. If provided, this will be added as limit_train_batches to trainer_params, by default None. | None |
Returns:
| Type | Description |
|---|---|
dict | Updated trainer parameters dictionary. |