%matplotlib inline
from torch import nn
from IPython import display
from torch.utils.data import TensorDataset, DataLoader
import torchvision
import torch
import torchvision.transforms as transforms
import gzip
import numpy as np
import time
import os
import matplotlib.pyplot as plt
display.set_matplotlib_formats("svg")
mnist_train = torchvision.datasets.FashionMNIST(root='../dataset/fashion_mnist', train=True, download=True, transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='../dataset/fashion_mnist', train=False, download=True, transform=transforms.ToTensor())
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz
100.0%Extracting ../dataset/fashion_mnist/FashionMNIST/raw/train-images-idx3-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz
111.0%Extracting ../dataset/fashion_mnist/FashionMNIST/raw/train-labels-idx1-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
100.0%Extracting ../dataset/fashion_mnist/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
159.1%Extracting ../dataset/fashion_mnist/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to ../dataset/fashion_mnist/FashionMNIST/raw
Processing...
Done!
/opt/venv/lib/python3.7/site-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)
return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
feature, label = mnist_train[0]
print(feature.shape, label) # Channel x Height x Width
torch.Size([1, 28, 28]) 9
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images, labels):
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
for f, img, lbl in zip(figs, images, labels):
f.imshow(img.view((28, 28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X, y = [], []
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_train[i][1])
show_fashion_mnist(X, get_fashion_mnist_labels(y))
batch_size = 256
num_workers = 4
train_iter = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, shuffle=True, num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(mnist_test, batch_size=batch_size, shuffle=False, num_workers=num_workers)
def accuracy(net, data_iter, is_one_hot=False):
'''
net: 神经网络
data_iter: 数据迭代器
'''
sum_acc = 0.0
sum_crossEnt = 0.0
n = 0
for X, y in data_iter:
hat_y = torch.softmax(net(X), dim=1)
if is_one_hot:
acc = torch.sum(torch.max(y, axis=1)[1] == torch.max(hat_y, axis=1)[1]).item()
crossEnt = -torch.dot(y.float(), torch.log2(hat_y)).item()
else:
acc = torch.sum((y == torch.max(hat_y, axis=1)[1]).byte()).numpy().item()
crossEnt = -torch.dot(y.float(), torch.log2(hat_y[torch.arange(hat_y.shape[0]), y.long()].reshape(-1))).item()
sum_acc += acc
sum_crossEnt += crossEnt
n += X.shape[0]
return sum_crossEnt/n, sum_acc/n
class FlattenLayer(torch.nn.Module):
def __init__(self):
super(FlattenLayer, self).__init__()
def forward(self, x): # x shape: (batch, *, *, ...)
return x.view(x.shape[0], -1)
num_inputs, num_outputs, num_hiddens = 784, 10, 256
net = nn.Sequential(
FlattenLayer(),
nn.Linear(num_inputs, num_hiddens),
nn.ReLU(),
nn.Linear(num_hiddens, num_outputs),
)
for params in net.parameters():
torch.nn.init.normal_(params, mean=0, std=0.001)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
t1 = time.perf_counter()
num_epochs = 40
for epoch in range(num_epochs):
for t_x, t_y in train_iter:
l = loss(net(t_x), t_y) # 计算当前批量的交叉熵损失
optimizer.zero_grad() # 参数梯度清零
l.backward() # 反向传播,计算梯度
optimizer.step() # 更新参数
with torch.no_grad(): # 不计算梯度,加速损失函数的运算
train_accuracy = accuracy(net, train_iter)
test_accuracy = accuracy(net, test_iter)
print(f'Epoch {epoch+1}, train_loss: {train_accuracy[0]:.4f} train_accuracy: {train_accuracy[1]:.4f}, test_loss: {test_accuracy[0]:.4f}, test_accuracy: {test_accuracy[1]:.4f}')
print(f"耗时:{time.perf_counter()-t1:.2f}s")
Epoch 1, train_loss: 4.9924 train_accuracy: 0.7117, test_loss: 5.0431, test_accuracy: 0.7050
Epoch 2, train_loss: 3.7232 train_accuracy: 0.7927, test_loss: 3.8094, test_accuracy: 0.7845
Epoch 3, train_loss: 3.6153 train_accuracy: 0.7890, test_loss: 3.7583, test_accuracy: 0.7775
Epoch 4, train_loss: 3.5457 train_accuracy: 0.8189, test_loss: 3.7068, test_accuracy: 0.8100
Epoch 5, train_loss: 2.6074 train_accuracy: 0.8377, test_loss: 2.7829, test_accuracy: 0.8240
Epoch 6, train_loss: 2.5498 train_accuracy: 0.8427, test_loss: 2.7479, test_accuracy: 0.8278
Epoch 7, train_loss: 3.4656 train_accuracy: 0.8262, test_loss: 3.7041, test_accuracy: 0.8137
Epoch 8, train_loss: 2.6096 train_accuracy: 0.8329, test_loss: 2.8372, test_accuracy: 0.8132
Epoch 9, train_loss: 2.3112 train_accuracy: 0.8478, test_loss: 2.5409, test_accuracy: 0.8326
Epoch 10, train_loss: 2.2200 train_accuracy: 0.8332, test_loss: 2.4359, test_accuracy: 0.8134
Epoch 11, train_loss: 3.0599 train_accuracy: 0.8403, test_loss: 3.3502, test_accuracy: 0.8258
Epoch 12, train_loss: 2.2588 train_accuracy: 0.8596, test_loss: 2.5230, test_accuracy: 0.8423
Epoch 13, train_loss: 2.0555 train_accuracy: 0.8670, test_loss: 2.3131, test_accuracy: 0.8480
Epoch 14, train_loss: 2.1826 train_accuracy: 0.8710, test_loss: 2.4712, test_accuracy: 0.8551
Epoch 15, train_loss: 1.8623 train_accuracy: 0.8763, test_loss: 2.1402, test_accuracy: 0.8580
Epoch 16, train_loss: 2.0191 train_accuracy: 0.8719, test_loss: 2.3311, test_accuracy: 0.8544
Epoch 17, train_loss: 1.8316 train_accuracy: 0.8800, test_loss: 2.1278, test_accuracy: 0.8620
Epoch 18, train_loss: 1.9074 train_accuracy: 0.8817, test_loss: 2.2157, test_accuracy: 0.8617
Epoch 19, train_loss: 2.6135 train_accuracy: 0.8618, test_loss: 2.9847, test_accuracy: 0.8464
Epoch 20, train_loss: 1.7353 train_accuracy: 0.8694, test_loss: 2.0451, test_accuracy: 0.8478
Epoch 21, train_loss: 1.8104 train_accuracy: 0.8856, test_loss: 2.1539, test_accuracy: 0.8627
Epoch 22, train_loss: 1.7400 train_accuracy: 0.8790, test_loss: 2.0763, test_accuracy: 0.8606
Epoch 23, train_loss: 1.9590 train_accuracy: 0.8870, test_loss: 2.3318, test_accuracy: 0.8680
Epoch 24, train_loss: 1.7209 train_accuracy: 0.8837, test_loss: 2.0876, test_accuracy: 0.8608
Epoch 25, train_loss: 1.6549 train_accuracy: 0.8906, test_loss: 2.0260, test_accuracy: 0.8674
Epoch 26, train_loss: 2.0681 train_accuracy: 0.8835, test_loss: 2.4687, test_accuracy: 0.8646
Epoch 27, train_loss: 1.7220 train_accuracy: 0.8955, test_loss: 2.0987, test_accuracy: 0.8713
Epoch 28, train_loss: 2.0688 train_accuracy: 0.8843, test_loss: 2.5082, test_accuracy: 0.8613
Epoch 29, train_loss: 1.8367 train_accuracy: 0.8920, test_loss: 2.2408, test_accuracy: 0.8702
Epoch 30, train_loss: 1.7036 train_accuracy: 0.8925, test_loss: 2.1060, test_accuracy: 0.8709
Epoch 31, train_loss: 1.6256 train_accuracy: 0.8937, test_loss: 2.0227, test_accuracy: 0.8718
Epoch 32, train_loss: 1.4946 train_accuracy: 0.9015, test_loss: 1.9191, test_accuracy: 0.8769
Epoch 33, train_loss: 1.6215 train_accuracy: 0.9014, test_loss: 2.0463, test_accuracy: 0.8788
Epoch 34, train_loss: 1.8242 train_accuracy: 0.8916, test_loss: 2.2910, test_accuracy: 0.8675
Epoch 35, train_loss: 1.7814 train_accuracy: 0.8937, test_loss: 2.2591, test_accuracy: 0.8674
Epoch 36, train_loss: 1.6753 train_accuracy: 0.9031, test_loss: 2.1689, test_accuracy: 0.8799
Epoch 37, train_loss: 1.4796 train_accuracy: 0.8974, test_loss: 1.9399, test_accuracy: 0.8718
Epoch 38, train_loss: 1.5795 train_accuracy: 0.9028, test_loss: 2.0463, test_accuracy: 0.8761
Epoch 39, train_loss: 1.4715 train_accuracy: 0.8898, test_loss: 1.9323, test_accuracy: 0.8607
Epoch 40, train_loss: 1.6440 train_accuracy: 0.9021, test_loss: 2.1673, test_accuracy: 0.8738
耗时:589.91s