import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision import transforms
import torchvision.utils
from tqdm import tqdm
import matplotlib.pyplot as plt
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
The PyTorch DataLoader class is an efficient implementation of an iterator that can perform useful preprocessing and returns batches of elements. Here, we use its ability to batch and shuffle data, but DataLoaders are capable of much more.
Note that each time we iterate over a DataLoader, it starts again from the beginning.
Below we use torchvision.utils.make_grid()
to show a sample batch of inputs.
data_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
# Show one batch of images. Each batch of images has shape [batch_size, 1, 28, 28],
# where 1 is the "channels" dimension of the image.
for images,labels in data_loader:
grid_img = torchvision.utils.make_grid(images)
plt.imshow(grid_img.permute(1, 2, 0))
plt.title("A single batch of images")
break
Here we define a simple 1-hidden-layer neural network for classification on MNIST. It takes a parameter that determines the hidden size of the hidden layer.
class MNISTNetwork(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear_0 = nn.Linear(784, hidden_size)
self.linear_1 = nn.Linear(hidden_size, 10)
def forward(self, inputs):
x = self.linear_0(inputs)
x = torch.sigmoid(x)
return self.linear_1(x)
We will consider three networks.
In the code below, we utilize some important PyTorch methods which you'll want to be familiar with. This includes:
torch.nn.Module.parameters()
: Returns an iterator over module parameters (i.e. for passing to an optimizer that will update those parameters).
torch.Tensor.view()
: Returns a view into the original Tensor. The result of this method shares the same underlying data as the input Tensor. This avoids copying the data, which means it can be mnore efficient, but it also means that when the original Tensor is modified, so is the view!
torch.Tensor.item()
: Returns the value of a single-element Tensor as a standard Python number. This only works for tensors with one element. For other cases, see torch.Tensor.tolist()
.
torch.Tensor.backward()
: Computes the gradients of current tensor wrt the graph leaves (note that this is only called if Tensor.requires_grad
is True, which is the case by default). After calling this, a Tensor's .grad
attribute is updated with the current gradients. These are used, for example, when calling .step()
method of an optimizer.
torch.optim.Optimizer.zero_grad()
: Sets the gradients of all variables to zero. This should be conducted before each step of an optimization procedure (i.e., for each batch of training a DNN). If .zero_grad()
is not called, gradients accumulate (add) over iterations.
small_net = MNISTNetwork(1)
large_net = MNISTNetwork(64)
large_net_rand = MNISTNetwork(64)
for p in zip(small_net.parameters(), large_net.parameters()):
p1, p2 = p
p1.data = torch.zeros_like(p1.data)
p2.data = torch.zeros_like(p2.data)
We will train all three networks simulateneously using the same learning rate. After each epoch, we print the current loss of each network.
epochs = 32
optimizer_small = optim.Adam(small_net.parameters(), lr=5e-3)
optimizer_large = optim.Adam(large_net.parameters(), lr=5e-3)
optimizer_large_rand = optim.Adam(large_net_rand.parameters(), lr=5e-3)
for i in range(epochs):
loss_small_epoch = 0.
loss_large_epoch = 0.
loss_large_rand_epoch = 0.
for batch in tqdm(data_loader):
images, labels = batch
images, labels = images, labels
images = images.view(-1, 784)
optimizer_small.zero_grad()
optimizer_large.zero_grad()
optimizer_large_rand.zero_grad()
y_small = small_net(images)
y_large = large_net(images)
y_large_rand = large_net_rand(images)
loss_small = F.cross_entropy(y_small, labels)
loss_large = F.cross_entropy(y_large, labels)
loss_large_rand = F.cross_entropy(y_large_rand, labels)
loss_small_epoch += loss_small.item()
loss_large_epoch += loss_large.item()
loss_large_rand_epoch += loss_large_rand.item()
loss_small.backward()
loss_large.backward()
loss_large_rand.backward()
optimizer_small.step()
optimizer_large.step()
optimizer_large_rand.step()
print("Small Loss:", loss_small_epoch / len(data_loader))
print("Large Loss:", loss_large_epoch / len(data_loader))
print("Large rand Loss:", loss_large_rand_epoch / len(data_loader))
100%|██████████| 938/938 [00:08<00:00, 110.92it/s]
Small Loss: 1.9747004094662697 Large Loss: 1.8210063917296273 Large rand Loss: 0.3225371094106802
100%|██████████| 938/938 [00:08<00:00, 106.81it/s]
Small Loss: 1.8062316338136506 Large Loss: 1.6598941615141276 Large rand Loss: 0.1449374124618259
100%|██████████| 938/938 [00:08<00:00, 106.02it/s]
Small Loss: 1.7683436330447573 Large Loss: 1.6282896525315893 Large rand Loss: 0.10546199680427149
100%|██████████| 938/938 [00:08<00:00, 105.37it/s]
Small Loss: 1.745496065250592 Large Loss: 1.611243960318535 Large rand Loss: 0.08176170171959314
100%|██████████| 938/938 [00:08<00:00, 115.63it/s]
Small Loss: 1.7230805862687035 Large Loss: 1.6086191712920346 Large rand Loss: 0.06632703205229028
100%|██████████| 938/938 [00:08<00:00, 106.73it/s]
Small Loss: 1.6960091592152235 Large Loss: 1.6033877005963437 Large rand Loss: 0.056162314508348576
100%|██████████| 938/938 [00:08<00:00, 107.39it/s]
Small Loss: 1.6704798057389412 Large Loss: 1.601104776488184 Large rand Loss: 0.048491643619850706
100%|██████████| 938/938 [00:08<00:00, 117.15it/s]
Small Loss: 1.6492018880112085 Large Loss: 1.6005864369589637 Large rand Loss: 0.04086352238900213
100%|██████████| 938/938 [00:08<00:00, 107.32it/s]
Small Loss: 1.631801547272119 Large Loss: 1.5975915173223532 Large rand Loss: 0.037130898026539
100%|██████████| 938/938 [00:08<00:00, 106.40it/s]
Small Loss: 1.6177405220613297 Large Loss: 1.5955544795308794 Large rand Loss: 0.03204905803090454
100%|██████████| 938/938 [00:08<00:00, 111.08it/s]
Small Loss: 1.6089600046306276 Large Loss: 1.5970904198028386 Large rand Loss: 0.02750099962080783
100%|██████████| 938/938 [00:08<00:00, 111.01it/s]
Small Loss: 1.6031588676895923 Large Loss: 1.5954826061151175 Large rand Loss: 0.02494130808654537
100%|██████████| 938/938 [00:08<00:00, 105.97it/s]
Small Loss: 1.5941073978379336 Large Loss: 1.59343382125216 Large rand Loss: 0.022491324605218677
100%|██████████| 938/938 [00:08<00:00, 105.73it/s]
Small Loss: 1.592615304979434 Large Loss: 1.5952912087379487 Large rand Loss: 0.01928929526280405
100%|██████████| 938/938 [00:08<00:00, 115.24it/s]
Small Loss: 1.5848750128929041 Large Loss: 1.5933883008418053 Large rand Loss: 0.017751067268148586
100%|██████████| 938/938 [00:09<00:00, 103.21it/s]
Small Loss: 1.5789772860531106 Large Loss: 1.594103632705298 Large rand Loss: 0.016460558151617014
100%|██████████| 938/938 [00:08<00:00, 105.42it/s]
Small Loss: 1.5755586494514937 Large Loss: 1.5923805371530528 Large rand Loss: 0.016497078008015455
100%|██████████| 938/938 [00:07<00:00, 117.27it/s]
Small Loss: 1.5722746952001982 Large Loss: 1.5920828934163174 Large rand Loss: 0.013686977899110844
100%|██████████| 938/938 [00:08<00:00, 106.64it/s]
Small Loss: 1.5709358697761096 Large Loss: 1.5918515444056058 Large rand Loss: 0.01256602146017523
100%|██████████| 938/938 [00:08<00:00, 106.35it/s]
Small Loss: 1.5678447042701087 Large Loss: 1.5934399259624197 Large rand Loss: 0.015225826330961529
100%|██████████| 938/938 [00:08<00:00, 111.82it/s]
Small Loss: 1.5655181012682315 Large Loss: 1.5908296955928112 Large rand Loss: 0.010799695614480607
100%|██████████| 938/938 [00:08<00:00, 109.93it/s]
Small Loss: 1.5624548138331757 Large Loss: 1.5924690467462357 Large rand Loss: 0.012164810865083392
100%|██████████| 938/938 [00:08<00:00, 105.01it/s]
Small Loss: 1.562048412081021 Large Loss: 1.5917949865875975 Large rand Loss: 0.010378873941811499
100%|██████████| 938/938 [00:08<00:00, 105.63it/s]
Small Loss: 1.5588710741447742 Large Loss: 1.5909390035214455 Large rand Loss: 0.008664014752820392
100%|██████████| 938/938 [00:08<00:00, 116.55it/s]
Small Loss: 1.5587773711950794 Large Loss: 1.5918893418840763 Large rand Loss: 0.009671726888597338
100%|██████████| 938/938 [00:08<00:00, 106.70it/s]
Small Loss: 1.5614188310942416 Large Loss: 1.5933234642055243 Large rand Loss: 0.009133911602114868
100%|██████████| 938/938 [00:08<00:00, 105.90it/s]
Small Loss: 1.5554405610952804 Large Loss: 1.5895309865093434 Large rand Loss: 0.006141310933210551
100%|██████████| 938/938 [00:08<00:00, 115.72it/s]
Small Loss: 1.5551571632499126 Large Loss: 1.590696329374049 Large rand Loss: 0.011262383425718362
100%|██████████| 938/938 [00:08<00:00, 106.35it/s]
Small Loss: 1.555824849270046 Large Loss: 1.5892616450659502 Large rand Loss: 0.011654529858772496
100%|██████████| 938/938 [00:08<00:00, 106.49it/s]
Small Loss: 1.557392876158391 Large Loss: 1.5905824543824836 Large rand Loss: 0.00701571908392653
100%|██████████| 938/938 [00:08<00:00, 109.28it/s]
Small Loss: 1.5533092093111864 Large Loss: 1.589730658383766 Large rand Loss: 0.005507708725134489
100%|██████████| 938/938 [00:08<00:00, 112.20it/s]
Small Loss: 1.5584697225200597 Large Loss: 1.5897600718144416 Large rand Loss: 0.007374370418365521
W_0 = large_net.linear_0.weight
b_0 = large_net.linear_0.bias
W_1 = large_net.linear_1.weight
b_1 = large_net.linear_1.bias
print("W_0 => All weights equal for each hidden unit:", (W_0[0, :].unsqueeze(0) == W_0).all().item())
print("Example of weights:")
print(W_0[:, 256])
W_0 => All weights equal for each hidden unit: True Example of weights: tensor([0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595], grad_fn=<SelectBackward0>)
print("W_1 => All weights equal for each hidden unit:", (W_1[:, 0].unsqueeze(-1) == W_1).all().item())
print("Weights:")
print(W_1[8])
W_1 => All weights equal for each hidden unit: True Weights: tensor([-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400], grad_fn=<SelectBackward0>)
print("b_0 => All biases equal for each hidden unit:", (b_0[0] == b_0).all().item())
print("Bias:")
print(b_0)
b_0 => All biases equal for each hidden unit: True Bias: Parameter containing: tensor([-0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500], requires_grad=True)
print("b_1 => All biases equal for each hidden unit:", (b_1[0] == b_1).all().item())
print("Bias:")
print(b_1)
b_1 => All biases equal for each hidden unit: False Bias: Parameter containing: tensor([ 2.5093, -9.0861, -3.1097, 0.5818, 2.8968, 1.6153, -1.6222, 3.9712, 0.4858, 3.9563], requires_grad=True)
Below is an implementation of the network from the section handout. We use torchinfo-summary()
to view the size of the data as it flows through the network; additionally, we print and the size of the weights and biases of the layers during a forward pass. Note that this network is just for demonstration and may not work well in practice.
Note: this section uses the torchinfo
package; see the github repo for installation instructions or run one of the following lines below:
install via conda:
conda install -c conda-forge torchinfo
install via pip:
pip install torchinfo
from torchinfo import summary
class DemoNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)
self.max1 = nn.MaxPool2d(2, 2, 0)
self.conv2 = nn.Conv2d(16, 32, 3, 1, 0)
self.max2 = nn.MaxPool2d(2, 2, 1)
self.conv3 = nn.Conv2d(32, 8, 1, 1, 0)
self.conv4 = nn.Conv2d(8, 4, 5, 1, 0)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(36, 10)
@property
def trainable_layers(self):
"""A utility property to easily access a list of all model layers."""
return [self.conv1, self.conv2, self.conv3, self.conv4, self.linear1]
def forward(self, inputs):
"""Implements the forward pass."""
x = self.conv1(inputs)
x = self.max1(x)
x = self.conv2(x)
x = self.max2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.flatten(x)
x = self.linear1(x)
return x
def print_weight_shapes(self):
"""Utility function to print the shapes of weights in trainable layers."""
for layer in self.trainable_layers:
print(f"Weight shape: {layer.weight.shape}; Bias shape: {layer.bias.shape}")
demo = DemoNetwork()
batch_size = 64
summary(demo, input_size=(batch_size, 1, 28, 28))
========================================================================================== Layer (type:depth-idx) Output Shape Param # ========================================================================================== DemoNetwork [64, 10] -- ├─Conv2d: 1-1 [64, 16, 28, 28] 160 ├─MaxPool2d: 1-2 [64, 16, 14, 14] -- ├─Conv2d: 1-3 [64, 32, 12, 12] 4,640 ├─MaxPool2d: 1-4 [64, 32, 7, 7] -- ├─Conv2d: 1-5 [64, 8, 7, 7] 264 ├─Conv2d: 1-6 [64, 4, 3, 3] 804 ├─Flatten: 1-7 [64, 36] -- ├─Linear: 1-8 [64, 10] 370 ========================================================================================== Total params: 6,238 Trainable params: 6,238 Non-trainable params: 0 Total mult-adds (M): 52.11 ========================================================================================== Input size (MB): 0.20 Forward/backward pass size (MB): 9.01 Params size (MB): 0.02 Estimated Total Size (MB): 9.23 ==========================================================================================
demo.print_weight_shapes()
Weight shape: torch.Size([16, 1, 3, 3]); Bias shape: torch.Size([16]) Weight shape: torch.Size([32, 16, 3, 3]); Bias shape: torch.Size([32]) Weight shape: torch.Size([8, 32, 1, 1]); Bias shape: torch.Size([8]) Weight shape: torch.Size([4, 8, 5, 5]); Bias shape: torch.Size([4]) Weight shape: torch.Size([10, 36]); Bias shape: torch.Size([10])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data, test_data = torch.utils.data.random_split(mnist, [0.9, 0.1])
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)
epochs = 16
optimizer = optim.Adam(demo.parameters(), lr=5e-3)
for i in range(epochs):
loss = 0.
correct_labels = 0
total_labels = 0
for batch in tqdm(train_dataloader):
images, labels = batch
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
y_hat = demo(images)
batch_loss = F.cross_entropy(y_hat, labels)
batch_loss.backward()
optimizer.step()
loss += batch_loss.item()
correct_labels += torch.sum(torch.argmax(y_hat, dim=1) == labels).item()
total_labels += len(labels)
print("Train Loss:", loss / len(data_loader))
print("Train Accuracy:", correct_labels / total_labels)
100%|██████████| 844/844 [00:08<00:00, 95.84it/s]
Train Loss: 0.052556460685644506 Train Accuracy: 0.9822777777777778
100%|██████████| 844/844 [00:07<00:00, 109.86it/s]
Train Loss: 0.05048376267987749 Train Accuracy: 0.9822037037037037
100%|██████████| 844/844 [00:11<00:00, 74.26it/s]
Train Loss: 0.049236171157421994 Train Accuracy: 0.9834259259259259
100%|██████████| 844/844 [00:08<00:00, 100.83it/s]
Train Loss: 0.05079400505832603 Train Accuracy: 0.9826111111111111
100%|██████████| 844/844 [00:09<00:00, 86.99it/s]
Train Loss: 0.04830449687629523 Train Accuracy: 0.9828888888888889
100%|██████████| 844/844 [00:10<00:00, 79.22it/s]
Train Loss: 0.04747499192038229 Train Accuracy: 0.9830740740740741
100%|██████████| 844/844 [00:07<00:00, 110.20it/s]
Train Loss: 0.04763950119656362 Train Accuracy: 0.9835555555555555
100%|██████████| 844/844 [00:08<00:00, 99.47it/s]
Train Loss: 0.04605868608596909 Train Accuracy: 0.984
100%|██████████| 844/844 [00:08<00:00, 100.11it/s]
Train Loss: 0.044495291716124645 Train Accuracy: 0.9844814814814815
100%|██████████| 844/844 [00:07<00:00, 108.20it/s]
Train Loss: 0.0477359937937854 Train Accuracy: 0.9832222222222222
100%|██████████| 844/844 [00:08<00:00, 99.90it/s]
Train Loss: 0.047190548927358876 Train Accuracy: 0.9840740740740741
100%|██████████| 844/844 [00:08<00:00, 100.12it/s]
Train Loss: 0.046100782785842793 Train Accuracy: 0.9841481481481481
100%|██████████| 844/844 [00:07<00:00, 108.29it/s]
Train Loss: 0.04420503837182182 Train Accuracy: 0.9849629629629629
100%|██████████| 844/844 [00:09<00:00, 90.20it/s]
Train Loss: 0.045091666309939285 Train Accuracy: 0.9840925925925926
100%|██████████| 844/844 [00:08<00:00, 99.03it/s]
Train Loss: 0.04636539066281008 Train Accuracy: 0.9838333333333333
100%|██████████| 844/844 [00:07<00:00, 108.33it/s]
Train Loss: 0.044687388977771234 Train Accuracy: 0.985
with torch.no_grad():
loss = 0.
correct_labels = 0
total_labels = 0
for batch in tqdm(test_dataloader):
images, labels = batch
images, labels = images.to(device), labels.to(device)
y_hat = demo(images)
batch_loss = F.cross_entropy(y_hat, labels)
loss += batch_loss.item()
correct_labels += torch.sum(torch.argmax(y_hat, dim=1) == labels).item()
total_labels += len(labels)
print("Test Loss:", loss / len(data_loader))
print("Test Accuracy:", correct_labels / total_labels)
100%|██████████| 94/94 [00:01<00:00, 79.54it/s]
Test Loss: 0.005917948178278266 Test Accuracy: 0.984