In [22]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision import transforms
import torchvision.utils
from tqdm import tqdm
import matplotlib.pyplot as plt
In [23]:
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())

Constructing the DataLoader¶

The PyTorch DataLoader class is an efficient implementation of an iterator that can perform useful preprocessing and returns batches of elements. Here, we use its ability to batch and shuffle data, but DataLoaders are capable of much more.

Note that each time we iterate over a DataLoader, it starts again from the beginning.

Below we use torchvision.utils.make_grid() to show a sample batch of inputs.

In [24]:
data_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)

# Show one batch of images. Each batch of images has shape [batch_size, 1, 28, 28],
# where 1 is the "channels" dimension of the image.
for images,labels in data_loader:
    grid_img = torchvision.utils.make_grid(images)
    plt.imshow(grid_img.permute(1, 2, 0))
    plt.title("A single batch of images")
    break

Defining the Network¶

Here we define a simple 1-hidden-layer neural network for classification on MNIST. It takes a parameter that determines the hidden size of the hidden layer.

In [25]:
class MNISTNetwork(nn.Module):
    def __init__(self, hidden_size):
        super().__init__()
        self.linear_0 = nn.Linear(784, hidden_size)
        self.linear_1 = nn.Linear(hidden_size, 10)
        
    def forward(self, inputs):
        x = self.linear_0(inputs)
        x = torch.sigmoid(x)
        return self.linear_1(x)

Instantiating the Networks¶

We will consider three networks.

  1. One that only has a single hidden unit and all of its weights are initialized to exactly 0.
  2. One that has 64 hidden units and all of its weights are initialized to exactly 0.
  3. One that has 64 hidden units and the weights are initialized using Torch's default, Kaiming Uniform initialization.

In the code below, we utilize some important PyTorch methods which you'll want to be familiar with. This includes:

  • torch.nn.Module.parameters(): Returns an iterator over module parameters (i.e. for passing to an optimizer that will update those parameters).

  • torch.Tensor.view(): Returns a view into the original Tensor. The result of this method shares the same underlying data as the input Tensor. This avoids copying the data, which means it can be mnore efficient, but it also means that when the original Tensor is modified, so is the view!

  • torch.Tensor.item(): Returns the value of a single-element Tensor as a standard Python number. This only works for tensors with one element. For other cases, see torch.Tensor.tolist().

  • torch.Tensor.backward(): Computes the gradients of current tensor wrt the graph leaves (note that this is only called if Tensor.requires_grad is True, which is the case by default). After calling this, a Tensor's .grad attribute is updated with the current gradients. These are used, for example, when calling .step() method of an optimizer.

  • torch.optim.Optimizer.zero_grad(): Sets the gradients of all variables to zero. This should be conducted before each step of an optimization procedure (i.e., for each batch of training a DNN). If .zero_grad() is not called, gradients accumulate (add) over iterations.

In [26]:
small_net = MNISTNetwork(1)
large_net = MNISTNetwork(64)
large_net_rand = MNISTNetwork(64)
In [27]:
for p in zip(small_net.parameters(), large_net.parameters()):
    p1, p2 = p
    p1.data = torch.zeros_like(p1.data)
    p2.data = torch.zeros_like(p2.data)

Training¶

We will train all three networks simulateneously using the same learning rate. After each epoch, we print the current loss of each network.

In [28]:
epochs = 32

optimizer_small = optim.Adam(small_net.parameters(), lr=5e-3)
optimizer_large = optim.Adam(large_net.parameters(), lr=5e-3)
optimizer_large_rand = optim.Adam(large_net_rand.parameters(), lr=5e-3)

for i in range(epochs):
    loss_small_epoch = 0.
    loss_large_epoch = 0.
    loss_large_rand_epoch = 0.
    
    for batch in tqdm(data_loader):
        images, labels = batch
        images, labels = images, labels
        
        images = images.view(-1, 784)
        
        optimizer_small.zero_grad()
        optimizer_large.zero_grad()
        optimizer_large_rand.zero_grad()
        
        y_small = small_net(images)
        y_large = large_net(images)
        y_large_rand = large_net_rand(images)
        
        loss_small = F.cross_entropy(y_small, labels)
        loss_large = F.cross_entropy(y_large, labels)
        loss_large_rand = F.cross_entropy(y_large_rand, labels)
        
        loss_small_epoch += loss_small.item()
        loss_large_epoch += loss_large.item()
        loss_large_rand_epoch += loss_large_rand.item()
        
        loss_small.backward()
        loss_large.backward()
        loss_large_rand.backward()
        
        optimizer_small.step()
        optimizer_large.step()
        optimizer_large_rand.step()
        
    print("Small Loss:", loss_small_epoch / len(data_loader))
    print("Large Loss:", loss_large_epoch / len(data_loader))
    print("Large rand Loss:", loss_large_rand_epoch / len(data_loader))
100%|██████████| 938/938 [00:08<00:00, 110.92it/s]
Small Loss: 1.9747004094662697
Large Loss: 1.8210063917296273
Large rand Loss: 0.3225371094106802
100%|██████████| 938/938 [00:08<00:00, 106.81it/s]
Small Loss: 1.8062316338136506
Large Loss: 1.6598941615141276
Large rand Loss: 0.1449374124618259
100%|██████████| 938/938 [00:08<00:00, 106.02it/s]
Small Loss: 1.7683436330447573
Large Loss: 1.6282896525315893
Large rand Loss: 0.10546199680427149
100%|██████████| 938/938 [00:08<00:00, 105.37it/s]
Small Loss: 1.745496065250592
Large Loss: 1.611243960318535
Large rand Loss: 0.08176170171959314
100%|██████████| 938/938 [00:08<00:00, 115.63it/s]
Small Loss: 1.7230805862687035
Large Loss: 1.6086191712920346
Large rand Loss: 0.06632703205229028
100%|██████████| 938/938 [00:08<00:00, 106.73it/s]
Small Loss: 1.6960091592152235
Large Loss: 1.6033877005963437
Large rand Loss: 0.056162314508348576
100%|██████████| 938/938 [00:08<00:00, 107.39it/s]
Small Loss: 1.6704798057389412
Large Loss: 1.601104776488184
Large rand Loss: 0.048491643619850706
100%|██████████| 938/938 [00:08<00:00, 117.15it/s]
Small Loss: 1.6492018880112085
Large Loss: 1.6005864369589637
Large rand Loss: 0.04086352238900213
100%|██████████| 938/938 [00:08<00:00, 107.32it/s]
Small Loss: 1.631801547272119
Large Loss: 1.5975915173223532
Large rand Loss: 0.037130898026539
100%|██████████| 938/938 [00:08<00:00, 106.40it/s]
Small Loss: 1.6177405220613297
Large Loss: 1.5955544795308794
Large rand Loss: 0.03204905803090454
100%|██████████| 938/938 [00:08<00:00, 111.08it/s]
Small Loss: 1.6089600046306276
Large Loss: 1.5970904198028386
Large rand Loss: 0.02750099962080783
100%|██████████| 938/938 [00:08<00:00, 111.01it/s]
Small Loss: 1.6031588676895923
Large Loss: 1.5954826061151175
Large rand Loss: 0.02494130808654537
100%|██████████| 938/938 [00:08<00:00, 105.97it/s]
Small Loss: 1.5941073978379336
Large Loss: 1.59343382125216
Large rand Loss: 0.022491324605218677
100%|██████████| 938/938 [00:08<00:00, 105.73it/s]
Small Loss: 1.592615304979434
Large Loss: 1.5952912087379487
Large rand Loss: 0.01928929526280405
100%|██████████| 938/938 [00:08<00:00, 115.24it/s]
Small Loss: 1.5848750128929041
Large Loss: 1.5933883008418053
Large rand Loss: 0.017751067268148586
100%|██████████| 938/938 [00:09<00:00, 103.21it/s]
Small Loss: 1.5789772860531106
Large Loss: 1.594103632705298
Large rand Loss: 0.016460558151617014
100%|██████████| 938/938 [00:08<00:00, 105.42it/s]
Small Loss: 1.5755586494514937
Large Loss: 1.5923805371530528
Large rand Loss: 0.016497078008015455
100%|██████████| 938/938 [00:07<00:00, 117.27it/s]
Small Loss: 1.5722746952001982
Large Loss: 1.5920828934163174
Large rand Loss: 0.013686977899110844
100%|██████████| 938/938 [00:08<00:00, 106.64it/s]
Small Loss: 1.5709358697761096
Large Loss: 1.5918515444056058
Large rand Loss: 0.01256602146017523
100%|██████████| 938/938 [00:08<00:00, 106.35it/s]
Small Loss: 1.5678447042701087
Large Loss: 1.5934399259624197
Large rand Loss: 0.015225826330961529
100%|██████████| 938/938 [00:08<00:00, 111.82it/s]
Small Loss: 1.5655181012682315
Large Loss: 1.5908296955928112
Large rand Loss: 0.010799695614480607
100%|██████████| 938/938 [00:08<00:00, 109.93it/s]
Small Loss: 1.5624548138331757
Large Loss: 1.5924690467462357
Large rand Loss: 0.012164810865083392
100%|██████████| 938/938 [00:08<00:00, 105.01it/s]
Small Loss: 1.562048412081021
Large Loss: 1.5917949865875975
Large rand Loss: 0.010378873941811499
100%|██████████| 938/938 [00:08<00:00, 105.63it/s]
Small Loss: 1.5588710741447742
Large Loss: 1.5909390035214455
Large rand Loss: 0.008664014752820392
100%|██████████| 938/938 [00:08<00:00, 116.55it/s]
Small Loss: 1.5587773711950794
Large Loss: 1.5918893418840763
Large rand Loss: 0.009671726888597338
100%|██████████| 938/938 [00:08<00:00, 106.70it/s]
Small Loss: 1.5614188310942416
Large Loss: 1.5933234642055243
Large rand Loss: 0.009133911602114868
100%|██████████| 938/938 [00:08<00:00, 105.90it/s]
Small Loss: 1.5554405610952804
Large Loss: 1.5895309865093434
Large rand Loss: 0.006141310933210551
100%|██████████| 938/938 [00:08<00:00, 115.72it/s]
Small Loss: 1.5551571632499126
Large Loss: 1.590696329374049
Large rand Loss: 0.011262383425718362
100%|██████████| 938/938 [00:08<00:00, 106.35it/s]
Small Loss: 1.555824849270046
Large Loss: 1.5892616450659502
Large rand Loss: 0.011654529858772496
100%|██████████| 938/938 [00:08<00:00, 106.49it/s]
Small Loss: 1.557392876158391
Large Loss: 1.5905824543824836
Large rand Loss: 0.00701571908392653
100%|██████████| 938/938 [00:08<00:00, 109.28it/s]
Small Loss: 1.5533092093111864
Large Loss: 1.589730658383766
Large rand Loss: 0.005507708725134489
100%|██████████| 938/938 [00:08<00:00, 112.20it/s]
Small Loss: 1.5584697225200597
Large Loss: 1.5897600718144416
Large rand Loss: 0.007374370418365521

In [29]:
W_0 = large_net.linear_0.weight
b_0 = large_net.linear_0.bias

W_1 = large_net.linear_1.weight
b_1 = large_net.linear_1.bias
In [30]:
print("W_0 => All weights equal for each hidden unit:", (W_0[0, :].unsqueeze(0) == W_0).all().item())
print("Example of weights:")
print(W_0[:, 256])
W_0 => All weights equal for each hidden unit: True
Example of weights:
tensor([0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595, 0.2595,
        0.2595], grad_fn=<SelectBackward0>)
In [31]:
print("W_1 => All weights equal for each hidden unit:", (W_1[:, 0].unsqueeze(-1) == W_1).all().item())
print("Weights:")
print(W_1[8])
W_1 => All weights equal for each hidden unit: True
Weights:
tensor([-0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400,
        -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400, -0.0400],
       grad_fn=<SelectBackward0>)
In [32]:
print("b_0 => All biases equal for each hidden unit:", (b_0[0] == b_0).all().item())
print("Bias:")
print(b_0)
b_0 => All biases equal for each hidden unit: True
Bias:
Parameter containing:
tensor([-0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500,
        -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500, -0.7500],
       requires_grad=True)
In [33]:
print("b_1 => All biases equal for each hidden unit:", (b_1[0] == b_1).all().item())
print("Bias:")
print(b_1)
b_1 => All biases equal for each hidden unit: False
Bias:
Parameter containing:
tensor([ 2.5093, -9.0861, -3.1097,  0.5818,  2.8968,  1.6153, -1.6222,  3.9712,
         0.4858,  3.9563], requires_grad=True)

Tensor and Layer sizes¶

Below is an implementation of the network from the section handout. We use torchinfo-summary() to view the size of the data as it flows through the network; additionally, we print and the size of the weights and biases of the layers during a forward pass. Note that this network is just for demonstration and may not work well in practice.

Note: this section uses the torchinfo package; see the github repo for installation instructions or run one of the following lines below:

install via conda:

conda install -c conda-forge torchinfo

install via pip:

pip install torchinfo
In [48]:
from torchinfo import summary

class DemoNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv1 = nn.Conv2d(1, 16, 3, 1, 1)
        self.max1 = nn.MaxPool2d(2, 2, 0)
        self.conv2 = nn.Conv2d(16, 32, 3, 1, 0)
        self.max2 = nn.MaxPool2d(2, 2, 1)
        self.conv3 = nn.Conv2d(32, 8, 1, 1, 0)
        self.conv4 = nn.Conv2d(8, 4, 5, 1, 0)
        self.flatten = nn.Flatten()
        self.linear1 = nn.Linear(36, 10)
    
    @property
    def trainable_layers(self):
        """A utility property to easily access a list of all model layers."""
        return [self.conv1, self.conv2, self.conv3, self.conv4, self.linear1]
        
    def forward(self, inputs):
        """Implements the forward pass."""
        x = self.conv1(inputs)
        x = self.max1(x)
        x = self.conv2(x)
        x = self.max2(x)
        x = self.conv3(x)
        x = self.conv4(x)
        x = self.flatten(x)
        x = self.linear1(x)
        return x

    def print_weight_shapes(self):
        """Utility function to print the shapes of weights in trainable layers."""
        for layer in self.trainable_layers:
            print(f"Weight shape: {layer.weight.shape}; Bias shape: {layer.bias.shape}")

demo = DemoNetwork()
batch_size = 64
summary(demo, input_size=(batch_size, 1, 28, 28))
Out[48]:
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
DemoNetwork                              [64, 10]                  --
├─Conv2d: 1-1                            [64, 16, 28, 28]          160
├─MaxPool2d: 1-2                         [64, 16, 14, 14]          --
├─Conv2d: 1-3                            [64, 32, 12, 12]          4,640
├─MaxPool2d: 1-4                         [64, 32, 7, 7]            --
├─Conv2d: 1-5                            [64, 8, 7, 7]             264
├─Conv2d: 1-6                            [64, 4, 3, 3]             804
├─Flatten: 1-7                           [64, 36]                  --
├─Linear: 1-8                            [64, 10]                  370
==========================================================================================
Total params: 6,238
Trainable params: 6,238
Non-trainable params: 0
Total mult-adds (M): 52.11
==========================================================================================
Input size (MB): 0.20
Forward/backward pass size (MB): 9.01
Params size (MB): 0.02
Estimated Total Size (MB): 9.23
==========================================================================================
In [49]:
demo.print_weight_shapes()
Weight shape: torch.Size([16, 1, 3, 3]); Bias shape: torch.Size([16])
Weight shape: torch.Size([32, 16, 3, 3]); Bias shape: torch.Size([32])
Weight shape: torch.Size([8, 32, 1, 1]); Bias shape: torch.Size([8])
Weight shape: torch.Size([4, 8, 5, 5]); Bias shape: torch.Size([4])
Weight shape: torch.Size([10, 36]); Bias shape: torch.Size([10])
In [59]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
train_data, test_data = torch.utils.data.random_split(mnist, [0.9, 0.1])
train_dataloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=64, shuffle=True)

epochs = 16
optimizer = optim.Adam(demo.parameters(), lr=5e-3)

for i in range(epochs):
    loss = 0.
    correct_labels = 0
    total_labels = 0
    
    for batch in tqdm(train_dataloader):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        y_hat = demo(images)
        batch_loss = F.cross_entropy(y_hat, labels)
        batch_loss.backward()
        optimizer.step()

        loss += batch_loss.item()
        correct_labels += torch.sum(torch.argmax(y_hat, dim=1) == labels).item()
        total_labels += len(labels)
        
    print("Train Loss:", loss / len(data_loader))
    print("Train Accuracy:", correct_labels / total_labels)
100%|██████████| 844/844 [00:08<00:00, 95.84it/s] 
Train Loss: 0.052556460685644506
Train Accuracy: 0.9822777777777778
100%|██████████| 844/844 [00:07<00:00, 109.86it/s]
Train Loss: 0.05048376267987749
Train Accuracy: 0.9822037037037037
100%|██████████| 844/844 [00:11<00:00, 74.26it/s] 
Train Loss: 0.049236171157421994
Train Accuracy: 0.9834259259259259
100%|██████████| 844/844 [00:08<00:00, 100.83it/s]
Train Loss: 0.05079400505832603
Train Accuracy: 0.9826111111111111
100%|██████████| 844/844 [00:09<00:00, 86.99it/s]
Train Loss: 0.04830449687629523
Train Accuracy: 0.9828888888888889
100%|██████████| 844/844 [00:10<00:00, 79.22it/s]
Train Loss: 0.04747499192038229
Train Accuracy: 0.9830740740740741
100%|██████████| 844/844 [00:07<00:00, 110.20it/s]
Train Loss: 0.04763950119656362
Train Accuracy: 0.9835555555555555
100%|██████████| 844/844 [00:08<00:00, 99.47it/s] 
Train Loss: 0.04605868608596909
Train Accuracy: 0.984
100%|██████████| 844/844 [00:08<00:00, 100.11it/s]
Train Loss: 0.044495291716124645
Train Accuracy: 0.9844814814814815
100%|██████████| 844/844 [00:07<00:00, 108.20it/s]
Train Loss: 0.0477359937937854
Train Accuracy: 0.9832222222222222
100%|██████████| 844/844 [00:08<00:00, 99.90it/s] 
Train Loss: 0.047190548927358876
Train Accuracy: 0.9840740740740741
100%|██████████| 844/844 [00:08<00:00, 100.12it/s]
Train Loss: 0.046100782785842793
Train Accuracy: 0.9841481481481481
100%|██████████| 844/844 [00:07<00:00, 108.29it/s]
Train Loss: 0.04420503837182182
Train Accuracy: 0.9849629629629629
100%|██████████| 844/844 [00:09<00:00, 90.20it/s] 
Train Loss: 0.045091666309939285
Train Accuracy: 0.9840925925925926
100%|██████████| 844/844 [00:08<00:00, 99.03it/s] 
Train Loss: 0.04636539066281008
Train Accuracy: 0.9838333333333333
100%|██████████| 844/844 [00:07<00:00, 108.33it/s]
Train Loss: 0.044687388977771234
Train Accuracy: 0.985

In [60]:
with torch.no_grad():
    loss = 0.
    correct_labels = 0
    total_labels = 0
    for batch in tqdm(test_dataloader):
        images, labels = batch
        images, labels = images.to(device), labels.to(device)
        y_hat = demo(images)
        batch_loss = F.cross_entropy(y_hat, labels)

        loss += batch_loss.item()
        correct_labels += torch.sum(torch.argmax(y_hat, dim=1) == labels).item()
        total_labels += len(labels)
        
    print("Test Loss:", loss / len(data_loader))
    print("Test Accuracy:", correct_labels / total_labels)
100%|██████████| 94/94 [00:01<00:00, 79.54it/s]
Test Loss: 0.005917948178278266
Test Accuracy: 0.984