π§ Adding a New Model to the Benchmarking Framework
This repository provides a unified pipeline for training and evaluating deep learning models for risk prediction. To ensure compatibility, all new models must follow a standard structure and interface.
This guide explains how to integrate a new model into the framework.
π¦ Overview
To add a new model, you need to:
1. Create a new model module
2. Inherit from the base model class
3. Define model-specific configuration
4. Register the model in the factory
π 1. Create a Model Folder
Navigate to the models/ directory and create a new subfolder for your model:
models/
βββ your_model_name/
βββ model.py
βββ model_utils.py
Recommended structure:
model.pyβ Main model implementationmodel_utils.pyβ Helper functions, custom layers, utilities
π§© 2. Inherit from BaseRiskModel
All models must inherit from the shared base class: models/common_parts/base_models.py
Base Class Interface
class BaseRiskModel(nn.Module, ABC):
def __init__(self, args):
super().__init__()
self.args = args
@abstractmethod
def forward(self, batch):
"""Run forward pass. Returns dict of outputs."""
pass
@abstractmethod
def get_risk_heads(self, outputs, batch):
"""
Returns dict of {head_name: (logits, target, mask)}
used for loss computation.
"""
pass
@abstractmethod
def get_primary_risk_head(self, outputs):
"""
Returns main prediction tensor used for evaluation
(e.g., AUC, C-index).
"""
pass
Example:
from models.common_parts.base_models import BaseRiskModel
class YourModel(BaseRiskModel):
def __init__(self, args):
super().__init__(args)
# define layers
def forward(self, batch):
# define flow of data through the network
outputs = {"logit": logits, "fcur": f_cur, "fpri": f_pri}
return outputs
def get_risk_heads(self, outputs, batch):
target = batch["target"]
mask = batch["y_mask"]
return {
"main": (outputs["logit"], target, mask)
}
def get_primary_risk_head(self, outputs):
return torch.sigmoid(outputs["logit"])
βοΈ 3. Add Model Configuration
Create a YAML configuration file for your model in:
config/models/your_model_name.yaml
Purpose:
Store model-specific hyperparameters that are:
- fixed by default
- but configurable by users
Example:
model_name: your_model_name
dropout: 0.3
num_heads: 4
hidden_dim: 256
num_layers: 3
π 4. Register the Model in model_factory.py
To make your model available in the pipeline, register it in:
models/model_factory.py
Step 1: Add a builder function
def _build_your_model():
from models.your_model_name.model import YourModel
return _build_model(YourModel, args=args, **kwargs)
Step 2: Add it to the registry
MODEL_REGISTRY = {
"Mirai": _build_mirai,
"ImgFeatAlign": _build_imgfeatalign,
"LMV-Net": _build_lmvnet,
"VMRA-MaR": _build_vmramar,
"OA-BreaCR": _build_oa_breacr,
"YourModel": _build_your_model, # β add here
}
β οΈ Special Case: Registration-Based Models
If your model uses image registration (e.g., MammoRegNet), you must:
1: Add your model name to:
REGISTRATION_MODELS = {"ImgFeatAlign", "LMV-Net", "YourModel"}
2: Accept mammo_reg_net in your constructor:
def __init__(self, mammo_reg_net=None, args=None):
super().__init__(args)
self.mammo_reg_net = mammo_reg_net
β Final Checklist
Before using your model, ensure:
β Model folder created in models/
β Inherits from BaseRiskModel
β Implements required methods:
β’ __init__()
β’ forward()
β’ get_risk_heads()
β’ get_primary_risk_head()
β YAML config added in config/models/
β Model registered in model_factory.py
β (If applicable) Added to REGISTRATION_MODELS
β Model runs with a small test batch without errors
β Output format matches expected pipeline interface
π‘ Tips
- Keep your model modular and readable
- Use model_utils.py for reusable components
- Follow naming conventions for consistency
- Test your model with a small batch before full training
π€ Contributing
If youβre adding a new model for benchmarking:
- Ensure it follows this structure
- Provide a short description of the model (add a model page in the Documentation)
- Optionally include references or related papers