import torch
from torchvision import datasets, transforms
torch.manual_seed(44)
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5), std=(0.5)),
])
trainset = datasets.FashionMNIST(root='./data', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=0)
testset = datasets.FashionMNIST(root='./data', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True, num_workers=0)
print("Train shape:", trainloader.dataset.data.shape)
print("Test shape:", testloader.dataset.data.shape)
print("Train batch size:", trainloader.batch_size)
print("Test batch size:", testloader.batch_size)
print("Sampler:", trainloader.sampler)
print("Collate function:", trainloader.collate_fn)
import matplotlib.pyplot as plt
images, labels = next(iter(trainloader)) # Gets a batch of 64 images in the training set
first_image = images[0] # Get the first image out of the 64 images.
plt.imshow(first_image.numpy().squeeze(), cmap='Greys_r')
plt.show()