import torch
from torchvision import datasets, transforms
torch.manual_seed(44)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5)),
])
trainset = datasets.FashionMNIST(root='./data', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)
testset = datasets.FashionMNIST(root='./data', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, num_workers=0)
print("Train shape:", trainloader.dataset.data.shape)
print("Test shape:", testloader.dataset.data.shape)
Train shape: torch.Size([60000, 28, 28])
Test shape: torch.Size([10000, 28, 28])
print("Train batch size:", trainloader.batch_size)
print("Test batch size:", testloader.batch_size)
Train batch size: 64
Test batch size: 64
print("Sampler:", trainloader.sampler)
print("Collate function:", trainloader.collate_fn)
Sampler: <torch.utils.data.sampler.RandomSampler object at 0x7f2b591610d0>
Collate function: <function default_collate at 0x7f2b5b70e8c0>
import matplotlib.pyplot as plt
images, labels = next(iter(trainloader)) # Gets a batch of 64 images in the training set
first_image = images[0] # Get the first image out of the 64 images.
plt.imshow(first_image.numpy().squeeze(), cmap='Greys_r')
plt.show()