[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%

Posted Tina姐

tags:

篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%相关的知识,希望对你有一定的参考价值。

本次实战的的概况如下:

  • 代码来源:https://github.com/junaidiqbalsyed/Covid_detection_CNN
  • 目的:使用 CNN(vgg and resnet)检测 COVID 并使用GRAD-CAM进行可视化
  • 方法: 二分类(normal or covid)
  • 框架: keras (不会没关系,很简单)
  • 结果:源代码在VGG-16的准确率为 85%, 我使用resnet跑的结果为88%,甚至90%。
  • 难易程度: ⭐️⭐️

结果展示

  • resnet准确度

  • 可视化结果

感兴趣的话一起进入代码环节吧~~


1 准备工作

  • 在github上下载源代码
  • 下载数据集(约800M)左右。
    如果你网络够好(可随意打开GitHub,Google网站),下载数据集这一步可省略,后面通过代码下载。如果网络不好,建议先下载数据集并解压。
  • 配置环境
# 查看版本
import tensorflow as tf
import keras
print(tf.__version__) # 2.4.1
print(keras.__version__)  # 2.4.3

注意:这里有个坑,当我在 tf版本 2.0.0 keras 版本 2.3.1运行时,准确度一直在0.5左右徘徊,没有提升。如果你遇到了相同的问题,请重新建一个环境。

2 使用jupyter运行代码

文件:Covid_detection_using_chest_X_Ray(using_ResNet_50)%20(2).ipynb

这部分主要解读代码。

导入包的部分自动略过~~

2.1 数据集是否下载

如果刚开始你没有下载数据集,可运行 cell2 and cell3 下载数据集及解压

!wget https://www.dropbox.com/s/e1r2laj50nh4tez/COVID-19_Radiography_Dataset.zip?dl=0
!unzip "/content/COVID-19_Radiography_Dataset.zip?dl=0"

如果已经下载好,则这两步省略~~~~

2.2 探索数据

直接运行代码即可,不过多介绍。

可以简单探索下:数据的名字,格式,大小,以及下载来源。比如这里的covid数据来自sirm网站。

2.3 把所有类别的数据放在同一个文件夹并可视化

下载好的数据集是按照疾病分类的。这里我们把所有图像放在同一个文件夹。

ROOT_DIR = "/content/COVID-19_Radiography_Dataset/"
imgs = ['COVID','Lung_Opacity','Normal','Viral Pneumonia']

NEW_DIR = "content/all_images/"

需要注意这部分代码的地址。作者在地址前加了”/“,如/content,表示content目录在根目录中,如果你的数据不在根目录,就不要加前面的”/“。

同理,如果后文出现了找不到文件夹,很有可能是你这里加了”/“, 去掉即可。

可视化看看数据集的分布

下载的数据集有4类,但是我们代码中实际上只用到了2类(COVID 和 Normal)

2.4 把数据集分成训练集、验证集和测试集

if not os.path.exists(NEW_DIR+"train_test_split/"):

  os.makedirs(NEW_DIR+"train_test_split/")
  .....
  os.makedirs(NEW_DIR+"train_test_split/validation/Covid")


  # Train Data
  for i in np.random.choice(replace= False , size= 3000 , a = glob.glob(NEW_DIR+imgs[0]+"*") ):
    shutil.copy(i , NEW_DIR+"train_test_split/train/Covid" )
    os.remove(i)

  for i in np.random.choice(replace= False , size= 3900 , a = glob.glob(NEW_DIR+imgs[2]+"*") ):
    shutil.copy(i , NEW_DIR+"train_test_split/train/Normal" )
    os.remove(i)
    ....   

这部分代码里有很多可以学习的地方:

比如,我们要在所有covid图像中,随机选取3000个作为训练集。怎么做到?

答案: np.random.choice()

当这部分选取作为训练集后,如何保证这部分数据在验证集和测试集中选不到它。

答案:os.remove(i) 当被选取后,删除它,那么再选择的时候就选择不到它了。

现在,数据集有了。

2.5 为keras生成数据流

train_path  = "content/all_images/train_test_split/train"
valid_path  = "content/all_images/train_test_split/validation"
test_path   = "content/all_images/train_test_split/test"

这是我们数据集存放的地址。再强调一次,content前面的’/‘我已经去掉。

为各数据集生成keras可识别的数据

train_data_gen = ImageDataGenerator(preprocessing_function= preprocess_input, 
                                    zoom_range= 0.2, 
                                    horizontal_flip= True, 
                                    shear_range= 0.2,
                                    
                                    )

train = train_data_gen.flow_from_directory(directory= train_path, 
                                           target_size=(224,224))

训练集中: Found 7800 images belonging to 2 classes.

2.6 构建模型

res = ResNet50( input_shape=(224,224,3), include_top= False) 
# include_top will consider the new weights

include_top= False表示不要全连接层,只加载特征部分。

这里加载的预训练权重是在:https://storage.googleapis.com

注意
我用其他版本keras加载的权重是在:https://github.com/
不同的版本会有区别,也会影响到结果。我也不知道为啥。

你如果得不到一个较好的结果,程序也没有报错,可能就是这步出现了问题,注意检查。

2.6 冻结特征层,这里我们只训练网络最后一层

for layer in res.layers:           # Dont Train the parameters again 
  layer.trainable = False

2.7 添加全连接层

x = Flatten()(res.output)
x = Dense(units=2 , activation='sigmoid', name = 'predictions' )(x)

# creating our model.
model = Model(res.input, x)

可以通过 model.summary()查看每一层的信息

2.8 训练模型

model.compile( optimizer= 'adam' , loss = 'categorical_crossentropy', metrics=['accuracy'])

es = EarlyStopping(monitor= "val_accuracy" , min_delta= 0.01, patience= 3, verbose=1)
mc = ModelCheckpoint(filepath="bestmodel.h5", monitor="val_accuracy", verbose=1, save_best_only= True)

设置优化器,loss, 评估指标
通过监控val_accuracy来保存网络,使用EarlyStopping来结束训练。

然后,开始训练👇

hist = model.fit_generator(train, steps_per_epoch= 10, epochs= 30, validation_data= valid , validation_steps= 16, callbacks=[es,mc])

model.fit_generator可以了解一下keras的这个函数,可参数的意思,不想了解直接运行就对了。

你的运行准确度应该在80%以上才是正确的,如果结果异常,检查哪一步除了问题。并解决它。

2.9 加载模型查看训练历史结果

## load only the best model 
from keras.models import load_model
model = load_model("bestmodel.h5")

查看保存了哪些历史信息

图片中可以看到,分别保存了训练集和验证集的loss和acc.用matplotlib画出来

plt.plot(h['accuracy'])
plt.plot(h['val_accuracy'] , c = "red")
plt.title("acc vs v-acc")
plt.show()

在测试集上评估模型

acc = model.evaluate_generator(generator= test)[1] 
print(f"The accuracy of your model is = {acc*100} %")

The accuracy of your model is = 88 %

2.10 如何对测试集单张测试

这里需要注意的是,即便是单张测试,也需要对图像进行预处理。

预处理方法同测试集的方法一样。

这里使用from keras.preprocessing import image方法进行预处理。需要统一大小(224,224,3), 转化成nunpy数据,并添加一个batch维度。

def get_img_array(img_path):
  """
  Input : Takes in image path as input 
  Output : Gives out Pre-Processed image
  """
  path = img_path
  img = image.load_img(path, target_size=(224,224,3))
  img = image.img_to_array(img)
  img = np.expand_dims(img , axis= 0 )
  
  return img

处理好的图像就可以通过model.predict(img)进行预测。

2.11 可视化

keras的可视化代码我没研究过,感兴趣的自行研究。

可视化结果如下:
这是最后一个卷积层获得的热力图 尺寸为7*7

将它放大到原始图像一样大,并叠加在原始图像上的效果如下:

我们在可视化一个normal样本


可以发现,健康样本的热力图为全白图像,叠加在原始图像使得整个图像偏蓝。

3 可能遇到的问题总结

  • 环境问题
    如果你的环境可能存在问题,建议尝试重新创建一个虚拟环境,安装tensorflow and keras
conda install tensorflow-gpu
conda install keras

我用的版本为:tf: 2.4.1 keras: 2.4.3

  • 找不到文件问题
    如果你的数据地址前面加了”/“,表示根目录,通常我们不会把数据放在根目录,删掉”/“。

  • 得到的结果跟我的差很远,甚至网络train不动,acc一直在0.5左右 第一可能是环境问题,重新安装环境后,还存在此问题,那么可能是网络问题,预训练权重没下载下来,可以多尝试几次。

如果不出问题,把数据集下载好,训练只要几分钟的时间。

tip:我这里只用resnet进行了实验,你还可以尝试train另一个用vgg16训练的文件.

希望您能享受这次实验,并从中获取知识~~

文章持续更新,可以关注微信公众号【医学图像人工智能实战营】获取最新动态,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~

我是Tina, 我们下篇博客见~

白天工作晚上写文,呕心沥血

觉得写的不错的话最后,求点赞,评论,收藏。或者一键三连
在这里插入图片描述

以上是关于[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%的主要内容,如果未能解决你的问题,请参考以下文章

版本不对,努力白费。这是我花了240元买来的教训!与君共勉-Python项目如何生成requirements.txt文件

YoloV5实战:手把手教物体检测——YoloV5

YoloV6实战:手把手教你使用Yolov6进行物体检测(附数据集)

YoloV6实战:手把手教你使用Yolov6进行物体检测(附数据集)

『Python开发实战菜鸟教程』实战篇:一文带你了解人脸识别应用原理及手把手教学实现自己的人脸识别项目

手把手教你如何自制目标检测框架(从理论到实现)