2

Brain Age Prediction Using Deep CNNs

Predicting physiological brain age from MRI scans using transfer learning with ResNet and custom 2D/3D convolutional neural networks

Introduction

We investigated the application of computer vision in predicting brain age from MRI scans. Neurodegenerative disorders like Alzheimer's disease affect an estimated 5% of people older than 65 and 30% of those older than 85. Using CNNs, we predict the physiological age of brains through MRI data to enable early detection of such disorders.

By pre-processing MRI datasets through FMRIB Software Library (FSL) and training convolutional neural networks — both pre-trained ResNets via transfer learning and custom architectures from scratch — we achieved validation errors as low as 1–4 years and test errors of 5–7 years, comparable to state-of-the-art performance.

Dataset

We collected brain MRI data from three sources totaling 3,321 subjects:

  • IXI Dataset: 619 healthy subjects from 3 London hospitals (brain-development.org)
  • FCP Dataset: 1,206 subjects from 32 hospitals worldwide (International Neuroimaging Data-Sharing Initiative)
  • ADNI Dataset: 1,724 subjects including normal subjects and patients with Alzheimer's Disease and Mild Cognitive Impairment

After cleaning (removing missing data, duplicates, and outliers), we retained 563 IXI subjects, 1,034 FCP subjects, and all 1,724 ADNI subjects.

Data Pre-processing Pipeline

We used FSL's Anatomical Processing Script (fsl_anat) with five key steps:

  1. Automatic cropping — Remove neck and lower head portions via robustfov
  2. Bias-field correction — Correct spatial intensity variations using FAST
  3. Linear registration — Register to standard space using FIRST
  4. Brain extraction — Remove non-brain tissue using BET
  5. Tissue-type segmentation — Segment into Grey Matter, White Matter, and CSF using FAST

Methods

Baseline: Linear Regression

We constructed histograms of skull-stripped MRI voxel intensities (capturing grey and white matter distributions) and fed them into a RidgeCV regression model, achieving a mean absolute error of 7.99 years.

Transfer Learning: ResNet18 and ResNet50

Since MRI scans are 3D grayscale images while ResNets expect 2D RGB inputs, we:

  • Sliced across the axial dimension for a 2D representation
  • Reshaped to 224×224 with center-cropping or zero-padding
  • Stacked the grayscale slice three times for the RGB channels

We replaced the final fully connected layers with:

  • Regression: Single-neuron linear layer, trained with MSE loss
  • Classification: 14-class Softmax layer (5-year age bins starting from 18)

All pre-trained layers were unfrozen for fine-tuning. Best hyperparameters: batch size 4, learning rate 0.001 with 0.1 decay per 5 epochs, Nesterov GD with momentum 0.9.

Custom CNN Architectures

We also trained two models from scratch to preserve the full dimensionality of MRI data:

10-Layer 2D CNN: Five blocks of paired Conv2d layers with BatchNorm and AvgPool, progressively increasing channels (32 → 64 → 128 → 256 → 512), followed by a linear regression head.

6-Layer 3D CNN: Three blocks of paired Conv3d layers with BatchNorm, ReLU, and AvgPool3d (64 → 128 → 256 channels), preserving the full volumetric information of the MRI scans.

Results

Regression Models

ModelValidation MAETest MAE
Linear Regression7.99
ResNet183.807.76
ResNet18 + Data Aug.4.136.41
ResNet501.675.83
ResNet50 + Data Aug.0.377.40
10-Layer 2D CNN2.2011.00
6-Layer 3D CNN10.1110.89

Classification Models

ModelValidation Acc.Test Acc.
ResNet1849.84%26.45%
ResNet18 + Data Aug.50.39%48.23%
ResNet5060.07%54.14%
ResNet50 + Data Aug.65.78%55.23%

ResNet50 with data augmentation (combining all three datasets) achieved the best results across both regression and classification tasks.

Saliency Map Analysis

We generated saliency maps to understand what the network learns. The network focuses most on grey matter regions and the frontal cortex — consistent with neuroscience findings about brain aging biomarkers.

Discussion

  • Regression outperformed classification for brain age estimation, likely due to the continuous nature of aging
  • Data augmentation across all three datasets was critical — training on IXI alone led to overfitting due to limited samples
  • Age class imbalance in FCP dataset (skewed toward 18–28 age range) degraded classification performance
  • Custom 3D CNN models showed promise but needed more layers and training time to compete with transfer learning approaches

Contributions

  • Zixian Ma: Data pre-processing pipeline, FSL integration
  • Prabhjot Singh Rai: AWS infrastructure and DevOps for the project, model training
  • Harry (Xinxuan) Jiang: Model architecture analysis and tuning

References

  1. Cole et al. "Predicting brain age with deep learning from raw imaging data results in a reliable and heritable biomarker." NeuroImage, 2017.
  2. He et al. "Deep Residual Learning for Image Recognition." CoRR, 2015.
  3. Jenkinson et al. "FSL." NeuroImage, 2012.
  4. Islam & Zhang. "Brain MRI analysis for Alzheimer's disease diagnosis using an ensemble system of deep CNNs." Brain Informatics, 2018.