Skip to content

feat: add CNN3DAD model for Alzheimer's disease classification#1121

Open
willlee497 wants to merge 9 commits intosunlabuiuc:masterfrom
paul-nguyen-1:paultn2/CNN4AD
Open

feat: add CNN3DAD model for Alzheimer's disease classification#1121
willlee497 wants to merge 9 commits intosunlabuiuc:masterfrom
paul-nguyen-1:paultn2/CNN4AD

Conversation

@willlee497
Copy link
Copy Markdown

@willlee497 willlee497 commented Apr 23, 2026

Contributor

Paul Nguyen (paultn2)

Shayan Jaffar (sjaffa2)

William Lee (wlee2)

Contribution Type

Model

Original Paper

Liu, S., Yadav, C., Fernandez-Granda, C., & Razavian, N. (2020). On the Design of Convolutional Neural Networks for Automatic Detection of Alzheimer's Disease. Machine Learning for Health Workshop, PMLR 116:184–201. (Link: http://proceedings.mlr.press/v116/liu20a)

Code reference: https://github.com/NYUMedML/CNN_design_for_AD

Summary

This PR adds CNN3DAD, a PyHealth-native 3D convolutional neural network for Alzheimer's disease classification from structural MRI scans. The model performs 3-class classification: Cognitively Normal (CN), Mild Cognitive Impairment (MCI), and Alzheimer's Disease (AD).

The implementation includes:

  • 4-block dilated convolutional backbone with instance normalization, matching Table 2 of the paper
  • Configurable widening factor for channel scaling (default f=8, the paper's best-performing variant at 66.9% accuracy)
  • Sinusoidal age encoding with MLP fusion (positional encoding → Linear → LayerNorm → Linear, added residually to the feature vector)
  • Adaptive global average pooling for input-size robustness
  • Optional class-weighted cross-entropy loss for handling class imbalance (CN/MCI/AD)
  • Full integration with PyHealth's BaseModel, SampleDataset, create_sample_dataset, and get_dataloader APIs

Architecture Details (from Table 2 of the paper)

Block Kernel Channels (×f) Dilation Padding Pool Kernel (stride=2)
1 1 4 1 0 3
2 3 32 2 0 3
3 5 64 2 2 3
4 3 64 2 1 5

Files to Review

  • pyhealth/models/cnn3d_ad.py → model implementation (_make_norm, ConvBlock3D, CNN3DAD)
  • pyhealth/models/__init__.py → module registration
  • tests/core/test_cnn3d_ad.py → unit tests
  • examples/adni_alzheimer_cnn3dad.py → example usage and ablation study
  • docs/api/models/pyhealth.models.cnn3d_ad.rst → API documentation
  • docs/api/models.rst → imported model

Validation

  • Unit tests pass (python -m pytest tests/core/test_cnn3d_ad.py -v)
  • Smoke test runs successfully (python pyhealth/models/cnn3d_ad.py)
  • Example script demonstrates model instantiation, forward pass, and ablation across widening factors and age encoding configurations
  • Tests cover: normalization helper, ConvBlock3D shapes and activations, model instantiation with various configs, forward pass output correctness (keys, shapes, softmax constraints, label passthrough), and gradient flow through backbone, age MLP, and classifier

Notes

  • This implementation is based on the reference code at NYUMedML/CNN_design_for_AD and verified against the paper's Table 2 architecture specification
  • The age encoding MLP follows the reference's AgeEncoding.fc6 structure: Linear → LayerNorm → Linear with no ReLU between layers
  • AdaptiveAvgPool3d(1) is used instead of the reference's hardcoded 64*exp*5*5*5 flatten, making the model robust to varying input spatial dimensions
  • Optional class_weights parameter supports weighted cross-entropy for handling the class imbalance typical in ADNI datasets

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants