Introduction
We investigated the application of computer vision in predicting brain age from MRI scans. Neurodegenerative disorders like Alzheimer's disease affect an estimated 5% of people older than 65 and 30% of those older than 85. Using CNNs, we predict the physiological age of brains through MRI data to enable early detection of such disorders.
By pre-processing MRI datasets through FMRIB Software Library (FSL) and training convolutional neural networks — both pre-trained ResNets via transfer learning and custom architectures from scratch — we achieved validation errors as low as 1–4 years and test errors of 5–7 years, comparable to state-of-the-art performance.
Dataset
We collected brain MRI data from three sources totaling 3,321 subjects:
- IXI Dataset: 619 healthy subjects from 3 London hospitals (brain-development.org)
- FCP Dataset: 1,206 subjects from 32 hospitals worldwide (International Neuroimaging Data-Sharing Initiative)
- ADNI Dataset: 1,724 subjects including normal subjects and patients with Alzheimer's Disease and Mild Cognitive Impairment
After cleaning (removing missing data, duplicates, and outliers), we retained 563 IXI subjects, 1,034 FCP subjects, and all 1,724 ADNI subjects.
Data Pre-processing Pipeline
We used FSL's Anatomical Processing Script (fsl_anat) with five key steps:
- Automatic cropping — Remove neck and lower head portions via
robustfov - Bias-field correction — Correct spatial intensity variations using FAST
- Linear registration — Register to standard space using FIRST
- Brain extraction — Remove non-brain tissue using BET
- Tissue-type segmentation — Segment into Grey Matter, White Matter, and CSF using FAST
Methods
Baseline: Linear Regression
We constructed histograms of skull-stripped MRI voxel intensities (capturing grey and white matter distributions) and fed them into a RidgeCV regression model, achieving a mean absolute error of 7.99 years.
Transfer Learning: ResNet18 and ResNet50
Since MRI scans are 3D grayscale images while ResNets expect 2D RGB inputs, we:
- Sliced across the axial dimension for a 2D representation
- Reshaped to 224×224 with center-cropping or zero-padding
- Stacked the grayscale slice three times for the RGB channels
We replaced the final fully connected layers with:
- Regression: Single-neuron linear layer, trained with MSE loss
- Classification: 14-class Softmax layer (5-year age bins starting from 18)
All pre-trained layers were unfrozen for fine-tuning. Best hyperparameters: batch size 4, learning rate 0.001 with 0.1 decay per 5 epochs, Nesterov GD with momentum 0.9.
Custom CNN Architectures
We also trained two models from scratch to preserve the full dimensionality of MRI data:
10-Layer 2D CNN: Five blocks of paired Conv2d layers with BatchNorm and AvgPool, progressively increasing channels (32 → 64 → 128 → 256 → 512), followed by a linear regression head.
6-Layer 3D CNN: Three blocks of paired Conv3d layers with BatchNorm, ReLU, and AvgPool3d (64 → 128 → 256 channels), preserving the full volumetric information of the MRI scans.
Results
Regression Models
| Model | Validation MAE | Test MAE |
|---|---|---|
| Linear Regression | — | 7.99 |
| ResNet18 | 3.80 | 7.76 |
| ResNet18 + Data Aug. | 4.13 | 6.41 |
| ResNet50 | 1.67 | 5.83 |
| ResNet50 + Data Aug. | 0.37 | 7.40 |
| 10-Layer 2D CNN | 2.20 | 11.00 |
| 6-Layer 3D CNN | 10.11 | 10.89 |
Classification Models
| Model | Validation Acc. | Test Acc. |
|---|---|---|
| ResNet18 | 49.84% | 26.45% |
| ResNet18 + Data Aug. | 50.39% | 48.23% |
| ResNet50 | 60.07% | 54.14% |
| ResNet50 + Data Aug. | 65.78% | 55.23% |
ResNet50 with data augmentation (combining all three datasets) achieved the best results across both regression and classification tasks.
Saliency Map Analysis
We generated saliency maps to understand what the network learns. The network focuses most on grey matter regions and the frontal cortex — consistent with neuroscience findings about brain aging biomarkers.
Discussion
- Regression outperformed classification for brain age estimation, likely due to the continuous nature of aging
- Data augmentation across all three datasets was critical — training on IXI alone led to overfitting due to limited samples
- Age class imbalance in FCP dataset (skewed toward 18–28 age range) degraded classification performance
- Custom 3D CNN models showed promise but needed more layers and training time to compete with transfer learning approaches
Contributions
- Zixian Ma: Data pre-processing pipeline, FSL integration
- Prabhjot Singh Rai: AWS infrastructure and DevOps for the project, model training
- Harry (Xinxuan) Jiang: Model architecture analysis and tuning
References
- Cole et al. "Predicting brain age with deep learning from raw imaging data results in a reliable and heritable biomarker." NeuroImage, 2017.
- He et al. "Deep Residual Learning for Image Recognition." CoRR, 2015.
- Jenkinson et al. "FSL." NeuroImage, 2012.
- Islam & Zhang. "Brain MRI analysis for Alzheimer's disease diagnosis using an ensemble system of deep CNNs." Brain Informatics, 2018.