机器学习-基于CNN的青光眼检测模型


小组成员:潘序(组长),梁禹,万立志,方宇豪

一、模型背景

青光眼,作为一种导致不可逆视力丧失的眼科疾病,在全球范围内对人类健康构成了重大威胁。
据估计,到2020年,全球将有超过1100万人因青光眼而失明。鉴于其早期症状不明显,早期诊断对于预防视力损害至关重要。
然而,专业眼科医生的缺乏,尤其是在偏远地区,限制了青光眼的早期筛查和治疗。
为了解决这一问题,我们小组设计了一种基于卷积神经网络(CNN)的青光眼眼底病变检测模型——EyeNet。

数据集链接:Fundus Glaucoma Detection Data [PyTorch format]
Github:EyeNet: A Convolutional Neural Network for Glaucomatous Fundus Lesion Detection

二、环境测试

测试当前环境是否可用GPU训练。

三、模型结构

model
在EyeNet中,卷积层由多个Conv2d层组成,负责提取输入图像的特征。
第一个卷积层将输入通道数从3(RGB图像)增加到64,使用3x3的卷积核;
第二个卷积层将通道数从64增加到128;
第三个卷积层保持通道数为128;
第四个卷积层将通道数从128增加到256;
第五个卷积层保持通道数为256。
每个卷积层后面都紧跟一个ReLU激活函数,用于引入非线性。

EyeNet中的池化层用于降低特征图的空间维度,以减少计算量并增加感受野,同时也提高了模型对小的位置变化的不变性。
在EyeNet中,池化层使用了2x2的池化核和2的步长,使得输出为输入大小的1/4。

除了最后一个卷积层,ReLU激活函数在EyeNet的每个卷积层后面使用,这是因为ReLU能够在训练过程中提供快速的收敛速度,并且减轻梯度消失的问题。

打印模型结构:

四、数据集加载器

本模块包含了一个最基本的PyTorch的数据集加载器和一个getROI(image)函数。

(一)相关库的引入

(二)getROI(image)函数

经过小组查阅资料,青光眼病变时医生通常根据眼底图像,一般是后视网膜图像中的眼底视盘区域的形状改变进行诊断。
这块区域在图像中呈现的是一个明显的偏亮的圆形或椭圆形区域。
因此,我们使用getROI函数,找到图像中平均像素值最高的一个200x200区域并返回。
这是一种数据增强的方式,我们小组通过这种方法大大缩短了神经网络的训练时间,并且能够防止模型的过拟合。
getROI()函数示意图如下:
getroi

(三)PyTorch数据集加载器GlaucomaDataset(Dataset)

根据数据集的摆放方式创建图像-标签表供模型训练和评估。PyTorch规定了这一类的编写方法。

五、训练模式

(一)我们使用了:

1、学习率:首先设置为0.01,发现模型loss几乎不下降,于是逐步降低至0.0001发现模型loss下降,且在测试集的正确率上升。最后训练到模型准确率到95%左右时发现loss和accuracy都停滞,于是继续减小学习率,直至模型历史正确率高达98.56%。
2、损失函数:由于我们拟解决的是一个二分类问题,于是我们使用了二元交叉熵损失函数。
3、优化器:我们使用了Adam优化器,来优化反向传播时参数调整的过程。

前期工作结果:
训练至第150 epoch的模型(准确率98.56%),其历史准确率折线图如下:
accuracy
在第30epoch训练完毕后,发现准确度不上升,调整学习率后发现准确率高速上升。

(二)train(pre_epochs, epochs)函数:

首先看train函数输入,如果不是从0 epoch开始训练的则使用预训练的权重。然后使用数据集加载器加载的图像-标签对,对模型进行训练。具体的训练过程依次包括梯度归零、前向传播计算输出、损失函数计算、反向传播和优化器优化。
每一个epoch训练完毕后保存一次模型,并且调用下一个Cell中的eva()函数,使用本epoch的模型对测试集图像标签对进行准确率评估。如果本epoch训练的结果模型准确度创下整个训练过程的历史记录,则将本模型保存为best model。

六、评估模式(隐藏模式)

本模式只能由train(pre_epochs, epochs)函数调用,不可由使用者直接调用。
如果想对单张图片进行青光眼诊断,则可以使用下一Cell的diagnose()函数。

七、诊断模式(直接调用模式)

可以通过本模式,对指定的图像,使用我们的最高准确率模型进行青光眼诊断。
本Cell中使用了数据集中archive/val中的后视网膜图像,并且取出其中8张进行诊断结果可视化。