Skip to content

careamist_v2

Main interface for training and predicting with CAREamics.

CAREamistV2 #

Main interface for training and predicting with CAREamics.

Attributes:

Name Type Description
workdir Path

Working directory in which to save training outputs.

config NGConfiguration[AlgorithmConfig]

CAREamics configuration.

model CAREamicsModule

The PyTorch Lightning module to be trained and used for prediction.

checkpoint_path Path | None

Path to a checkpoint file from which model and configuration may be loaded.

trainer Trainer

The PyTorch Lightning Trainer used for training and prediction.

callbacks list[Callback]

List of callbacks used during training.

prediction_writer PredictionWriterCallback

Callback used to write predictions to disk during prediction.

train_datamodule CareamicsDataModule | None

The datamodule used for training, set after calling train().

Parameters:

Name Type Description Default
config NGConfiguration | Path

CAREamics configuration, or a path to a configuration file. See careamics.config.ng_factories for method to build configurations.

None
checkpoint_path Path

Path to a checkpoint file from which to load the model and configuration.

None
bmz_path Path

Path to a BioImage Model Zoo archive from which to load the model and configuration.

None
work_dir Path | str

Working directory in which to save training outputs. If None, the current working directory will be used.

None
callbacks list of PyTorch Lightning Callbacks

List of callbacks to use during training. If None, no additional callbacks will be used. Note that ModelCheckpoint and EarlyStopping callbacks are already defined in CAREamics and should only be modified through the training configuration (see NGConfiguration and TrainingConfig).

None
enable_progress_bar bool

Whether to show the progress bar during training.

True
Source code in src/careamics/careamist_v2.py
  46
  47
  48
  49
  50
  51
  52
  53
  54
  55
  56
  57
  58
  59
  60
  61
  62
  63
  64
  65
  66
  67
  68
  69
  70
  71
  72
  73
  74
  75
  76
  77
  78
  79
  80
  81
  82
  83
  84
  85
  86
  87
  88
  89
  90
  91
  92
  93
  94
  95
  96
  97
  98
  99
 100
 101
 102
 103
 104
 105
 106
 107
 108
 109
 110
 111
 112
 113
 114
 115
 116
 117
 118
 119
 120
 121
 122
 123
 124
 125
 126
 127
 128
 129
 130
 131
 132
 133
 134
 135
 136
 137
 138
 139
 140
 141
 142
 143
 144
 145
 146
 147
 148
 149
 150
 151
 152
 153
 154
 155
 156
 157
 158
 159
 160
 161
 162
 163
 164
 165
 166
 167
 168
 169
 170
 171
 172
 173
 174
 175
 176
 177
 178
 179
 180
 181
 182
 183
 184
 185
 186
 187
 188
 189
 190
 191
 192
 193
 194
 195
 196
 197
 198
 199
 200
 201
 202
 203
 204
 205
 206
 207
 208
 209
 210
 211
 212
 213
 214
 215
 216
 217
 218
 219
 220
 221
 222
 223
 224
 225
 226
 227
 228
 229
 230
 231
 232
 233
 234
 235
 236
 237
 238
 239
 240
 241
 242
 243
 244
 245
 246
 247
 248
 249
 250
 251
 252
 253
 254
 255
 256
 257
 258
 259
 260
 261
 262
 263
 264
 265
 266
 267
 268
 269
 270
 271
 272
 273
 274
 275
 276
 277
 278
 279
 280
 281
 282
 283
 284
 285
 286
 287
 288
 289
 290
 291
 292
 293
 294
 295
 296
 297
 298
 299
 300
 301
 302
 303
 304
 305
 306
 307
 308
 309
 310
 311
 312
 313
 314
 315
 316
 317
 318
 319
 320
 321
 322
 323
 324
 325
 326
 327
 328
 329
 330
 331
 332
 333
 334
 335
 336
 337
 338
 339
 340
 341
 342
 343
 344
 345
 346
 347
 348
 349
 350
 351
 352
 353
 354
 355
 356
 357
 358
 359
 360
 361
 362
 363
 364
 365
 366
 367
 368
 369
 370
 371
 372
 373
 374
 375
 376
 377
 378
 379
 380
 381
 382
 383
 384
 385
 386
 387
 388
 389
 390
 391
 392
 393
 394
 395
 396
 397
 398
 399
 400
 401
 402
 403
 404
 405
 406
 407
 408
 409
 410
 411
 412
 413
 414
 415
 416
 417
 418
 419
 420
 421
 422
 423
 424
 425
 426
 427
 428
 429
 430
 431
 432
 433
 434
 435
 436
 437
 438
 439
 440
 441
 442
 443
 444
 445
 446
 447
 448
 449
 450
 451
 452
 453
 454
 455
 456
 457
 458
 459
 460
 461
 462
 463
 464
 465
 466
 467
 468
 469
 470
 471
 472
 473
 474
 475
 476
 477
 478
 479
 480
 481
 482
 483
 484
 485
 486
 487
 488
 489
 490
 491
 492
 493
 494
 495
 496
 497
 498
 499
 500
 501
 502
 503
 504
 505
 506
 507
 508
 509
 510
 511
 512
 513
 514
 515
 516
 517
 518
 519
 520
 521
 522
 523
 524
 525
 526
 527
 528
 529
 530
 531
 532
 533
 534
 535
 536
 537
 538
 539
 540
 541
 542
 543
 544
 545
 546
 547
 548
 549
 550
 551
 552
 553
 554
 555
 556
 557
 558
 559
 560
 561
 562
 563
 564
 565
 566
 567
 568
 569
 570
 571
 572
 573
 574
 575
 576
 577
 578
 579
 580
 581
 582
 583
 584
 585
 586
 587
 588
 589
 590
 591
 592
 593
 594
 595
 596
 597
 598
 599
 600
 601
 602
 603
 604
 605
 606
 607
 608
 609
 610
 611
 612
 613
 614
 615
 616
 617
 618
 619
 620
 621
 622
 623
 624
 625
 626
 627
 628
 629
 630
 631
 632
 633
 634
 635
 636
 637
 638
 639
 640
 641
 642
 643
 644
 645
 646
 647
 648
 649
 650
 651
 652
 653
 654
 655
 656
 657
 658
 659
 660
 661
 662
 663
 664
 665
 666
 667
 668
 669
 670
 671
 672
 673
 674
 675
 676
 677
 678
 679
 680
 681
 682
 683
 684
 685
 686
 687
 688
 689
 690
 691
 692
 693
 694
 695
 696
 697
 698
 699
 700
 701
 702
 703
 704
 705
 706
 707
 708
 709
 710
 711
 712
 713
 714
 715
 716
 717
 718
 719
 720
 721
 722
 723
 724
 725
 726
 727
 728
 729
 730
 731
 732
 733
 734
 735
 736
 737
 738
 739
 740
 741
 742
 743
 744
 745
 746
 747
 748
 749
 750
 751
 752
 753
 754
 755
 756
 757
 758
 759
 760
 761
 762
 763
 764
 765
 766
 767
 768
 769
 770
 771
 772
 773
 774
 775
 776
 777
 778
 779
 780
 781
 782
 783
 784
 785
 786
 787
 788
 789
 790
 791
 792
 793
 794
 795
 796
 797
 798
 799
 800
 801
 802
 803
 804
 805
 806
 807
 808
 809
 810
 811
 812
 813
 814
 815
 816
 817
 818
 819
 820
 821
 822
 823
 824
 825
 826
 827
 828
 829
 830
 831
 832
 833
 834
 835
 836
 837
 838
 839
 840
 841
 842
 843
 844
 845
 846
 847
 848
 849
 850
 851
 852
 853
 854
 855
 856
 857
 858
 859
 860
 861
 862
 863
 864
 865
 866
 867
 868
 869
 870
 871
 872
 873
 874
 875
 876
 877
 878
 879
 880
 881
 882
 883
 884
 885
 886
 887
 888
 889
 890
 891
 892
 893
 894
 895
 896
 897
 898
 899
 900
 901
 902
 903
 904
 905
 906
 907
 908
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
class CAREamistV2:
    """Main interface for training and predicting with CAREamics.

    Attributes
    ----------
    workdir : Path
        Working directory in which to save training outputs.
    config : NGConfiguration[AlgorithmConfig]
        CAREamics configuration.
    model : CAREamicsModule
        The PyTorch Lightning module to be trained and used for prediction.
    checkpoint_path : Path | None
        Path to a checkpoint file from which model and configuration may be loaded.
    trainer : Trainer
        The PyTorch Lightning Trainer used for training and prediction.
    callbacks : list[Callback]
        List of callbacks used during training.
    prediction_writer : PredictionWriterCallback
        Callback used to write predictions to disk during prediction.
    train_datamodule : CareamicsDataModule | None
        The datamodule used for training, set after calling `train()`.

    Parameters
    ----------
    config : NGConfiguration | Path, default=None
        CAREamics configuration, or a path to a configuration file. See
        `careamics.config.ng_factories` for method to build configurations.
    checkpoint_path : Path, default=None
        Path to a checkpoint file from which to load the model and configuration.
    bmz_path : Path, default=None
        Path to a BioImage Model Zoo archive from which to load the model and
        configuration.
    work_dir : Path | str, default=None
        Working directory in which to save training outputs. If None, the current
        working directory will be used.
    callbacks : list of PyTorch Lightning Callbacks, default=None
        List of callbacks to use during training. If None, no additional callbacks
        will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
        already defined in CAREamics and should only be modified through the
        training configuration (see NGConfiguration and TrainingConfig).
    enable_progress_bar : bool, default=True
        Whether to show the progress bar during training.
    """

    def __init__(
        self,
        config: ConfigurationType | Path | None = None,
        *,
        checkpoint_path: Path | None = None,
        bmz_path: Path | None = None,
        work_dir: Path | str | None = None,
        callbacks: list[Callback] | None = None,
        enable_progress_bar: bool = True,
    ) -> None:
        """Constructor for CAREamistV2.

        Exactly one of `config`, `checkpoint_path`, or `bmz_path` must be provided.

        Parameters
        ----------
        config : NGConfiguration | Path, default=None
            CAREamics configuration, or a path to a configuration file. See
            `careamics.config.ng_factories` for method to build configurations. `config`
            is mutually exclusive with `checkpoint_path` and `bmz_path`.
        checkpoint_path : Path, default=None
            Path to a checkpoint file from which to load the model and configuration.
            `checkpoint_path` is mutually exclusive with `config` and `bmz_path`.
        bmz_path : Path, default=None
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration. `bmz_path` is mutually exclusive with `config` and
            `checkpoint_path`.
        work_dir : Path | str, default=None
            Working directory in which to save training outputs. If None, the current
            working directory will be used.
        callbacks : list of PyTorch Lightning Callbacks, default=None
            List of callbacks to use during training. If None, no additional callbacks
            will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
            already defined in CAREamics and should only be modified through the
            training configuration (see NGConfiguration and TrainingConfig).
        enable_progress_bar : bool, default=True
            Whether to show the progress bar during training.
        """
        self.checkpoint_path = checkpoint_path
        self.work_dir = self._resolve_work_dir(work_dir)

        self.config: ConfigurationType
        self.config, self.model = self._load_model(
            config, checkpoint_path, bmz_path
        )

        self.config.training_config.trainer_params["enable_progress_bar"] = (
            enable_progress_bar
        )
        self.callbacks = self._define_callbacks(callbacks, self.config, self.work_dir)

        self.prediction_writer = PredictionWriterCallback(
            self.work_dir, enable_writing=False
        )

        experiment_loggers = self._create_loggers(
            self.config.training_config.logger,
            self.config.get_safe_experiment_name(),
            self.work_dir,
        )

        self.trainer = Trainer(
            callbacks=[self.prediction_writer, *self.callbacks],
            default_root_dir=self.work_dir,
            logger=experiment_loggers,
            **self.config.training_config.trainer_params or {},
        )

        self.train_datamodule: CareamicsDataModule | None = None

    def _load_model(
        self,
        config: ConfigurationType | Path | None,
        checkpoint_path: Path | None,
        bmz_path: Path | None,
    ) -> tuple[ConfigurationType, CAREamicsModule]:
        """Load model.

        Parameters
        ----------
        config : NGConfiguration | Path | None
            CAREamics configuration, or a path to a configuration file.
        checkpoint_path : Path | None
            Path to a checkpoint file from which to load the model and configuration.
        bmz_path : Path | None
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration.

        Returns
        -------
        NGConfiguration
            The loaded configuration.
        CAREamicsModule
            The loaded model.

        Raises
        ------
        ValueError
            If not exactly one of `config`, `checkpoint_path`, or `bmz_path` is
            provided.
        """
        n_inputs = sum(
            [config is not None, checkpoint_path is not None, bmz_path is not None]
        )
        if n_inputs != 1:
            raise ValueError(
                "Exactly one of `config`, `checkpoint_path`, or `bmz_path` "
                "must be provided."
            )
        if config is not None:
            return self._from_config(config)
        elif checkpoint_path is not None:
            return self._from_checkpoint(checkpoint_path)
        else:
            assert bmz_path is not None
            return self._from_bmz(bmz_path)

    @staticmethod
    def _from_config(
        config: ConfigurationType | Path,
    ) -> tuple[ConfigurationType, CAREamicsModule]:
        """Create model from configuration.

        Parameters
        ----------
        config : NGConfiguration | Path
            CAREamics configuration, or a path to a configuration file.

        Returns
        -------
        NGConfiguration
            The loaded configuration if a path was provided, otherwise the original
            configuration.
        CAREamicsModule
            The created model.
        """
        if isinstance(config, Path):
            config = load_configuration_ng(config)
        assert not isinstance(config, Path)

        model = create_module(config.algorithm_config)
        return config, model

    @staticmethod
    def _from_checkpoint(
        checkpoint_path: Path,
    ) -> tuple[ConfigurationType, CAREamicsModule]:
        """Load checkpoint and configuration from checkpoint file.

        Parameters
        ----------
        checkpoint_path : Path
            Path to a checkpoint file from which to load the model and configuration.

        Returns
        -------
        NGConfiguration
            The loaded configuration.
        CAREamicsModule
            The loaded model.
        """
        config = load_config_from_checkpoint(checkpoint_path)
        module = load_module_from_checkpoint(checkpoint_path)

        return config, module

    @staticmethod
    def _from_bmz(
        bmz_path: Path,
    ) -> tuple[ConfigurationType, CAREamicsModule]:
        """Load checkpoint and configuration from a BioImage Model Zoo archive.

        Parameters
        ----------
        bmz_path : Path
            Path to a BioImage Model Zoo archive from which to load the model and
            configuration.

        Returns
        -------
        NGConfiguration
            The loaded configuration.
        CAREamicsModule
            The loaded model.

        Raises
        ------
        NotImplementedError
            Loading from BMZ is not implemented yet.
        """
        raise NotImplementedError("Loading from BMZ is not implemented yet.")

    @staticmethod
    def _resolve_work_dir(work_dir: str | Path | None) -> Path:
        """Resolve working directory.

        Parameters
        ----------
        work_dir : str | Path | None
            The working directory to resolve. If None, the current working directory
            will be used.

        Returns
        -------
        Path
            The resolved working directory.
        """
        if work_dir is None:
            work_dir = Path.cwd().resolve()
            logger.warning(
                f"No working directory provided. Using current working directory: "
                f"{work_dir}."
            )
        else:
            work_dir = Path(work_dir).resolve()
        return work_dir

    @staticmethod
    def _define_callbacks(
        callbacks: list[Callback] | None,
        config: ConfigurationType,
        work_dir: Path,
    ) -> list[Callback]:
        """Define callbacks for the training process.

        Parameters
        ----------
        callbacks : list[Callback] | None
            List of callbacks to use during training. If None, no additional callbacks
            will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
            already defined in CAREamics and instantiated in this method.
        config : NGConfiguration
            The CAREamics configuration, used to instantiate the callbacks.
        work_dir : Path
            The working directory, used as a parameter to the checkpointing callback.

        Returns
        -------
        list[Callback]
            The list of callbacks to use during training.

        Raises
        ------
        ValueError
            If `ModelCheckpoint` or `EarlyStopping` callbacks are included in the
            provided `callbacks` list, as these are already defined in CAREamics and
            should only be modified through the training configuration (see
            NGConfiguration and TrainingConfig).
        """
        callback_lst: list[Callback] = [] if callbacks is None else callbacks
        for c in callback_lst:
            if isinstance(c, (ModelCheckpoint, EarlyStopping)):
                raise ValueError(
                    "`ModelCheckpoint` and `EarlyStopping` callbacks are already "
                    "defined in CAREamics and should only be modified through the "
                    "training configuration (see TrainingConfig)."
                )

            if isinstance(c, (CareamicsCheckpointInfo, ProgressBarCallback)):
                raise ValueError(
                    "`CareamicsCheckpointInfo` and `ProgressBar` callbacks are defined "
                    "internally and should not be passed as callbacks."
                )

        checkpoint_callback = ModelCheckpoint(
            dirpath=work_dir / "checkpoints" / config.get_safe_experiment_name(),
            filename=(
                f"{config.get_safe_experiment_name()}_{{epoch:02d}}_step_{{step}}_"
                f"{{val_loss:.4f}}"
            ),
            **config.training_config.checkpoint_params,
        )
        checkpoint_callback.CHECKPOINT_NAME_LAST = (
            f"{config.get_safe_experiment_name()}_last"
        )
        internal_callbacks: list[Callback] = [
            checkpoint_callback,
            CareamicsCheckpointInfo(
                config.version,
                config.get_safe_experiment_name(),
                config.training_config
            ),
        ]

        enable_progress_bar = config.training_config.trainer_params.get(
            "enable_progress_bar", True
        )
        if enable_progress_bar:
            internal_callbacks.append(ProgressBarCallback())

        if config.training_config.early_stopping_params is not None:
            internal_callbacks.append(
                EarlyStopping(
                    **config.training_config.early_stopping_params
                )
            )

        return internal_callbacks + callback_lst

    @staticmethod
    def _create_loggers(
        logger: str | None, experiment_name: str, work_dir: Path
    ) -> list[TensorBoardLogger | WandbLogger | CSVLogger]:
        """Create loggers for the experiment.

        Parameters
        ----------
        logger : str | None
            Logger to use during training. If None, no logger will be used. Available
            loggers are defined in SupportedLogger.
        experiment_name : str
            Name of the experiment, used as a parameter to the loggers.
        work_dir : Path
            The working directory, used as a parameter to the loggers.

        Returns
        -------
        list[TensorBoardLogger | WandbLogger | CSVLogger]
            The list of loggers to use during training.
        """
        csv_logger = CSVLogger(name=experiment_name, save_dir=work_dir / "csv_logs")

        if logger is not None:
            logger = SupportedLogger(logger)

        match logger:
            case SupportedLogger.WANDB:
                return [
                    WandbLogger(name=experiment_name, save_dir=work_dir / "wandb_logs"),
                    csv_logger,
                ]
            case SupportedLogger.TENSORBOARD:
                return [
                    TensorBoardLogger(save_dir=work_dir / "tb_logs"),
                    csv_logger,
                ]
            case _:
                return [csv_logger]

    # Two overloads:
    # - 1st for supported data types & using ReadFuncLoading
    # - 2nd for ImageStackLoading
    # Why:
    #   ImageStackLoading supports any type as input, but we want to tell most users
    #   that they are only allowed Path, str, ndarray or a sequence of these.
    #   The first overload will be displaced first by most code editors, this is what
    #   most users will see.
    @overload
    def train( # numpydoc ignore=GL08
        self,
        *,
        # BASIC PARAMS
        train_data: InputVar | None = None,
        train_data_target: InputVar | None = None,
        val_data: InputVar | None = None,
        val_data_target: InputVar | None = None,
        # ADVANCED PARAMS
        filtering_mask: InputVar | None = None,
        loading: ReadFuncLoading | None = None,
    ) -> None: ...

    @overload  # any data input is allowed for ImageStackLoading
    def train( # numpydoc ignore=GL08
        self,
        *,
        # BASIC PARAMS
        train_data: Any | None = None,
        train_data_target: Any | None = None,
        val_data: Any | None = None,
        val_data_target: Any | None = None,
        # ADVANCED PARAMS
        filtering_mask: Any | None = None,
        loading: ImageStackLoading = ...,
    ) -> None: ...

    def train(
        self,
        *,
        # BASIC PARAMS
        train_data: Any | None = None,
        train_data_target: Any | None = None,
        val_data: Any | None = None,
        val_data_target: Any | None = None,
        # ADVANCED PARAMS
        filtering_mask: Any | None = None,
        loading: Loading = None,
    ) -> None:
        """Train the model on the provided data.

        The training data can be provided as arrays or paths.

        Parameters
        ----------
        train_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Training data, by default None.
        train_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
            Training target data, by default None.
        val_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
            Validation data. If not provided, `data_config.n_val_patches` patches will
            selected from the training data for validation.
        val_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
            Validation target data, by default None.
        filtering_mask : pathlib.Path, str, numpy.ndarray, or sequence of these
            Filtering mask for coordinate-based patch filtering, by default None.
        loading : Loading, default=None
            Loading strategy to use for the prediction data. May be a ReadFuncLoading or
            ImageStackLoading. If None, uses the loading strategy from the training
            configuration.

        Raises
        ------
        ValueError
            If train_data is not provided.
        """
        if train_data is None:
            raise ValueError("Training data must be provided. Provide `train_data`.")

        if self.config.is_supervised() and train_data_target is None:
            raise ValueError(
                f"Training target data must be provided for supervised training (got "
                f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
                f"`train_data_target`."
            )

        if (
            self.config.is_supervised()
            and val_data is not None
            and val_data_target is None
        ):
            raise ValueError(
                f"Validation target data must be provided for supervised training (got "
                f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
                f"`val_data_target`."
            )

        datamodule = CareamicsDataModule( # type: ignore
            data_config=self.config.data_config,
            train_data=train_data,
            val_data=val_data,
            train_data_target=train_data_target,
            val_data_target=val_data_target,
            train_data_mask=filtering_mask,
            loading=loading, # type: ignore
        )

        self.train_datamodule = datamodule

        # set parameters back to defaults, this is a guard against `stop_training`
        # which changes them in order to interrupt training gracefully
        self.trainer.should_stop = False
        self.trainer.limit_val_batches = 1.0 # equivalent to all validation batches

        self.trainer.fit(
            self.model, datamodule=datamodule, ckpt_path=self.checkpoint_path
        )

    def _build_predict_datamodule(
        self,
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
    ) -> CareamicsDataModule:
        """Create prediction data module.

        Parameters
        ----------
        pred_data : Any
            Prediction data.
        pred_data_target : Any | None, default=None
            Prediction target data, by default None. Can be used to compute metrics
            during prediction.
        batch_size : int | None, default=None
            Batch size for prediction. If None, uses the batch size from the training
            configuration.
        tile_size : tuple[int, ...] | None, default=None
            Tile size for prediction. If None, uses whole image prediction.
        tile_overlap : tuple[int, ...] | None, default=(48, 48)
            Tile overlap for prediction. If None, defaults to (48, 48).
        axes : str | None, default=None
            Axes for prediction. If None, uses training configuration axes.
        data_type : {"array", "tiff", "zarr", "czi", "custom"} | None, default=None
            Data type for prediction. If None, uses training configuration data type.
        num_workers : int | None, default=None
            Number of workers for data loading during prediction.
        channels : Sequence[int] | Literal["all"] | None, default=None
            Channels to use for prediction. If "all", uses all channels. If None, uses
            the channels from the training configuration.
        in_memory : bool | None, default=None
            Whether to load data into memory during prediction. If None, uses training
            configuration.
        loading : Loading, default=None
            Loading strategy for prediction data if data type (either from training
            configuration or specified) is `"custom"`.

        Returns
        -------
        CareamicsDataModule
            Prediction data module.
        """
        dataloader_params: dict[str, Any] | None = None
        if num_workers is not None:
            dataloader_params = {"num_workers": num_workers}

        pred_data_config = self.config.data_config.convert_mode(
            new_mode="predicting",
            new_patch_size=tile_size,
            overlap_size=tile_overlap,
            new_batch_size=batch_size,
            new_data_type=data_type,
            new_dataloader_params=dataloader_params,
            new_axes=axes,
            new_channels=channels,
            new_in_memory=in_memory,
        )
        return CareamicsDataModule(
            data_config=pred_data_config,
            pred_data=pred_data,
            pred_data_target=pred_data_target,
            loading=loading,
        )

    # see comment on train func for a description of why we have these two overloads
    @overload  # constrained input data type for supported data or ReadFuncLoading
    def predict( # numpydoc ignore=GL08
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ReadFuncLoading | None = None,
    ) -> tuple[list[NDArray], list[str]]:
        ...

    @overload  # any data input is allowed for ImageStackLoading
    def predict( # numpydoc ignore=GL08
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ImageStackLoading = ...,
    ) -> tuple[list[NDArray], list[str]]:
        ...

    def predict(
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
    ) -> tuple[list[NDArray], list[str]]:
        """
        Predict on data and return the predictions.

        Input can be a path to a data file, a list of paths, a numpy array, or a
        list of numpy arrays.

        If `data_type` and `axes` are not provided, the training configuration
        parameters will be used. If `tile_size` is not provided, prediction will
        be performed on the whole image.

        Note that if you are using a UNet model and tiling, the tile size must be
        divisible in every dimension by 2**d, where d is the depth of the model. This
        avoids artefacts arising from the broken shift invariance induced by the
        pooling layers of the UNet. Images smaller than the tile size in any spatial
        dimension will be automatically zero-padded.

        Parameters
        ----------
        pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
            Data to predict on. Can be a single item or a sequence of paths/arrays.
        batch_size : int, optional
            Batch size for prediction. If not provided, uses the training configuration
            batch size.
        tile_size : tuple of int, optional
            Size of the tiles to use for prediction. If not provided, prediction
            will be performed on the whole image.
        tile_overlap : tuple of int, default=(48, 48)
            Overlap between tiles, can be None.
        axes : str, optional
            Axes of the input data, by default None.
        data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
            Type of the input data.
        num_workers : int, optional
            Number of workers for the dataloader, by default None.
        channels : sequence of int or "all", optional
            Channels to use from the data. If None, uses the training configuration
            channels.
        in_memory : bool, optional
            Whether to load all data into memory. If None, uses the training
            configuration setting.
        loading : Loading, default=None
            Loading strategy to use for the prediction data. May be a ReadFuncLoading or
            ImageStackLoading. If None, uses the loading strategy from the training
            configuration.

        Returns
        -------
        tuple of (list of NDArray, list of str)
            Predictions made by the model and their source identifiers.

        Raises
        ------
        ValueError
            If tile overlap is not specified when tile_size is provided.
        """
        datamodule = self._build_predict_datamodule(
            pred_data,
            batch_size=batch_size,
            tile_size=tile_size,
            tile_overlap=tile_overlap,
            axes=axes,
            data_type=data_type,
            num_workers=num_workers,
            channels=channels,
            in_memory=in_memory,
            loading=loading,
        )

        predictions: list[ImageRegionData] = self.trainer.predict(
            model=self.model, datamodule=datamodule
        )  # type: ignore[assignment]
        tiled = tile_size is not None
        predictions_output, sources = convert_prediction(
            predictions, tiled=tiled, restore_shape=True
        )

        return predictions_output, sources

    # see comment on train func for a description of why we have these two overloads
    @overload  # constrained input data type for supported data or ReadFuncLoading
    def predict_to_disk( # numpydoc ignore=GL08
        self,
        # BASIC PARAMS
        pred_data: InputVar,
        *,
        pred_data_target: InputVar | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ReadFuncLoading | None = None,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None: ...

    @overload  # any data input is allowed for ImageStackLoading
    def predict_to_disk( # numpydoc ignore=GL08
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: ImageStackLoading = ...,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None: ...

    def predict_to_disk(
        self,
        # BASIC PARAMS
        pred_data: Any,
        *,
        pred_data_target: Any | None = None,
        prediction_dir: Path | str = "predictions",
        batch_size: int | None = None,
        tile_size: tuple[int, ...] | None = None,
        tile_overlap: tuple[int, ...] | None = (48, 48),
        axes: str | None = None,
        data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
        # ADVANCED PARAMS
        num_workers: int | None = None,
        channels: Sequence[int] | Literal["all"] | None = None,
        in_memory: bool | None = None,
        loading: Loading = None,
        # WRITE OPTIONS
        write_type: Literal["tiff", "zarr", "custom"] = "tiff",
        write_extension: str | None = None,
        write_func: WriteFunc | None = None,
        write_func_kwargs: dict[str, Any] | None = None,
    ) -> None:
        """
        Make predictions on the provided data and save outputs to files.

        Predictions are saved to `prediction_dir` (absolute paths are used as-is,
        relative paths are relative to `work_dir`). The directory structure matches
        the source directory.

        The file names of the predictions will match those of the source. If there is
        more than one sample within a file, the samples will be stacked along the sample
        dimension in the output file.

        If `data_type` and `axes` are not provided, the training configuration
        parameters will be used. If `tile_size` is not provided, prediction
        will be performed on whole images rather than in a tiled manner.

        Note that if you are using a UNet model and tiling, the tile size must be
        divisible in every dimension by 2**d, where d is the depth of the model. This
        avoids artefacts arising from the broken shift invariance induced by the
        pooling layers of the UNet. Images smaller than the tile size in any spatial
        dimension will be automatically zero-padded.

        Parameters
        ----------
        pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
            Data to predict on. Can be a single item or a sequence of paths/arrays.
        pred_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
            Prediction data target, by default None.
        prediction_dir : Path | str, default="predictions"
            The path to save the prediction results to. If `prediction_dir` is an
            absolute path, it will be used as-is. If it is a relative path, it will
            be relative to the pre-set `work_dir`. If the directory does not exist it
            will be created.
        batch_size : int, optional
            Batch size for prediction. If not provided, uses the training configuration
            batch size.
        tile_size : tuple of int, optional
            Size of the tiles to use for prediction. If not provided, uses whole image
            strategy.
        tile_overlap : tuple of int, default=(48, 48)
            Overlap between tiles.
        axes : str, optional
            Axes of the input data, by default None.
        data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
            Type of the input data.
        num_workers : int, optional
            Number of workers for the dataloader, by default None.
        channels : sequence of int or "all", optional
            Channels to use from the data. If None, uses the training configuration
            channels.
        in_memory : bool, optional
            Whether to load all data into memory. If None, uses the training
            configuration setting.
        loading : Loading, default=None
            Loading strategy to use for the prediction data. May be a ReadFuncLoading or
            ImageStackLoading. If None, uses the loading strategy from the training
            configuration.
        write_type : {"tiff", "zarr", "custom"}, default="tiff"
            The data type to save as, includes custom.
        write_extension : str, optional
            If a known `write_type` is selected this argument is ignored. For a custom
            `write_type` an extension to save the data with must be passed.
        write_func : WriteFunc, optional
            If a known `write_type` is selected this argument is ignored. For a custom
            `write_type` a function to save the data must be passed. See notes below.
        write_func_kwargs : dict of {str: any}, optional
            Additional keyword arguments to be passed to the save function.

        Raises
        ------
        ValueError
            If `write_type` is custom and `write_extension` is None.
        ValueError
            If `write_type` is custom and `write_func` is None.
        """
        if write_func_kwargs is None:
            write_func_kwargs = {}

        if Path(prediction_dir).is_absolute():
            write_dir = Path(prediction_dir)
        else:
            write_dir = self.work_dir / prediction_dir
        self.prediction_writer.dirpath = write_dir

        if write_type == "custom":
            if write_extension is None:
                raise ValueError(
                    "A `write_extension` must be provided for custom write types."
                )
            if write_func is None:
                raise ValueError(
                    "A `write_func` must be provided for custom write types."
                )
        elif write_type == "zarr" and tile_size is None:
            raise ValueError(
                "Writing prediction to Zarr is only supported with tiling. Please "
                "provide a value for `tile_size`, and optionally `tile_overlap`."
            )
        else:
            write_func = get_write_func(write_type)
            write_extension = SupportedData.get_extension(write_type)

        tiled = tile_size is not None
        self.prediction_writer.set_writing_strategy(
            write_type=write_type,
            tiled=tiled,
            write_func=write_func,
            write_extension=write_extension,
            write_func_kwargs=write_func_kwargs,
        )

        self.prediction_writer.enable_writing(True)

        try:
            datamodule = self._build_predict_datamodule(
                pred_data,
                pred_data_target=pred_data_target,
                batch_size=batch_size,
                tile_size=tile_size,
                tile_overlap=tile_overlap,
                axes=axes,
                data_type=data_type,
                num_workers=num_workers,
                channels=channels,
                in_memory=in_memory,
                loading=loading,
            )

            self.trainer.predict(
                model=self.model, datamodule=datamodule, return_predictions=False
            )

        finally:
            self.prediction_writer.enable_writing(False)

    def export_to_bmz(
        self,
        path_to_archive: Path | str,
        friendly_model_name: str,
        input_array: NDArray,
        authors: list[dict],
        general_description: str,
        data_description: str,
        covers: list[Path | str] | None = None,
        channel_names: list[str] | None = None,
        model_version: str = "0.2.0",
    ) -> None:
        """Export the model to the BioImage Model Zoo format.

        This method packages the current weights into a zip file that can be uploaded
        to the BioImage Model Zoo. The archive consists of the model weights, the model
        specifications and various files (inputs, outputs, README, env.yaml etc.).

        `path_to_archive` should point to a file with a ".zip" extension.

        `friendly_model_name` is the name used for the model in the BMZ specs
        and website, it should consist of letters, numbers, dashes, underscores and
        parentheses only.

        Input array must be of the same dimensions as the axes recorded in the
        configuration of the `CAREamist`.

        Parameters
        ----------
        path_to_archive : pathlib.Path or str
            Path in which to save the model, including file name, which should end with
            ".zip".
        friendly_model_name : str
            Name of the model as used in the BMZ specs, it should consist of letters,
            numbers, dashes, underscores and parentheses only.
        input_array : NDArray
            Input array used to validate the model and as example.
        authors : list of dict
            List of authors of the model.
        general_description : str
            General description of the model used in the BMZ metadata.
        data_description : str
            Description of the data the model was trained on.
        covers : list of pathlib.Path or str, default=None
            Paths to the cover images.
        channel_names : list of str, default=None
            Channel names.
        model_version : str, default="0.1.0"
            Version of the model.
        """
        # from .model_io import export_to_bmz

        # output_patch = self.predict(
        #     pred_data=input_array,
        #     data_type=SupportedData.ARRAY.value,
        # )
        # output = np.concatenate(output_patch, axis=0)
        # input_array = reshape_array(input_array, self.config.data_config.axes)

        # export_to_bmz(
        #     model=self.model,
        #     config=self.config,
        #     path_to_archive=path_to_archive,
        #     model_name=friendly_model_name,
        #     general_description=general_description,
        #     data_description=data_description,
        #     authors=authors,
        #     input_array=input_array,
        #     output_array=output,
        #     covers=covers,
        #     channel_names=channel_names,
        #     model_version=model_version,
        # )
        raise NotImplementedError("Exporting to BMZ is not implemented yet.")

    def get_losses(self) -> dict[str, list]:
        """Return data that can be used to plot train and validation loss curves.

        Returns
        -------
        dict of str: list
            Dictionary containing losses for each epoch.
        """
        return read_csv_logger(
            self.config.get_safe_experiment_name(), self.work_dir / "csv_logs"
        )

    def stop_training(self) -> None:
        """Stop the training loop."""
        self.trainer.should_stop = True
        self.trainer.limit_val_batches = 0  # skip validation

__init__(config=None, *, checkpoint_path=None, bmz_path=None, work_dir=None, callbacks=None, enable_progress_bar=True) #

Constructor for CAREamistV2.

Exactly one of config, checkpoint_path, or bmz_path must be provided.

Parameters:

Name Type Description Default
config NGConfiguration | Path

CAREamics configuration, or a path to a configuration file. See careamics.config.ng_factories for method to build configurations. config is mutually exclusive with checkpoint_path and bmz_path.

None
checkpoint_path Path

Path to a checkpoint file from which to load the model and configuration. checkpoint_path is mutually exclusive with config and bmz_path.

None
bmz_path Path

Path to a BioImage Model Zoo archive from which to load the model and configuration. bmz_path is mutually exclusive with config and checkpoint_path.

None
work_dir Path | str

Working directory in which to save training outputs. If None, the current working directory will be used.

None
callbacks list of PyTorch Lightning Callbacks

List of callbacks to use during training. If None, no additional callbacks will be used. Note that ModelCheckpoint and EarlyStopping callbacks are already defined in CAREamics and should only be modified through the training configuration (see NGConfiguration and TrainingConfig).

None
enable_progress_bar bool

Whether to show the progress bar during training.

True
Source code in src/careamics/careamist_v2.py
def __init__(
    self,
    config: ConfigurationType | Path | None = None,
    *,
    checkpoint_path: Path | None = None,
    bmz_path: Path | None = None,
    work_dir: Path | str | None = None,
    callbacks: list[Callback] | None = None,
    enable_progress_bar: bool = True,
) -> None:
    """Constructor for CAREamistV2.

    Exactly one of `config`, `checkpoint_path`, or `bmz_path` must be provided.

    Parameters
    ----------
    config : NGConfiguration | Path, default=None
        CAREamics configuration, or a path to a configuration file. See
        `careamics.config.ng_factories` for method to build configurations. `config`
        is mutually exclusive with `checkpoint_path` and `bmz_path`.
    checkpoint_path : Path, default=None
        Path to a checkpoint file from which to load the model and configuration.
        `checkpoint_path` is mutually exclusive with `config` and `bmz_path`.
    bmz_path : Path, default=None
        Path to a BioImage Model Zoo archive from which to load the model and
        configuration. `bmz_path` is mutually exclusive with `config` and
        `checkpoint_path`.
    work_dir : Path | str, default=None
        Working directory in which to save training outputs. If None, the current
        working directory will be used.
    callbacks : list of PyTorch Lightning Callbacks, default=None
        List of callbacks to use during training. If None, no additional callbacks
        will be used. Note that `ModelCheckpoint` and `EarlyStopping` callbacks are
        already defined in CAREamics and should only be modified through the
        training configuration (see NGConfiguration and TrainingConfig).
    enable_progress_bar : bool, default=True
        Whether to show the progress bar during training.
    """
    self.checkpoint_path = checkpoint_path
    self.work_dir = self._resolve_work_dir(work_dir)

    self.config: ConfigurationType
    self.config, self.model = self._load_model(
        config, checkpoint_path, bmz_path
    )

    self.config.training_config.trainer_params["enable_progress_bar"] = (
        enable_progress_bar
    )
    self.callbacks = self._define_callbacks(callbacks, self.config, self.work_dir)

    self.prediction_writer = PredictionWriterCallback(
        self.work_dir, enable_writing=False
    )

    experiment_loggers = self._create_loggers(
        self.config.training_config.logger,
        self.config.get_safe_experiment_name(),
        self.work_dir,
    )

    self.trainer = Trainer(
        callbacks=[self.prediction_writer, *self.callbacks],
        default_root_dir=self.work_dir,
        logger=experiment_loggers,
        **self.config.training_config.trainer_params or {},
    )

    self.train_datamodule: CareamicsDataModule | None = None

export_to_bmz(path_to_archive, friendly_model_name, input_array, authors, general_description, data_description, covers=None, channel_names=None, model_version='0.2.0') #

Export the model to the BioImage Model Zoo format.

This method packages the current weights into a zip file that can be uploaded to the BioImage Model Zoo. The archive consists of the model weights, the model specifications and various files (inputs, outputs, README, env.yaml etc.).

path_to_archive should point to a file with a ".zip" extension.

friendly_model_name is the name used for the model in the BMZ specs and website, it should consist of letters, numbers, dashes, underscores and parentheses only.

Input array must be of the same dimensions as the axes recorded in the configuration of the CAREamist.

Parameters:

Name Type Description Default
path_to_archive Path or str

Path in which to save the model, including file name, which should end with ".zip".

required
friendly_model_name str

Name of the model as used in the BMZ specs, it should consist of letters, numbers, dashes, underscores and parentheses only.

required
input_array NDArray

Input array used to validate the model and as example.

required
authors list of dict

List of authors of the model.

required
general_description str

General description of the model used in the BMZ metadata.

required
data_description str

Description of the data the model was trained on.

required
covers list of pathlib.Path or str

Paths to the cover images.

None
channel_names list of str

Channel names.

None
model_version str

Version of the model.

"0.1.0"
Source code in src/careamics/careamist_v2.py
def export_to_bmz(
    self,
    path_to_archive: Path | str,
    friendly_model_name: str,
    input_array: NDArray,
    authors: list[dict],
    general_description: str,
    data_description: str,
    covers: list[Path | str] | None = None,
    channel_names: list[str] | None = None,
    model_version: str = "0.2.0",
) -> None:
    """Export the model to the BioImage Model Zoo format.

    This method packages the current weights into a zip file that can be uploaded
    to the BioImage Model Zoo. The archive consists of the model weights, the model
    specifications and various files (inputs, outputs, README, env.yaml etc.).

    `path_to_archive` should point to a file with a ".zip" extension.

    `friendly_model_name` is the name used for the model in the BMZ specs
    and website, it should consist of letters, numbers, dashes, underscores and
    parentheses only.

    Input array must be of the same dimensions as the axes recorded in the
    configuration of the `CAREamist`.

    Parameters
    ----------
    path_to_archive : pathlib.Path or str
        Path in which to save the model, including file name, which should end with
        ".zip".
    friendly_model_name : str
        Name of the model as used in the BMZ specs, it should consist of letters,
        numbers, dashes, underscores and parentheses only.
    input_array : NDArray
        Input array used to validate the model and as example.
    authors : list of dict
        List of authors of the model.
    general_description : str
        General description of the model used in the BMZ metadata.
    data_description : str
        Description of the data the model was trained on.
    covers : list of pathlib.Path or str, default=None
        Paths to the cover images.
    channel_names : list of str, default=None
        Channel names.
    model_version : str, default="0.1.0"
        Version of the model.
    """
    # from .model_io import export_to_bmz

    # output_patch = self.predict(
    #     pred_data=input_array,
    #     data_type=SupportedData.ARRAY.value,
    # )
    # output = np.concatenate(output_patch, axis=0)
    # input_array = reshape_array(input_array, self.config.data_config.axes)

    # export_to_bmz(
    #     model=self.model,
    #     config=self.config,
    #     path_to_archive=path_to_archive,
    #     model_name=friendly_model_name,
    #     general_description=general_description,
    #     data_description=data_description,
    #     authors=authors,
    #     input_array=input_array,
    #     output_array=output,
    #     covers=covers,
    #     channel_names=channel_names,
    #     model_version=model_version,
    # )
    raise NotImplementedError("Exporting to BMZ is not implemented yet.")

get_losses() #

Return data that can be used to plot train and validation loss curves.

Returns:

Type Description
dict of str: list

Dictionary containing losses for each epoch.

Source code in src/careamics/careamist_v2.py
def get_losses(self) -> dict[str, list]:
    """Return data that can be used to plot train and validation loss curves.

    Returns
    -------
    dict of str: list
        Dictionary containing losses for each epoch.
    """
    return read_csv_logger(
        self.config.get_safe_experiment_name(), self.work_dir / "csv_logs"
    )

predict(pred_data, *, batch_size=None, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, num_workers=None, channels=None, in_memory=None, loading=None) #

predict(pred_data: InputVar, *, batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ReadFuncLoading | None = None) -> tuple[list[NDArray], list[str]]
predict(pred_data: Any, *, batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ImageStackLoading = ...) -> tuple[list[NDArray], list[str]]

Predict on data and return the predictions.

Input can be a path to a data file, a list of paths, a numpy array, or a list of numpy arrays.

If data_type and axes are not provided, the training configuration parameters will be used. If tile_size is not provided, prediction will be performed on the whole image.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. Images smaller than the tile size in any spatial dimension will be automatically zero-padded.

Parameters:

Name Type Description Default
pred_data pathlib.Path, str, numpy.ndarray, or sequence of these

Data to predict on. Can be a single item or a sequence of paths/arrays.

required
batch_size int

Batch size for prediction. If not provided, uses the training configuration batch size.

None
tile_size tuple of int

Size of the tiles to use for prediction. If not provided, prediction will be performed on the whole image.

None
tile_overlap tuple of int

Overlap between tiles, can be None.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, czi, zarr, custom)

Type of the input data.

"array"
num_workers int

Number of workers for the dataloader, by default None.

None
channels sequence of int or "all"

Channels to use from the data. If None, uses the training configuration channels.

None
in_memory bool

Whether to load all data into memory. If None, uses the training configuration setting.

None
loading Loading

Loading strategy to use for the prediction data. May be a ReadFuncLoading or ImageStackLoading. If None, uses the loading strategy from the training configuration.

None

Returns:

Type Description
tuple of (list of NDArray, list of str)

Predictions made by the model and their source identifiers.

Raises:

Type Description
ValueError

If tile overlap is not specified when tile_size is provided.

Source code in src/careamics/careamist_v2.py
def predict(
    self,
    # BASIC PARAMS
    pred_data: InputVar,
    *,
    batch_size: int | None = None,
    tile_size: tuple[int, ...] | None = None,
    tile_overlap: tuple[int, ...] | None = (48, 48),
    axes: str | None = None,
    data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
    # ADVANCED PARAMS
    num_workers: int | None = None,
    channels: Sequence[int] | Literal["all"] | None = None,
    in_memory: bool | None = None,
    loading: Loading = None,
) -> tuple[list[NDArray], list[str]]:
    """
    Predict on data and return the predictions.

    Input can be a path to a data file, a list of paths, a numpy array, or a
    list of numpy arrays.

    If `data_type` and `axes` are not provided, the training configuration
    parameters will be used. If `tile_size` is not provided, prediction will
    be performed on the whole image.

    Note that if you are using a UNet model and tiling, the tile size must be
    divisible in every dimension by 2**d, where d is the depth of the model. This
    avoids artefacts arising from the broken shift invariance induced by the
    pooling layers of the UNet. Images smaller than the tile size in any spatial
    dimension will be automatically zero-padded.

    Parameters
    ----------
    pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
        Data to predict on. Can be a single item or a sequence of paths/arrays.
    batch_size : int, optional
        Batch size for prediction. If not provided, uses the training configuration
        batch size.
    tile_size : tuple of int, optional
        Size of the tiles to use for prediction. If not provided, prediction
        will be performed on the whole image.
    tile_overlap : tuple of int, default=(48, 48)
        Overlap between tiles, can be None.
    axes : str, optional
        Axes of the input data, by default None.
    data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
        Type of the input data.
    num_workers : int, optional
        Number of workers for the dataloader, by default None.
    channels : sequence of int or "all", optional
        Channels to use from the data. If None, uses the training configuration
        channels.
    in_memory : bool, optional
        Whether to load all data into memory. If None, uses the training
        configuration setting.
    loading : Loading, default=None
        Loading strategy to use for the prediction data. May be a ReadFuncLoading or
        ImageStackLoading. If None, uses the loading strategy from the training
        configuration.

    Returns
    -------
    tuple of (list of NDArray, list of str)
        Predictions made by the model and their source identifiers.

    Raises
    ------
    ValueError
        If tile overlap is not specified when tile_size is provided.
    """
    datamodule = self._build_predict_datamodule(
        pred_data,
        batch_size=batch_size,
        tile_size=tile_size,
        tile_overlap=tile_overlap,
        axes=axes,
        data_type=data_type,
        num_workers=num_workers,
        channels=channels,
        in_memory=in_memory,
        loading=loading,
    )

    predictions: list[ImageRegionData] = self.trainer.predict(
        model=self.model, datamodule=datamodule
    )  # type: ignore[assignment]
    tiled = tile_size is not None
    predictions_output, sources = convert_prediction(
        predictions, tiled=tiled, restore_shape=True
    )

    return predictions_output, sources

predict_to_disk(pred_data, *, pred_data_target=None, prediction_dir='predictions', batch_size=None, tile_size=None, tile_overlap=(48, 48), axes=None, data_type=None, num_workers=None, channels=None, in_memory=None, loading=None, write_type='tiff', write_extension=None, write_func=None, write_func_kwargs=None) #

predict_to_disk(pred_data: InputVar, *, pred_data_target: InputVar | None = None, prediction_dir: Path | str = 'predictions', batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ReadFuncLoading | None = None, write_type: Literal['tiff', 'zarr', 'custom'] = 'tiff', write_extension: str | None = None, write_func: WriteFunc | None = None, write_func_kwargs: dict[str, Any] | None = None) -> None
predict_to_disk(pred_data: Any, *, pred_data_target: Any | None = None, prediction_dir: Path | str = 'predictions', batch_size: int | None = None, tile_size: tuple[int, ...] | None = None, tile_overlap: tuple[int, ...] | None = (48, 48), axes: str | None = None, data_type: Literal['array', 'tiff', 'zarr', 'czi', 'custom'] | None = None, num_workers: int | None = None, channels: Sequence[int] | Literal['all'] | None = None, in_memory: bool | None = None, loading: ImageStackLoading = ..., write_type: Literal['tiff', 'zarr', 'custom'] = 'tiff', write_extension: str | None = None, write_func: WriteFunc | None = None, write_func_kwargs: dict[str, Any] | None = None) -> None

Make predictions on the provided data and save outputs to files.

Predictions are saved to prediction_dir (absolute paths are used as-is, relative paths are relative to work_dir). The directory structure matches the source directory.

The file names of the predictions will match those of the source. If there is more than one sample within a file, the samples will be stacked along the sample dimension in the output file.

If data_type and axes are not provided, the training configuration parameters will be used. If tile_size is not provided, prediction will be performed on whole images rather than in a tiled manner.

Note that if you are using a UNet model and tiling, the tile size must be divisible in every dimension by 2**d, where d is the depth of the model. This avoids artefacts arising from the broken shift invariance induced by the pooling layers of the UNet. Images smaller than the tile size in any spatial dimension will be automatically zero-padded.

Parameters:

Name Type Description Default
pred_data pathlib.Path, str, numpy.ndarray, or sequence of these

Data to predict on. Can be a single item or a sequence of paths/arrays.

required
pred_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Prediction data target, by default None.

None
prediction_dir Path | str

The path to save the prediction results to. If prediction_dir is an absolute path, it will be used as-is. If it is a relative path, it will be relative to the pre-set work_dir. If the directory does not exist it will be created.

"predictions"
batch_size int

Batch size for prediction. If not provided, uses the training configuration batch size.

None
tile_size tuple of int

Size of the tiles to use for prediction. If not provided, uses whole image strategy.

None
tile_overlap tuple of int

Overlap between tiles.

(48, 48)
axes str

Axes of the input data, by default None.

None
data_type (array, tiff, czi, zarr, custom)

Type of the input data.

"array"
num_workers int

Number of workers for the dataloader, by default None.

None
channels sequence of int or "all"

Channels to use from the data. If None, uses the training configuration channels.

None
in_memory bool

Whether to load all data into memory. If None, uses the training configuration setting.

None
loading Loading

Loading strategy to use for the prediction data. May be a ReadFuncLoading or ImageStackLoading. If None, uses the loading strategy from the training configuration.

None
write_type (tiff, zarr, custom)

The data type to save as, includes custom.

"tiff"
write_extension str

If a known write_type is selected this argument is ignored. For a custom write_type an extension to save the data with must be passed.

None
write_func WriteFunc

If a known write_type is selected this argument is ignored. For a custom write_type a function to save the data must be passed. See notes below.

None
write_func_kwargs dict of {str: any}

Additional keyword arguments to be passed to the save function.

None

Raises:

Type Description
ValueError

If write_type is custom and write_extension is None.

ValueError

If write_type is custom and write_func is None.

Source code in src/careamics/careamist_v2.py
def predict_to_disk(
    self,
    # BASIC PARAMS
    pred_data: Any,
    *,
    pred_data_target: Any | None = None,
    prediction_dir: Path | str = "predictions",
    batch_size: int | None = None,
    tile_size: tuple[int, ...] | None = None,
    tile_overlap: tuple[int, ...] | None = (48, 48),
    axes: str | None = None,
    data_type: Literal["array", "tiff", "zarr", "czi", "custom"] | None = None,
    # ADVANCED PARAMS
    num_workers: int | None = None,
    channels: Sequence[int] | Literal["all"] | None = None,
    in_memory: bool | None = None,
    loading: Loading = None,
    # WRITE OPTIONS
    write_type: Literal["tiff", "zarr", "custom"] = "tiff",
    write_extension: str | None = None,
    write_func: WriteFunc | None = None,
    write_func_kwargs: dict[str, Any] | None = None,
) -> None:
    """
    Make predictions on the provided data and save outputs to files.

    Predictions are saved to `prediction_dir` (absolute paths are used as-is,
    relative paths are relative to `work_dir`). The directory structure matches
    the source directory.

    The file names of the predictions will match those of the source. If there is
    more than one sample within a file, the samples will be stacked along the sample
    dimension in the output file.

    If `data_type` and `axes` are not provided, the training configuration
    parameters will be used. If `tile_size` is not provided, prediction
    will be performed on whole images rather than in a tiled manner.

    Note that if you are using a UNet model and tiling, the tile size must be
    divisible in every dimension by 2**d, where d is the depth of the model. This
    avoids artefacts arising from the broken shift invariance induced by the
    pooling layers of the UNet. Images smaller than the tile size in any spatial
    dimension will be automatically zero-padded.

    Parameters
    ----------
    pred_data : pathlib.Path, str, numpy.ndarray, or sequence of these
        Data to predict on. Can be a single item or a sequence of paths/arrays.
    pred_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
        Prediction data target, by default None.
    prediction_dir : Path | str, default="predictions"
        The path to save the prediction results to. If `prediction_dir` is an
        absolute path, it will be used as-is. If it is a relative path, it will
        be relative to the pre-set `work_dir`. If the directory does not exist it
        will be created.
    batch_size : int, optional
        Batch size for prediction. If not provided, uses the training configuration
        batch size.
    tile_size : tuple of int, optional
        Size of the tiles to use for prediction. If not provided, uses whole image
        strategy.
    tile_overlap : tuple of int, default=(48, 48)
        Overlap between tiles.
    axes : str, optional
        Axes of the input data, by default None.
    data_type : {"array", "tiff", "czi", "zarr", "custom"}, optional
        Type of the input data.
    num_workers : int, optional
        Number of workers for the dataloader, by default None.
    channels : sequence of int or "all", optional
        Channels to use from the data. If None, uses the training configuration
        channels.
    in_memory : bool, optional
        Whether to load all data into memory. If None, uses the training
        configuration setting.
    loading : Loading, default=None
        Loading strategy to use for the prediction data. May be a ReadFuncLoading or
        ImageStackLoading. If None, uses the loading strategy from the training
        configuration.
    write_type : {"tiff", "zarr", "custom"}, default="tiff"
        The data type to save as, includes custom.
    write_extension : str, optional
        If a known `write_type` is selected this argument is ignored. For a custom
        `write_type` an extension to save the data with must be passed.
    write_func : WriteFunc, optional
        If a known `write_type` is selected this argument is ignored. For a custom
        `write_type` a function to save the data must be passed. See notes below.
    write_func_kwargs : dict of {str: any}, optional
        Additional keyword arguments to be passed to the save function.

    Raises
    ------
    ValueError
        If `write_type` is custom and `write_extension` is None.
    ValueError
        If `write_type` is custom and `write_func` is None.
    """
    if write_func_kwargs is None:
        write_func_kwargs = {}

    if Path(prediction_dir).is_absolute():
        write_dir = Path(prediction_dir)
    else:
        write_dir = self.work_dir / prediction_dir
    self.prediction_writer.dirpath = write_dir

    if write_type == "custom":
        if write_extension is None:
            raise ValueError(
                "A `write_extension` must be provided for custom write types."
            )
        if write_func is None:
            raise ValueError(
                "A `write_func` must be provided for custom write types."
            )
    elif write_type == "zarr" and tile_size is None:
        raise ValueError(
            "Writing prediction to Zarr is only supported with tiling. Please "
            "provide a value for `tile_size`, and optionally `tile_overlap`."
        )
    else:
        write_func = get_write_func(write_type)
        write_extension = SupportedData.get_extension(write_type)

    tiled = tile_size is not None
    self.prediction_writer.set_writing_strategy(
        write_type=write_type,
        tiled=tiled,
        write_func=write_func,
        write_extension=write_extension,
        write_func_kwargs=write_func_kwargs,
    )

    self.prediction_writer.enable_writing(True)

    try:
        datamodule = self._build_predict_datamodule(
            pred_data,
            pred_data_target=pred_data_target,
            batch_size=batch_size,
            tile_size=tile_size,
            tile_overlap=tile_overlap,
            axes=axes,
            data_type=data_type,
            num_workers=num_workers,
            channels=channels,
            in_memory=in_memory,
            loading=loading,
        )

        self.trainer.predict(
            model=self.model, datamodule=datamodule, return_predictions=False
        )

    finally:
        self.prediction_writer.enable_writing(False)

stop_training() #

Stop the training loop.

Source code in src/careamics/careamist_v2.py
def stop_training(self) -> None:
    """Stop the training loop."""
    self.trainer.should_stop = True
    self.trainer.limit_val_batches = 0  # skip validation

train(*, train_data=None, train_data_target=None, val_data=None, val_data_target=None, filtering_mask=None, loading=None) #

train(*, train_data: InputVar | None = None, train_data_target: InputVar | None = None, val_data: InputVar | None = None, val_data_target: InputVar | None = None, filtering_mask: InputVar | None = None, loading: ReadFuncLoading | None = None) -> None
train(*, train_data: Any | None = None, train_data_target: Any | None = None, val_data: Any | None = None, val_data_target: Any | None = None, filtering_mask: Any | None = None, loading: ImageStackLoading = ...) -> None

Train the model on the provided data.

The training data can be provided as arrays or paths.

Parameters:

Name Type Description Default
train_data pathlib.Path, str, numpy.ndarray, or sequence of these

Training data, by default None.

None
train_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Training target data, by default None.

None
val_data pathlib.Path, str, numpy.ndarray, or sequence of these

Validation data. If not provided, data_config.n_val_patches patches will selected from the training data for validation.

None
val_data_target pathlib.Path, str, numpy.ndarray, or sequence of these

Validation target data, by default None.

None
filtering_mask pathlib.Path, str, numpy.ndarray, or sequence of these

Filtering mask for coordinate-based patch filtering, by default None.

None
loading Loading

Loading strategy to use for the prediction data. May be a ReadFuncLoading or ImageStackLoading. If None, uses the loading strategy from the training configuration.

None

Raises:

Type Description
ValueError

If train_data is not provided.

Source code in src/careamics/careamist_v2.py
def train(
    self,
    *,
    # BASIC PARAMS
    train_data: Any | None = None,
    train_data_target: Any | None = None,
    val_data: Any | None = None,
    val_data_target: Any | None = None,
    # ADVANCED PARAMS
    filtering_mask: Any | None = None,
    loading: Loading = None,
) -> None:
    """Train the model on the provided data.

    The training data can be provided as arrays or paths.

    Parameters
    ----------
    train_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Training data, by default None.
    train_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
        Training target data, by default None.
    val_data : pathlib.Path, str, numpy.ndarray, or sequence of these, optional
        Validation data. If not provided, `data_config.n_val_patches` patches will
        selected from the training data for validation.
    val_data_target : pathlib.Path, str, numpy.ndarray, or sequence of these
        Validation target data, by default None.
    filtering_mask : pathlib.Path, str, numpy.ndarray, or sequence of these
        Filtering mask for coordinate-based patch filtering, by default None.
    loading : Loading, default=None
        Loading strategy to use for the prediction data. May be a ReadFuncLoading or
        ImageStackLoading. If None, uses the loading strategy from the training
        configuration.

    Raises
    ------
    ValueError
        If train_data is not provided.
    """
    if train_data is None:
        raise ValueError("Training data must be provided. Provide `train_data`.")

    if self.config.is_supervised() and train_data_target is None:
        raise ValueError(
            f"Training target data must be provided for supervised training (got "
            f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
            f"`train_data_target`."
        )

    if (
        self.config.is_supervised()
        and val_data is not None
        and val_data_target is None
    ):
        raise ValueError(
            f"Validation target data must be provided for supervised training (got "
            f"{self.config.get_algorithm_friendly_name()} algorithm). Provide "
            f"`val_data_target`."
        )

    datamodule = CareamicsDataModule( # type: ignore
        data_config=self.config.data_config,
        train_data=train_data,
        val_data=val_data,
        train_data_target=train_data_target,
        val_data_target=val_data_target,
        train_data_mask=filtering_mask,
        loading=loading, # type: ignore
    )

    self.train_datamodule = datamodule

    # set parameters back to defaults, this is a guard against `stop_training`
    # which changes them in order to interrupt training gracefully
    self.trainer.should_stop = False
    self.trainer.limit_val_batches = 1.0 # equivalent to all validation batches

    self.trainer.fit(
        self.model, datamodule=datamodule, ckpt_path=self.checkpoint_path
    )