[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%
Posted Tina姐
tags:
篇首语:本文由小常识网(cha138.com)小编为大家整理,主要介绍了[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%相关的知识,希望对你有一定的参考价值。
本次实战的的概况如下:
- 代码来源:https://github.com/junaidiqbalsyed/Covid_detection_CNN
- 目的:使用
CNN
(vgg and resnet)检测 COVID 并使用GRAD-CAM
进行可视化 - 方法: 二分类(normal or covid)
- 框架: keras (不会没关系,很简单)
- 结果:源代码在VGG-16的准确率为 85%, 我使用resnet跑的结果为88%,甚至90%。
- 难易程度: ⭐️⭐️
结果展示
- resnet准确度
- 可视化结果
感兴趣的话一起进入代码环节吧~~
1 准备工作
- 在github上下载源代码
- 下载数据集(约800M)左右。
如果你网络够好(可随意打开GitHub,Google网站),下载数据集这一步可省略,后面通过代码下载。如果网络不好,建议先下载数据集并解压。 - 配置环境
# 查看版本
import tensorflow as tf
import keras
print(tf.__version__) # 2.4.1
print(keras.__version__) # 2.4.3
注意:这里有个坑,当我在 tf版本 2.0.0 keras 版本 2.3.1运行时,准确度一直在0.5左右徘徊,没有提升。如果你遇到了相同的问题,请重新建一个环境。
2 使用jupyter运行代码
文件:Covid_detection_using_chest_X_Ray(using_ResNet_50)%20(2).ipynb
这部分主要解读代码。
导入包的部分自动略过~~
2.1 数据集是否下载
如果刚开始你没有下载数据集,可运行 cell2 and cell3 下载数据集及解压
!wget https://www.dropbox.com/s/e1r2laj50nh4tez/COVID-19_Radiography_Dataset.zip?dl=0
!unzip "/content/COVID-19_Radiography_Dataset.zip?dl=0"
如果已经下载好,则这两步省略~~~~
2.2 探索数据
直接运行代码即可,不过多介绍。
可以简单探索下:数据的名字,格式,大小,以及下载来源。比如这里的covid数据来自sirm网站。
2.3 把所有类别的数据放在同一个文件夹并可视化
下载好的数据集是按照疾病分类的。这里我们把所有图像放在同一个文件夹。
ROOT_DIR = "/content/COVID-19_Radiography_Dataset/"
imgs = ['COVID','Lung_Opacity','Normal','Viral Pneumonia']
NEW_DIR = "content/all_images/"
需要注意这部分代码的地址。作者在地址前加了”/“,如/content
,表示content目录在根目录中,如果你的数据不在根目录,就不要加前面的”/“。
同理,如果后文出现了找不到文件夹,很有可能是你这里加了”/“, 去掉即可。
可视化看看数据集的分布
下载的数据集有4类,但是我们代码中实际上只用到了2类(COVID 和 Normal)
2.4 把数据集分成训练集、验证集和测试集
if not os.path.exists(NEW_DIR+"train_test_split/"):
os.makedirs(NEW_DIR+"train_test_split/")
.....
os.makedirs(NEW_DIR+"train_test_split/validation/Covid")
# Train Data
for i in np.random.choice(replace= False , size= 3000 , a = glob.glob(NEW_DIR+imgs[0]+"*") ):
shutil.copy(i , NEW_DIR+"train_test_split/train/Covid" )
os.remove(i)
for i in np.random.choice(replace= False , size= 3900 , a = glob.glob(NEW_DIR+imgs[2]+"*") ):
shutil.copy(i , NEW_DIR+"train_test_split/train/Normal" )
os.remove(i)
....
这部分代码里有很多可以学习的地方:
比如,我们要在所有covid图像中,随机选取3000个作为训练集。怎么做到?
答案: np.random.choice()
当这部分选取作为训练集后,如何保证这部分数据在验证集和测试集中选不到它。
答案:os.remove(i)
当被选取后,删除它,那么再选择的时候就选择不到它了。
现在,数据集有了。
2.5 为keras生成数据流
train_path = "content/all_images/train_test_split/train"
valid_path = "content/all_images/train_test_split/validation"
test_path = "content/all_images/train_test_split/test"
这是我们数据集存放的地址。再强调一次,content
前面的’/‘我已经去掉。
为各数据集生成keras可识别的数据
train_data_gen = ImageDataGenerator(preprocessing_function= preprocess_input,
zoom_range= 0.2,
horizontal_flip= True,
shear_range= 0.2,
)
train = train_data_gen.flow_from_directory(directory= train_path,
target_size=(224,224))
训练集中: Found 7800 images belonging to 2 classes.
2.6 构建模型
res = ResNet50( input_shape=(224,224,3), include_top= False)
# include_top will consider the new weights
include_top= False
表示不要全连接层,只加载特征部分。
这里加载的预训练权重是在:https://storage.googleapis.com
注意
我用其他版本keras加载的权重是在:https://github.com/
不同的版本会有区别,也会影响到结果。我也不知道为啥。
你如果得不到一个较好的结果,程序也没有报错,可能就是这步出现了问题,注意检查。
2.6 冻结特征层,这里我们只训练网络最后一层
for layer in res.layers: # Dont Train the parameters again
layer.trainable = False
2.7 添加全连接层
x = Flatten()(res.output)
x = Dense(units=2 , activation='sigmoid', name = 'predictions' )(x)
# creating our model.
model = Model(res.input, x)
可以通过 model.summary()
查看每一层的信息
2.8 训练模型
model.compile( optimizer= 'adam' , loss = 'categorical_crossentropy', metrics=['accuracy'])
es = EarlyStopping(monitor= "val_accuracy" , min_delta= 0.01, patience= 3, verbose=1)
mc = ModelCheckpoint(filepath="bestmodel.h5", monitor="val_accuracy", verbose=1, save_best_only= True)
设置优化器,loss, 评估指标
通过监控val_accuracy来保存网络,使用EarlyStopping来结束训练。
然后,开始训练👇
hist = model.fit_generator(train, steps_per_epoch= 10, epochs= 30, validation_data= valid , validation_steps= 16, callbacks=[es,mc])
model.fit_generator
可以了解一下keras的这个函数,可参数的意思,不想了解直接运行就对了。
你的运行准确度应该在80%以上才是正确的,如果结果异常,检查哪一步除了问题。并解决它。
2.9 加载模型查看训练历史结果
## load only the best model
from keras.models import load_model
model = load_model("bestmodel.h5")
查看保存了哪些历史信息
图片中可以看到,分别保存了训练集和验证集的loss和acc.用matplotlib画出来
plt.plot(h['accuracy'])
plt.plot(h['val_accuracy'] , c = "red")
plt.title("acc vs v-acc")
plt.show()
在测试集上评估模型
acc = model.evaluate_generator(generator= test)[1]
print(f"The accuracy of your model is = {acc*100} %")
The accuracy of your model is = 88 %
2.10 如何对测试集单张测试
这里需要注意的是,即便是单张测试,也需要对图像进行预处理。
预处理方法同测试集的方法一样。
这里使用from keras.preprocessing import image
方法进行预处理。需要统一大小(224,224,3), 转化成nunpy数据,并添加一个batch维度。
def get_img_array(img_path):
"""
Input : Takes in image path as input
Output : Gives out Pre-Processed image
"""
path = img_path
img = image.load_img(path, target_size=(224,224,3))
img = image.img_to_array(img)
img = np.expand_dims(img , axis= 0 )
return img
处理好的图像就可以通过model.predict(img)
进行预测。
2.11 可视化
keras的可视化代码我没研究过,感兴趣的自行研究。
可视化结果如下:
这是最后一个卷积层获得的热力图 尺寸为7*7
将它放大到原始图像一样大,并叠加在原始图像上的效果如下:
我们在可视化一个normal样本
可以发现,健康样本的热力图为全白图像,叠加在原始图像使得整个图像偏蓝。
3 可能遇到的问题总结
- 环境问题
如果你的环境可能存在问题,建议尝试重新创建一个虚拟环境,安装tensorflow and keras
conda install tensorflow-gpu
conda install keras
我用的版本为:tf: 2.4.1 keras: 2.4.3
-
找不到文件问题
如果你的数据地址前面加了”/“,表示根目录,通常我们不会把数据放在根目录,删掉”/“。 -
得到的结果跟我的差很远,甚至网络train不动,acc一直在0.5左右 第一可能是环境问题,重新安装环境后,还存在此问题,那么可能是网络问题,预训练权重没下载下来,可以多尝试几次。
如果不出问题,把数据集下载好,训练只要几分钟的时间。
tip:我这里只用resnet进行了实验,你还可以尝试train另一个用vgg16训练的文件.
希望您能享受这次实验,并从中获取知识~~
文章持续更新,可以关注微信公众号【医学图像人工智能实战营】获取最新动态,一个关注于医学图像处理领域前沿科技的公众号。坚持已实践为主,手把手带你做项目,打比赛,写论文。凡原创文章皆提供理论讲解,实验代码,实验数据。只有实践才能成长的更快,关注我们,一起学习进步~
我是Tina, 我们下篇博客见~
白天工作晚上写文,呕心沥血
觉得写的不错的话最后,求点赞,评论,收藏。或者一键三连
以上是关于[代码实战]手把手带你训练一个COVID检测网络,准确率高达90%的主要内容,如果未能解决你的问题,请参考以下文章
版本不对,努力白费。这是我花了240元买来的教训!与君共勉-Python项目如何生成requirements.txt文件
YoloV6实战:手把手教你使用Yolov6进行物体检测(附数据集)
YoloV6实战:手把手教你使用Yolov6进行物体检测(附数据集)