feat: add CNN3DAD model for Alzheimer's disease classification#1121
Open
willlee497 wants to merge 9 commits intosunlabuiuc:masterfrom
Open
feat: add CNN3DAD model for Alzheimer's disease classification#1121willlee497 wants to merge 9 commits intosunlabuiuc:masterfrom
willlee497 wants to merge 9 commits intosunlabuiuc:masterfrom
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Contributor
Paul Nguyen (paultn2)
Shayan Jaffar (sjaffa2)
William Lee (wlee2)
Contribution Type
Model
Original Paper
Liu, S., Yadav, C., Fernandez-Granda, C., & Razavian, N. (2020). On the Design of Convolutional Neural Networks for Automatic Detection of Alzheimer's Disease. Machine Learning for Health Workshop, PMLR 116:184–201. (Link: http://proceedings.mlr.press/v116/liu20a)
Code reference: https://github.com/NYUMedML/CNN_design_for_AD
Summary
This PR adds CNN3DAD, a PyHealth-native 3D convolutional neural network for Alzheimer's disease classification from structural MRI scans. The model performs 3-class classification: Cognitively Normal (CN), Mild Cognitive Impairment (MCI), and Alzheimer's Disease (AD).
The implementation includes:
BaseModel,SampleDataset,create_sample_dataset, andget_dataloaderAPIsArchitecture Details (from Table 2 of the paper)
Files to Review
pyhealth/models/cnn3d_ad.py→ model implementation (_make_norm,ConvBlock3D,CNN3DAD)pyhealth/models/__init__.py→ module registrationtests/core/test_cnn3d_ad.py→ unit testsexamples/adni_alzheimer_cnn3dad.py→ example usage and ablation studydocs/api/models/pyhealth.models.cnn3d_ad.rst→ API documentationdocs/api/models.rst→ imported modelValidation
python -m pytest tests/core/test_cnn3d_ad.py -v)python pyhealth/models/cnn3d_ad.py)Notes
AgeEncoding.fc6structure: Linear → LayerNorm → Linear with no ReLU between layersAdaptiveAvgPool3d(1)is used instead of the reference's hardcoded64*exp*5*5*5flatten, making the model robust to varying input spatial dimensionsclass_weightsparameter supports weighted cross-entropy for handling the class imbalance typical in ADNI datasets