import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision import transforms
from tqdm import tqdm
mnist = datasets.MNIST(root="./data", train=True, download=True, transform=transforms.ToTensor())
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 503: Service Unavailable Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz Failed to download (trying next): HTTP Error 503: Service Unavailable Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw Processing... Done!
/usr/local/lib/python3.7/dist-packages/torchvision/datasets/mnist.py:502: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:143.) return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
data_loader = torch.utils.data.DataLoader(mnist, batch_size=64, shuffle=True)
Here we define a simple 1-hidden-layer neural network for classification on MNIST. It takes a parameter that determines the hidden size of the hidden layer.
class MNISTNetwork(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.linear_0 = nn.Linear(784, hidden_size)
self.linear_1 = nn.Linear(hidden_size, 10)
def forward(self, inputs):
x = self.linear_0(inputs)
x = torch.sigmoid(x)
return self.linear_1(x)
We will consider three networks.
small_net = MNISTNetwork(1)
large_net = MNISTNetwork(64)
large_net_rand = MNISTNetwork(64)
for p in zip(small_net.parameters(), large_net.parameters()):
p1, p2 = p
p1.data = torch.zeros_like(p1.data)
p2.data = torch.zeros_like(p2.data)
We will train all three networks simulateneously using the same learning rate. After each epoch, we print the current loss of each network.
epochs = 32
optimizer_small = optim.Adam(small_net.parameters(), lr=5e-3)
optimizer_large = optim.Adam(large_net.parameters(), lr=5e-3)
optimizer_large_rand = optim.Adam(large_net_rand.parameters(), lr=5e-3)
for i in range(epochs):
loss_small_epoch = 0.
loss_large_epoch = 0.
loss_large_rand_epoch = 0.
for batch in tqdm(data_loader):
images, labels = batch
images, labels = images, labels
images = images.view(-1, 784)
optimizer_small.zero_grad()
optimizer_large.zero_grad()
optimizer_large_rand.zero_grad()
y_small = small_net(images)
y_large = large_net(images)
y_large_rand = large_net_rand(images)
loss_small = F.cross_entropy(y_small, labels)
loss_large = F.cross_entropy(y_large, labels)
loss_large_rand = F.cross_entropy(y_large_rand, labels)
loss_small_epoch += loss_small.item()
loss_large_epoch += loss_large.item()
loss_large_rand_epoch += loss_large_rand.item()
loss_small.backward()
loss_large.backward()
loss_large_rand.backward()
optimizer_small.step()
optimizer_large.step()
optimizer_large_rand.step()
print("Small Loss:", loss_small_epoch / len(data_loader))
print("Large Loss:", loss_large_epoch / len(data_loader))
print("Large rand Loss:", loss_large_rand_epoch / len(data_loader))
100%|██████████| 938/938 [00:07<00:00, 120.08it/s] 1%|▏ | 13/938 [00:00<00:07, 127.86it/s]
Small Loss: 1.992125940856649 Large Loss: 1.846776630578519 Large rand Loss: 0.32522186862110203
100%|██████████| 938/938 [00:07<00:00, 123.48it/s] 1%|▏ | 14/938 [00:00<00:07, 131.81it/s]
Small Loss: 1.8445630946647384 Large Loss: 1.6948263950185227 Large rand Loss: 0.14667500496140992
100%|██████████| 938/938 [00:07<00:00, 122.84it/s] 1%|▏ | 12/938 [00:00<00:08, 112.32it/s]
Small Loss: 1.8232598135720437 Large Loss: 1.6582392909125225 Large rand Loss: 0.10941013419675007
100%|██████████| 938/938 [00:07<00:00, 123.38it/s] 1%|▏ | 13/938 [00:00<00:07, 126.03it/s]
Small Loss: 1.810512869342812 Large Loss: 1.647986725957663 Large rand Loss: 0.08550640116078354
100%|██████████| 938/938 [00:07<00:00, 124.94it/s] 1%| | 11/938 [00:00<00:08, 106.52it/s]
Small Loss: 1.7984815566524515 Large Loss: 1.6425841501526741 Large rand Loss: 0.07038939598355014
100%|██████████| 938/938 [00:07<00:00, 122.62it/s] 1%|▏ | 12/938 [00:00<00:07, 118.50it/s]
Small Loss: 1.7832257642166447 Large Loss: 1.639439013593995 Large rand Loss: 0.05966398936248561
100%|██████████| 938/938 [00:07<00:00, 121.88it/s] 1%|▏ | 12/938 [00:00<00:08, 114.36it/s]
Small Loss: 1.7638618088226075 Large Loss: 1.634959454729613 Large rand Loss: 0.05089896230567207
100%|██████████| 938/938 [00:07<00:00, 124.47it/s] 1%|▏ | 13/938 [00:00<00:07, 122.39it/s]
Small Loss: 1.7513694408605855 Large Loss: 1.6394141232535275 Large rand Loss: 0.04360822196315124
100%|██████████| 938/938 [00:07<00:00, 123.04it/s] 1%|▏ | 13/938 [00:00<00:07, 123.00it/s]
Small Loss: 1.7410384184007706 Large Loss: 1.6323074300660254 Large rand Loss: 0.03757109730477149
100%|██████████| 938/938 [00:07<00:00, 124.41it/s] 1%|▏ | 12/938 [00:00<00:07, 118.95it/s]
Small Loss: 1.733578706473938 Large Loss: 1.6320092139213578 Large rand Loss: 0.033873332093712065
100%|██████████| 938/938 [00:07<00:00, 124.21it/s] 1%|▏ | 13/938 [00:00<00:07, 123.65it/s]
Small Loss: 1.7297165021459178 Large Loss: 1.6296177059094281 Large rand Loss: 0.029986836188132745
100%|██████████| 938/938 [00:07<00:00, 125.05it/s] 1%|▏ | 13/938 [00:00<00:07, 126.80it/s]
Small Loss: 1.7265885717579041 Large Loss: 1.6225902791470608 Large rand Loss: 0.025942205315528273
100%|██████████| 938/938 [00:07<00:00, 122.25it/s] 1%|▏ | 12/938 [00:00<00:07, 119.60it/s]
Small Loss: 1.7229655160070227 Large Loss: 1.614469338200494 Large rand Loss: 0.022426178650200756
100%|██████████| 938/938 [00:07<00:00, 123.03it/s] 1%|▏ | 13/938 [00:00<00:07, 123.03it/s]
Small Loss: 1.7207565692696236 Large Loss: 1.59702094086706 Large rand Loss: 0.020207441224754333
100%|██████████| 938/938 [00:07<00:00, 120.31it/s] 1%|▏ | 13/938 [00:00<00:07, 122.94it/s]
Small Loss: 1.7177712766092215 Large Loss: 1.5778811080877715 Large rand Loss: 0.018660261891278396
100%|██████████| 938/938 [00:07<00:00, 118.46it/s] 1%|▏ | 14/938 [00:00<00:07, 129.48it/s]
Small Loss: 1.7150879690387864 Large Loss: 1.5686708898432473 Large rand Loss: 0.01802812687989234
100%|██████████| 938/938 [00:07<00:00, 123.87it/s] 1%|▏ | 13/938 [00:00<00:07, 125.58it/s]
Small Loss: 1.7111699737465458 Large Loss: 1.564772816482129 Large rand Loss: 0.01563816340027083
100%|██████████| 938/938 [00:07<00:00, 121.07it/s] 1%|▏ | 12/938 [00:00<00:08, 112.23it/s]
Small Loss: 1.7066961173563877 Large Loss: 1.5604815232728335 Large rand Loss: 0.013912582195142775
100%|██████████| 938/938 [00:07<00:00, 120.50it/s] 1%|▏ | 12/938 [00:00<00:07, 117.04it/s]
Small Loss: 1.7002086607632099 Large Loss: 1.5598054363020957 Large rand Loss: 0.01297786537696099
100%|██████████| 938/938 [00:07<00:00, 122.18it/s] 1%|▏ | 12/938 [00:00<00:07, 119.70it/s]
Small Loss: 1.6938693372171316 Large Loss: 1.5611818760697014 Large rand Loss: 0.014058153353132449
100%|██████████| 938/938 [00:07<00:00, 118.91it/s] 1%|▏ | 12/938 [00:00<00:07, 116.75it/s]
Small Loss: 1.6906237053210293 Large Loss: 1.5555067268261777 Large rand Loss: 0.012793340248741836
100%|██████████| 938/938 [00:07<00:00, 122.38it/s] 1%|▏ | 13/938 [00:00<00:07, 129.27it/s]
Small Loss: 1.6848017213695339 Large Loss: 1.5571024852520876 Large rand Loss: 0.009328897065228112
100%|██████████| 938/938 [00:07<00:00, 117.40it/s] 1%| | 10/938 [00:00<00:09, 93.23it/s]
Small Loss: 1.682982569056025 Large Loss: 1.5609279830318523 Large rand Loss: 0.013200376111322638
100%|██████████| 938/938 [00:07<00:00, 117.90it/s] 1%| | 11/938 [00:00<00:09, 102.84it/s]
Small Loss: 1.682613905050607 Large Loss: 1.5555425987823177 Large rand Loss: 0.010847247160684718
100%|██████████| 938/938 [00:07<00:00, 123.09it/s] 1%|▏ | 13/938 [00:00<00:07, 125.82it/s]
Small Loss: 1.6804717756283563 Large Loss: 1.5529911856153118 Large rand Loss: 0.009251076868505154
100%|██████████| 938/938 [00:07<00:00, 123.04it/s] 1%|▏ | 13/938 [00:00<00:07, 123.78it/s]
Small Loss: 1.679624924908823 Large Loss: 1.5580362841518702 Large rand Loss: 0.009180401486334025
100%|██████████| 938/938 [00:07<00:00, 123.22it/s] 1%| | 10/938 [00:00<00:09, 96.38it/s]
Small Loss: 1.6756493868604143 Large Loss: 1.5558550791191395 Large rand Loss: 0.010234546676092152
100%|██████████| 938/938 [00:07<00:00, 121.73it/s] 1%|▏ | 12/938 [00:00<00:07, 118.65it/s]
Small Loss: 1.6781808355215515 Large Loss: 1.5566776331299659 Large rand Loss: 0.008899102125460578
100%|██████████| 938/938 [00:07<00:00, 118.50it/s] 1%|▏ | 12/938 [00:00<00:08, 111.40it/s]
Small Loss: 1.6757417534714314 Large Loss: 1.5549392756114382 Large rand Loss: 0.008694451855173537
100%|██████████| 938/938 [00:07<00:00, 117.82it/s] 1%| | 10/938 [00:00<00:09, 99.43it/s]
Small Loss: 1.6722396652835774 Large Loss: 1.5535656431082214 Large rand Loss: 0.007481966306253475
100%|██████████| 938/938 [00:07<00:00, 119.21it/s] 1%| | 11/938 [00:00<00:08, 108.50it/s]
Small Loss: 1.671703512607607 Large Loss: 1.5559691183094277 Large rand Loss: 0.0073218714386225495
100%|██████████| 938/938 [00:07<00:00, 119.91it/s]
Small Loss: 1.6672354841283135 Large Loss: 1.553782471969946 Large rand Loss: 0.008971417971195494
W_0 = large_net.linear_0.weight
b_0 = large_net.linear_0.bias
W_1 = large_net.linear_1.weight
b_1 = large_net.linear_1.bias
print("W_0 => All weights equal for each hidden unit:", (W_0[0, :].unsqueeze(0) == W_0).all().item())
print("Example of weights:")
print(W_0[:, 256])
W_0 => All weights equal for each hidden unit: True Example of weights: tensor([-0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256, -0.0256], grad_fn=<SelectBackward>)
print("W_1 => All weights equal for each hidden unit:", (W_1[:, 0].unsqueeze(-1) == W_1).all().item())
print("Weights:")
print(W_1[8])
W_1 => All weights equal for each hidden unit: True Weights: tensor([-0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049, -0.2049], grad_fn=<SelectBackward>)
print("b_0 => All biases equal for each hidden unit:", (b_0[0] == b_0).all().item())
print("Bias:")
print(b_0)
b_0 => All biases equal for each hidden unit: True Bias: Parameter containing: tensor([-1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603, -1.5603], requires_grad=True)
print("b_1 => All biases equal for each hidden unit:", (b_1[0] == b_1).all().item())
print("Bias:")
print(b_1)
b_1 => All biases equal for each hidden unit: False Bias: Parameter containing: tensor([ 3.2585, -7.8880, 2.7643, 1.4902, -0.2802, 1.9114, 3.9217, -3.0606, 1.2542, -1.7107], requires_grad=True)
Below is an implementation of the network from the section handout. In the forward pass, there are print statements to print the size of the data as it flows through the network and the size of the weights and biases of the layers. Note that this network is just for demonstration and would not work well at all in practice.
class DemoNetwork(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 16, 3, 1, 1)
self.max1 = nn.MaxPool2d(2, 2, 0)
self.conv2 = nn.Conv2d(16, 32, 3, 1, 0)
self.max2 = nn.MaxPool2d(2, 2, 1)
self.conv3 = nn.Conv2d(32, 8, 1, 1, 0)
self.conv4 = nn.Conv2d(8, 4, 5, 1, 0)
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(576, 10)
def forward(self, inputs):
x = self.conv1(inputs)
print("Data shape:", x.shape)
print("Weight shape:", self.conv1.weight.shape, "Bias shape:", self.conv1.bias.shape)
x = self.max1(x)
print("Data shape:", x.shape)
x = self.conv2(x)
print("Data shape:", x.shape)
print("Weight shape:", self.conv2.weight.shape, "Bias shape:", self.conv2.bias.shape)
x = self.max2(x)
print("Data shape:", x.shape)
x = self.conv3(x)
print("Data shape:", x.shape)
print("Weight shape:", self.conv3.weight.shape, "Bias shape:", self.conv3.bias.shape)
x = self.conv4(x)
print("Data shape:", x.shape)
print("Weight shape:", self.conv4.weight.shape, "Bias shape:", self.conv4.bias.shape)
x = self.flatten(x)
print("Data shape:", x.shape)
x = self.linear1(x)
print("Data shape:", x.shape)
print("Weight shape:", self.linear1.weight.shape, "Bias shape:", self.linear1.bias.shape)
return x
demo = DemoNetwork()
_ = demo(torch.zeros(9, 3, 64, 64))
Data shape: torch.Size([9, 16, 64, 64]) Weight shape: torch.Size([16, 3, 3, 3]) Bias shape: torch.Size([16]) Data shape: torch.Size([9, 16, 32, 32]) Data shape: torch.Size([9, 32, 30, 30]) Weight shape: torch.Size([32, 16, 3, 3]) Bias shape: torch.Size([32]) Data shape: torch.Size([9, 32, 16, 16]) Data shape: torch.Size([9, 8, 16, 16]) Weight shape: torch.Size([8, 32, 1, 1]) Bias shape: torch.Size([8]) Data shape: torch.Size([9, 4, 12, 12]) Weight shape: torch.Size([4, 8, 5, 5]) Bias shape: torch.Size([4]) Data shape: torch.Size([9, 576]) Data shape: torch.Size([9, 10]) Weight shape: torch.Size([10, 576]) Bias shape: torch.Size([10])