๐Ÿง  Training a Model

This guide walks you through everything needed to run a training job: prerequisites, configuration, launching training, and understanding what happens at each stage.


โœ… Prerequisites

๐Ÿงช Environment
Ensure all dependencies are installed. The pipeline requires:

  • PyTorch
  • Hugging Face Accelerate
  • Kornia
  • Weights & Biases (WandB)

๐Ÿ“‚ Data
Training expects:

  • A CSV file with predefined train / val split columns
  • A root directory containing the image data

The exact CSV schema depends on the dataset (EMBED or CSAW).

โš™๏ธ Model configuration
Each model has a YAML configuration file under config/models/. For example, LMV-Net uses:

config/models/lmv_net.yaml

These files define model-specific hyperparameters (e.g. distance, alpha_coeff, margin, dropout, num_attn_blocks, etc.) that extend the base CLI arguments. If no YAML is found, the pipeline falls back to CLI defaults.

๐Ÿงฉ Registration model
ImgFeatAlign and LMV-Net require a pretrained MammoRegNet registration model. The path is defined in:

config/config.py
  • paths.csaw_path_saved_reg_model
  • paths.embed_path_saved_reg_model

(depending on the dataset)


๐Ÿš€ Running Training

Each model has a dedicated shell script that sets all required arguments:

bash scripts/train_lmv_net.sh
bash scripts/train_imgfeatalign.sh
bash scripts/train_vmra_mar.sh
bash scripts/train_oa_breacr.sh
bash scripts/train_mirai.sh

๐Ÿ‘‰ Use accelerate launch (instead of python) to enable multi-GPU training.


๐Ÿงพ CLI Arguments

๐Ÿ”ด Required

Argument Description
--model Model name (e.g. Mirai, OA-BreaCR, VMRA-MaR, ImgFeatAlign, LMV-Net)
--dataset Dataset name: EMBED or CSAW
--csv_file Path to CSV file containing data splits
--data_root Root directory of image data
--path_out_dir Base output directory (timestamped subdirectory is created automatically)

๐ŸŸก Key Optional

Argument Default Description
--num_epochs 100 Number of training epochs
--batch_size 12 Batch size per GPU
--learning_rate 5e-5 Learning rate for newly added modules
--weight_decay 1e-4 AdamW weight decay
--warmup_steps 5000 Number of warmup steps
--use_scheduler True Enable ReduceLROnPlateau
--patience_lr_scheduler 5 Epochs before LR reduction
--lr_decay 0.5 LR reduction factor
--patience 15 Early stopping patience
--augmentations True Enable data augmentation
--resume_from None Path to checkpoint
--wandb_id None WandB run ID for resuming
--seed 2023 Random seed

โš™๏ธ Configuration System

Argument loading occurs in two stages:

  1. CLI arguments are parsed first (including --model)
  2. Model YAML config is loaded from:
config/models/<model_name>.yaml

YAML values are added as CLI defaults and can always be overridden.

Example:

--dropout 0.3

overrides:

dropout: 0.1

๐Ÿ”„ Training Pipeline

When training starts, the following steps occur:

1๏ธโƒฃ Argument Parsing & Setup

main_train.py:

  • Parses CLI arguments
  • Loads YAML config
  • Creates a timestamped output directory:
{path_out_dir}_Model_{model}_lr_{lr}_wd_{wd}_epochs_{n}_bs_{bs}_{timestamp}/

2๏ธโƒฃ Accelerator Initialisation

A Hugging Face Accelerator is created:

This enables:

  • Multi-GPU training
  • Mixed precision
  • Automatic device placement
  • Gradient synchronisation

3๏ธโƒฃ ๐Ÿ” Reproducibility

Seeds are set for:

  • random
  • torch
  • torch.cuda

cuDNN runs in deterministic mode.


4๏ธโƒฃ ๐Ÿ“ฆ Data Loading

get_dataset_and_loader():

  • Creates training and validation DataLoaders
  • Automatically shards training data across GPUs

๐Ÿ‘‰ No augmentation is applied to validation data


5๏ธโƒฃ ๐Ÿง  Model Initialisation

Models are loaded via:

models/model_factory.py

Differential learning rates:

  • ๐Ÿ”น New modules โ†’ learning_rate
  • ๐Ÿ”น Encoder โ†’ learning_rate ร— 0.1 (for pretrained Mirai encoder)
  • ๐Ÿ”น OA-BreaCR โ†’ Encoder trained from scratch (no reduction)

This prevents destabilising pretrained encoders.


6๏ธโƒฃ ๐Ÿ“‰ Optimizer & Schedulers

Optimizer:
- AdamW over parameter groups

Schedulers:

  • ๐Ÿ”ฅ Warmup (LambdaLR)

    • Linear increase from 0 โ†’ base LR
    • Steps every batch
  • โณ Plateau (ReduceLROnPlateau, optional)

    • Reduces LR when validation C-index plateaus
    • Activates only after warmup

All components are wrapped with:

accelerator.prepare()

7๏ธโƒฃ ๐Ÿ” Epoch Loop

Each epoch consists of:

๐Ÿ‹๏ธ Training (train_one_epoch)

  • Forward pass
  • Loss computation
  • accelerator.backward()
  • Optimizer step
  • Warmup scheduler step (per batch)

Metrics:

  • C-index
  • Per-year AUC

๐Ÿงช Validation (evaluate)

  • Runs under torch.no_grad()
  • Computes:
    • Loss
    • C-index
    • Per-year AUC

๐Ÿ“Š Logging

  • Written to log file
  • Logged to WandB (main process only)

๐Ÿ“‰ Scheduler Step

  • Plateau scheduler updates after warmup
  • Uses validation C-index

๐Ÿ’พ Checkpointing & Early Stopping

  • Handled automatically (see below)

๐Ÿ“‰ Learning Rate Scheduling

Two-phase schedule:

Steps 0 โ†’ warmup_steps     LR increases linearly (per batch)
Steps > warmup_steps       LR constant or reduced on plateau (per epoch)

The plateau scheduler is disabled during warmup to avoid premature LR decay.


๐Ÿ’พ Checkpointing

Checkpoints are saved:

  • ๐Ÿ“Œ Every 10 epochs
  • ๐Ÿ† When a new best validation C-index is achieved

Files:

checkpoint_{epoch:04d}.pth
best_model_risk_prediction_id-{id}.pth

Each checkpoint includes:

  • Model state
  • Optimiser state
  • Scheduler states
  • Epoch
  • Global step
  • Best C-index

๐Ÿ”„ Resume Training

accelerate launch main_train.py \
  ... \
  --resume_from /path/to/checkpoint_0050.pth \
  --wandb_id <run-id>

๐Ÿ‘‰ --wandb_id ensures metrics continue in the same experiment


โน๏ธ Early Stopping

Triggered if validation C-index does not improve for --patience epochs.

Saved as:

early_stopping_risk_prediction_id-{id}.pth

๐Ÿ‘‰ Best model is always saved separately


๐Ÿ“ Output Directory

{results_dir}/
  train_risk_prediction_training_id_{id}.log
  checkpoint_0010.pth
  checkpoint_0020.pth
  ...
  best_model_risk_prediction_id-{id}.pth
  early_stopping_risk_prediction_id-{id}.pth
  model_risk_prediction_training_id_{id}_last_epoch.pth

๐Ÿ–ผ๏ธ Data Augmentation

Enable with:

--augmentations

Applied only to training data:

  • RandomCrop (1946 ร— 1581), p=0.2
  • Resize (2048 ร— 1664)
  • RandomAffine (translation ยฑ10%, scale up to 1.1ร—), p=0.5
  • ColorJitter (ยฑ0.4), p=0.5
  • RandomGamma (0.8โ€“1.2), p=0.5

๐Ÿ‘‰ No augmentation is applied to validation data


๐Ÿ“Š Experiment Tracking (WandB)

Logged per epoch:

  • Training & Validation Loss
  • Training & Validation C-index
  • Year 1โ€“5 AUC (train + validation)

๐Ÿ‘‰ All metrics use epoch as the step

๐Ÿ”„ Resume Logging

--wandb_id <run-id>

This appends metrics to an existing run instead of creating a new one.