1.导入必要的库
本文主要基于torch对肺部感染数据集进行处理,最终测试集的精度能够达到80%以上。首先我们将操作过程中所需要用到的库全部写在了开头,以供后续直接调用。
2. 定义图像变换函数:
我们首先定义了一个transform,主要功能是对我们的图片做一些变换,比如说伸缩、旋转、以及裁剪等等。采用了字典的格式。三个关键词:'train'、'val'和'test'分别是训练、验证和测试。每一个关键词后面的value都是对图片的变换。
对train,test,val依次都进行相似的的转换。封装成图像变换功能的函数。
3.读取数据集
4.数据集展示
读取数据集后,用dataloaders的dataset属性输出了数据集的简介,能够清晰的看到数据集的数量和所做的转换。同时也展示了一些肺部正常和感染的图片,为后续调用做铺垫。
5.图片的展示
因为在训练的过程中,我们需要将很多中间结果值保存起来,所以导入了SummaryWriter,来创建日志。
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Label : NORMAL
6.定义模型
troch中有很多预定义或者是预训练的模型。我们方法的实现基于迁移学习,用了Resnet50模型,他的参数是已经在其他数据集上训练好的,所以中间层的参数我们对其冻结,不做任何改变。因为我们训练的数据集发生了变化,所以主要是对模型中最后的池化层和全连接层做了改变。用自定义的层替换掉原来的层,然后进行训练。
7. 定义训练,测试函数
训练和测试函数的定义方式十分相似,通常都是相对固定的一种格式。输入一系列参数,定义了训练损失,验证损失,验证准确率,使用循环来读取训练集。接下来的流程十分固定:梯度置0,前向传播,计算损失,反向传播,更新参数,再计算累计损失和平均损失。只是要注意验证和测试要声明一下并不进行梯度更新。
8. 进行训练和测试
把训练的流程封装成了一个函数,后面直接调用函数就能输出结果。首先输入模型以及其他参数,然后初始化损失定义为正无穷,这样我们最终能够得到一个最小的损失,所对应的模型。即我们需要的模型。接着直接调用之前已经定义好的函数,train_val()和test()。分别返回各自的结果。然后判断test_loss 是否小于best_loss,保存最小损失,保存模型。我们只运行了10轮。最后函数输出每一轮的信息:训练损失,验证损失,验证准确率,测试损失和测试准确率。测试集的精度能够达到80%以上。
cuda
Epoch | Train Loss | val_loss | val_acc | Test Loss | Test_acc
0 | 0.6797741823401188 | 0.48668360710144043 | 0.75 | 2.606718271970749 | 75.80128205128206
1 | 0.546652967944467 | 0.38362419605255127 | 0.8125 | 2.339508444070816 | 79.32692307692308
2 | 0.4975595817975471 | 0.34449177980422974 | 0.8125 | 2.224026769399643 | 81.57051282051282
3 | 0.46354614317051474 | 0.32737037539482117 | 0.8125 | 2.1333694756031036 | 82.53205128205128
4 | 0.4505095761612149 | 0.30775487422943115 | 0.8125 | 2.118358075618744 | 82.6923076923077
5 | 0.42891258824091016 | 0.31182220578193665 | 0.8125 | 2.0867340862751007 | 82.37179487179488
6 | 0.42118780104660547 | 0.31789082288742065 | 0.8125 | 2.0254012644290924 | 82.21153846153847
7 | 0.40498529109486775 | 0.3049515187740326 | 0.8125 | 1.9985454976558685 | 83.17307692307692
8 | 0.3925688873770778 | 0.29997318983078003 | 0.8125 | 1.9681228697299957 | 83.81410256410257
9 | 0.3795613549604006 | 0.30442166328430176 | 0.8125 | 1.9643790423870087 | 83.65384615384616