{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# PyTorch Introduction\n", "\n", "Today, we will be intoducing PyTorch, \"an open source deep learning platform that provides a seamless path from research prototyping to production deployment\".\n", "\n", "This notebook is by no means comprehensive. If you have any questions the documentation and Google are your friends.\n", "\n", "Goal takeaways:\n", "- Automatic differentiation is a powerful tool\n", "- PyTorch implements common functions used in deep learning" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "\n", "from mpl_toolkits.mplot3d import Axes3D\n", "import matplotlib.pyplot as plt\n", "\n", "import numpy as np\n", "\n", "torch.manual_seed(446)\n", "np.random.seed(446)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Tensors and relation to numpy\n", "\n", "By this point, we have worked with numpy quite a bit. PyTorch's basic building block, the `tensor` is similar to numpy's `ndarray`" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "x_numpy, x_torch\n", "[0.1 0.2 0.3] tensor([0.1000, 0.2000, 0.3000])\n", "\n", "to and from numpy and pytorch\n", "tensor([0.1000, 0.2000, 0.3000], dtype=torch.float64) [0.1 0.2 0.3]\n", "\n", "x+y\n", "[3.1 4.2 5.3] tensor([3.1000, 4.2000, 5.3000])\n", "\n", "norm\n", "0.37416573867739417 tensor(0.3742)\n", "\n", "mean along the 0th dimension\n", "[2. 3.] tensor([2., 3.])\n" ] } ], "source": [ "# we create tensors in a similar way to numpy nd arrays\n", "x_numpy = np.array([0.1, 0.2, 0.3])\n", "x_torch = torch.tensor([0.1, 0.2, 0.3])\n", "print('x_numpy, x_torch')\n", "print(x_numpy, x_torch)\n", "print()\n", "\n", "# to and from numpy, pytorch\n", "print('to and from numpy and pytorch')\n", "print(torch.from_numpy(x_numpy), x_torch.numpy())\n", "print()\n", "\n", "# we can do basic operations like +-*/\n", "y_numpy = np.array([3,4,5.])\n", "y_torch = torch.tensor([3,4,5.])\n", "print(\"x+y\")\n", "print(x_numpy + y_numpy, x_torch + y_torch)\n", "print()\n", "\n", "# many functions that are in numpy are also in pytorch\n", "print(\"norm\")\n", "print(np.linalg.norm(x_numpy), torch.norm(x_torch))\n", "print()\n", "\n", "# to apply an operation along a dimension,\n", "# we use the dim keyword argument instead of axis\n", "print(\"mean along the 0th dimension\")\n", "x_numpy = np.array([[1,2],[3,4.]])\n", "x_torch = torch.tensor([[1,2],[3,4.]])\n", "print(np.mean(x_numpy, axis=0), torch.mean(x_torch, dim=0))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### `Tensor.view`\n", "We can use the `Tensor.view()` function to reshape tensors similarly to `numpy.reshape()`\n", "\n", "It can also automatically calculate the correct dimension if a `-1` is passed in. This is useful if we are working with batches, but the batch size is unknown." ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10000, 3, 28, 28])\n", "torch.Size([10000, 3, 784])\n", "torch.Size([10000, 3, 784])\n" ] } ], "source": [ "# \"MNIST\"\n", "N, C, W, H = 10000, 3, 28, 28\n", "X = torch.randn((N, C, W, H))\n", "\n", "print(X.shape)\n", "print(X.view(N, C, 784).shape)\n", "print(X.view(-1, C, 784).shape) # automatically choose the 0th dimension" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computation graphs\n", "\n", "What's special about PyTorch's `tensor` object is that it implicitly creates a computation graph in the background. A computation graph is a a way of writing a mathematical expression as a graph. There is an algorithm to compute the gradients of all the variables of a computation graph in time on the same order it is to compute the function itself.\n", "\n", "Consider the expression $e=(a+b)*(b+1)$ with values $a=2, b=1$. We can draw the evaluated computation graph as\n", "
\n", "
\n", "\n", "![tree-img](./tree-eval.png)\n", "\n", "[source](https://colah.github.io/posts/2015-08-Backprop/)\n", "\n", "In PyTorch, we can write this as" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "c tensor(3., grad_fn=)\n", "d tensor(2., grad_fn=)\n", "e tensor(6., grad_fn=)\n" ] } ], "source": [ "a = torch.tensor(2.0, requires_grad=True) # we set requires_grad=True to let PyTorch know to keep the graph\n", "b = torch.tensor(1.0, requires_grad=True)\n", "c = a + b\n", "d = b + 1\n", "e = c * d\n", "print('c', c)\n", "print('d', d)\n", "print('e', e)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can see that PyTorch kept track of the computation graph for us." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## PyTorch as an auto grad framework\n", "\n", "Now that we have seen that PyTorch keeps the graph around for us, let's use it to compute some gradients for us.\n", "\n", "Consider the function $f(x) = (x-2)^2$.\n", "\n", "Q: Compute $\\frac{d}{dx} f(x)$ and then compute $f'(1)$.\n", "\n", "We make a `backward()` call on the leaf variable (`y`) in the computation, computing all the gradients of `y` at once." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Analytical f'(x): tensor([-2.], grad_fn=)\n", "PyTorch's f'(x): tensor([-2.])\n" ] } ], "source": [ "def f(x):\n", " return (x-2)**2\n", "\n", "def fp(x):\n", " return 2*(x-2)\n", "\n", "x = torch.tensor([1.0], requires_grad=True)\n", "\n", "y = f(x)\n", "y.backward()\n", "\n", "print('Analytical f\\'(x):', fp(x))\n", "print('PyTorch\\'s f\\'(x):', x.grad)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It can also find gradients of functions.\n", "\n", "Let $w = [w_1, w_2]^T$\n", "\n", "Consider $g(w) = 2w_1w_2 + w_2\\cos(w_1)$\n", "\n", "Q: Compute $\\nabla_w g(w)$ and verify $\\nabla_w g([\\pi,1]) = [2, \\pi - 1]^T$" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Analytical grad g(w) tensor([2.0000, 5.2832])\n", "PyTorch's grad g(w) tensor([2.0000, 5.2832])\n" ] } ], "source": [ "def g(w):\n", " return 2*w[0]*w[1] + w[1]*torch.cos(w[0])\n", "\n", "def grad_g(w):\n", " return torch.tensor([2*w[1] - w[1]*torch.sin(w[0]), 2*w[0] + torch.cos(w[0])])\n", "\n", "w = torch.tensor([np.pi, 1], requires_grad=True)\n", "\n", "z = g(w)\n", "z.backward()\n", "\n", "print('Analytical grad g(w)', grad_g(w))\n", "print('PyTorch\\'s grad g(w)', w.grad)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using the gradients\n", "Now that we have gradients, we can use our favorite optimization algorithm: gradient descent!\n", "\n", "Let $f$ the same function we defined above.\n", "\n", "Q: What is the value of $x$ that minimizes $f$?" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tx,\tf(x),\tf'(x),\tf'(x) pytorch\n", "0,\t5.000,\t9.000,\t6.000,\t6.000\n", "1,\t3.500,\t2.250,\t3.000,\t3.000\n", "2,\t2.750,\t0.562,\t1.500,\t1.500\n", "3,\t2.375,\t0.141,\t0.750,\t0.750\n", "4,\t2.188,\t0.035,\t0.375,\t0.375\n", "5,\t2.094,\t0.009,\t0.188,\t0.188\n", "6,\t2.047,\t0.002,\t0.094,\t0.094\n", "7,\t2.023,\t0.001,\t0.047,\t0.047\n", "8,\t2.012,\t0.000,\t0.023,\t0.023\n", "9,\t2.006,\t0.000,\t0.012,\t0.012\n", "10,\t2.003,\t0.000,\t0.006,\t0.006\n", "11,\t2.001,\t0.000,\t0.003,\t0.003\n", "12,\t2.001,\t0.000,\t0.001,\t0.001\n", "13,\t2.000,\t0.000,\t0.001,\t0.001\n", "14,\t2.000,\t0.000,\t0.000,\t0.000\n" ] } ], "source": [ "x = torch.tensor([5.0], requires_grad=True)\n", "step_size = 0.25\n", "\n", "print('iter,\\tx,\\tf(x),\\tf\\'(x),\\tf\\'(x) pytorch')\n", "for i in range(15):\n", " y = f(x)\n", " y.backward() # compute the gradient\n", " \n", " print('{},\\t{:.3f},\\t{:.3f},\\t{:.3f},\\t{:.3f}'.format(i, x.item(), f(x).item(), fp(x).item(), x.grad.item()))\n", " \n", " x.data = x.data - step_size * x.grad # perform a GD update step\n", " \n", " # We need to zero the grad variable since the backward()\n", " # call accumulates the gradients in .grad instead of overwriting.\n", " # The detach_() is for efficiency. You do not need to worry too much about it.\n", " x.grad.detach_()\n", " x.grad.zero_()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Linear Regression\n", "\n", "Now, instead of minimizing a made-up function, lets minimize a loss function on some made-up data.\n", "\n", "We will implement Gradient Descent in order to solve the task of linear regression." ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "X shape torch.Size([50, 2])\n", "y shape torch.Size([50, 1])\n", "w shape torch.Size([2, 1])\n" ] } ], "source": [ "# make a simple linear dataset with some noise\n", "\n", "d = 2\n", "n = 50\n", "X = torch.randn(n,d)\n", "true_w = torch.tensor([[-1.0], [2.0]])\n", "y = X @ true_w + torch.randn(n,1) * 0.1\n", "print('X shape', X.shape)\n", "print('y shape', y.shape)\n", "print('w shape', true_w.shape)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Note: dimensions\n", "PyTorch does a lot of operations on batches of data. The convention is to have your data be of size $(N, d)$ where $N$ is the size of the batch of data." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# visualize the dataset\n", "\n", "fig = plt.figure()\n", "ax = fig.add_subplot(111, projection='3d')\n", "\n", "ax.scatter(X[:,0].numpy(), X[:,1].numpy(), y.numpy(), c='r', marker='o')\n", "\n", "ax.set_xlabel('$X_1$')\n", "ax.set_ylabel('$X_2$')\n", "ax.set_zlabel('$Y$')\n", "\n", "plt.title('Dataset')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def visualize_fun(w, title, num_pts=20):\n", " \n", " x1, x2 = np.meshgrid(np.linspace(-2,2, num_pts), np.linspace(-2,2, num_pts))\n", " X_plane = torch.tensor(np.stack([np.reshape(x1, (num_pts**2)), np.reshape(x2, (num_pts**2))], axis=1)).float()\n", " y_plane = np.reshape((X_plane @ w).detach().numpy(), (num_pts, num_pts))\n", " \n", " plt3d = plt.figure().gca(projection='3d')\n", " plt3d.plot_surface(x1, x2, y_plane, alpha=0.2)\n", "\n", " ax = plt.gca()\n", " ax.scatter(X[:,0].numpy(), X[:,1].numpy(), y.numpy(), c='r', marker='o')\n", "\n", " ax.set_xlabel('$X_1$')\n", " ax.set_ylabel('$X_2$')\n", " ax.set_zlabel('$Y$')\n", " \n", " plt.title(title)\n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "visualize_fun(true_w, 'Dataset and true $w$')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sanity check\n", "To verify PyTorch is computing the gradients correctly, let's recall the gradient for the RSS objective:\n", "\n", "$$\\nabla_w \\mathcal{L}_{RSS}(w; X) = \\nabla_w\\frac{1}{n} ||y - Xw||_2^2 = -\\frac{2}{n}X^T(y-Xw)$$\n", "\n", "Let's see if the match up:" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Analytical gradient [ 5.1867113 -5.5912566]\n", "PyTorch's gradient [ 5.186712 -5.5912566]\n" ] } ], "source": [ "# define a linear model with no bias\n", "def model(X, w):\n", " return X @ w\n", "\n", "# the residual sum of squares loss function\n", "def rss(y, y_hat):\n", " return torch.norm(y - y_hat)**2 / n\n", "\n", "# analytical expression for the gradient\n", "def grad_rss(X, y, w):\n", " return -2*X.t() @ (y - X @ w) / n\n", "\n", "w = torch.tensor([[1.], [0]], requires_grad=True)\n", "y_hat = model(X, w)\n", "\n", "loss = rss(y, y_hat)\n", "loss.backward()\n", "\n", "print('Analytical gradient', grad_rss(X, y, w).detach().view(2).numpy())\n", "print('PyTorch\\'s gradient', w.grad.view(2).numpy())\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now that we've seen PyTorch is doing the right think, let's use the gradients!\n", "\n", "## Linear regression using GD with automatically computed derivatives\n", "\n", "We will now use the gradients to run the gradient descent algorithm.\n", "\n", "Note: This example is an illustration to connect ideas we have seen before to PyTorch's way of doing things. We will see how to do this in the \"PyTorchic\" way in the next example." ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tloss,\tw\n", "0,\t10.80,\t[-0.03734243 1.1182513 ]\n", "1,\t2.31,\t[-0.28690195 1.3653738 ]\n", "2,\t1.24,\t[-0.4724271 1.5428905]\n", "3,\t0.67,\t[-0.6105486 1.6702049]\n", "4,\t0.36,\t[-0.71353513 1.7613506 ]\n", "5,\t0.20,\t[-0.79044634 1.8264704 ]\n", "6,\t0.11,\t[-0.8479796 1.8728881]\n", "7,\t0.06,\t[-0.89109135 1.9058872 ]\n", "8,\t0.04,\t[-0.92345405 1.9292755 ]\n", "9,\t0.03,\t[-0.94779253 1.9457937 ]\n", "10,\t0.02,\t[-0.9661309 1.957412 ]\n", "11,\t0.01,\t[-0.97997516 1.9655445 ]\n", "12,\t0.01,\t[-0.9904472 1.9712044]\n", "13,\t0.01,\t[-0.9983844 1.9751165]\n", "14,\t0.01,\t[-1.0044125 1.9777979]\n", "15,\t0.01,\t[-1.0090001 1.9796168]\n", "16,\t0.01,\t[-1.0124985 1.9808345]\n", "17,\t0.01,\t[-1.0151719 1.9816359]\n", "18,\t0.01,\t[-1.0172188 1.9821515]\n", "19,\t0.01,\t[-1.0187894 1.9824725]\n", "\n", "true w\t\t [-1. 2.]\n", "estimated w\t [-1.0187894 1.9824725]\n" ] } ], "source": [ "step_size = 0.1\n", "\n", "print('iter,\\tloss,\\tw')\n", "for i in range(20):\n", " y_hat = model(X, w)\n", " loss = rss(y, y_hat)\n", " \n", " loss.backward() # compute the gradient of the loss\n", " \n", " w.data = w.data - step_size * w.grad # do a gradient descent step\n", " \n", " print('{},\\t{:.2f},\\t{}'.format(i, loss.item(), w.view(2).detach().numpy()))\n", " \n", " # We need to zero the grad variable since the backward()\n", " # call accumulates the gradients in .grad instead of overwriting.\n", " # The detach_() is for efficiency. You do not need to worry too much about it.\n", " w.grad.detach()\n", " w.grad.zero_()\n", "\n", "print('\\ntrue w\\t\\t', true_w.view(2).numpy())\n", "print('estimated w\\t', w.view(2).detach().numpy())" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "visualize_fun(w, 'Dataset with learned $w$ (Manual GD)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## torch.nn.Module\n", "\n", "`Module` is PyTorch's way of performing operations on tensors. Modules are implemented as subclasses of the `torch.nn.Module` class. All modules are callable and can be composed together to create complex functions.\n", "\n", "[`torch.nn` docs](https://pytorch.org/docs/stable/nn.html)\n", "\n", "Note: most of the functionality implemented for modules can be accessed in a functional form via `torch.nn.functional`, but these require you to create and manage the weight tensors yourself.\n", "\n", "[`torch.nn.functional` docs](https://pytorch.org/docs/stable/nn.html#torch-nn-functional).\n", "\n", "### Linear Module\n", "The bread and butter of modules is the Linear module which does a linear transformation with a bias. It takes the input and output dimensions as parameters, and creates the weights in the object.\n", "\n", "Unlike how we initialized our $w$ manually, the Linear module automatically initializes the weights randomly. For minimizing non convex loss functions (e.g. training neural networks), initialization is important and can affect results. If training isn't working as well as expected, one thing to try is manually initializing the weights to something different from the default. PyTorch implements some common initializations in `torch.nn.init`.\n", "\n", "[`torch.nn.init` docs](https://pytorch.org/docs/stable/nn.html#torch-nn-init)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "example_tensor torch.Size([2, 3])\n", "transormed torch.Size([2, 4])\n", "\n", "We can see that the weights exist in the background\n", "\n", "W: Parameter containing:\n", "tensor([[ 0.2151, -0.2631, 0.4498],\n", " [-0.3092, 0.3098, -0.4239],\n", " [-0.0499, -0.2222, 0.0085],\n", " [-0.0356, 0.5260, 0.4925]], requires_grad=True)\n", "b: Parameter containing:\n", "tensor([-0.0887, 0.3944, 0.4080, 0.2182], requires_grad=True)\n" ] } ], "source": [ "d_in = 3\n", "d_out = 4\n", "linear_module = nn.Linear(d_in, d_out)\n", "\n", "example_tensor = torch.tensor([[1.,2,3], [4,5,6]])\n", "# applys a linear transformation to the data\n", "transformed = linear_module(example_tensor)\n", "print('example_tensor', example_tensor.shape)\n", "print('transormed', transformed.shape)\n", "print()\n", "print('We can see that the weights exist in the background\\n')\n", "print('W:', linear_module.weight)\n", "print('b:', linear_module.bias)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Activation functions\n", "PyTorch implements a number of activation functions including but not limited to `ReLU`, `Tanh`, and `Sigmoid`. Since they are modules, they need to be instantiated." ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "example_tensor tensor([-1., 1., 0.])\n", "activated tensor([0., 1., 0.])\n" ] } ], "source": [ "activation_fn = nn.ReLU() # we instantiate an instance of the ReLU module\n", "example_tensor = torch.tensor([-1.0, 1.0, 0.0])\n", "activated = activation_fn(example_tensor)\n", "print('example_tensor', example_tensor)\n", "print('activated', activated)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Sequential\n", "\n", "Many times, we want to compose Modules together. `torch.nn.Sequential` provides a good interface for composing simple modules." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "transformed torch.Size([2, 1])\n" ] } ], "source": [ "d_in = 3\n", "d_hidden = 4\n", "d_out = 1\n", "model = torch.nn.Sequential(\n", " nn.Linear(d_in, d_hidden),\n", " nn.Tanh(),\n", " nn.Linear(d_hidden, d_out),\n", " nn.Sigmoid()\n", " )\n", "\n", "example_tensor = torch.tensor([[1.,2,3],[4,5,6]])\n", "transformed = model(example_tensor)\n", "print('transformed', transformed.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note: we can access *all* of the parameters (of any `nn.Module`) with the `parameters()` method. " ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Parameter containing:\n", "tensor([[-0.1409, 0.0518, 0.3034],\n", " [ 0.0913, 0.2452, -0.2616],\n", " [ 0.5021, 0.0118, 0.1383],\n", " [ 0.4757, -0.3128, 0.2707]], requires_grad=True)\n", "Parameter containing:\n", "tensor([-0.3952, 0.1285, 0.1777, -0.4675], requires_grad=True)\n", "Parameter containing:\n", "tensor([[ 0.0391, -0.4876, -0.1731, 0.4704]], requires_grad=True)\n", "Parameter containing:\n", "tensor([0.0454], requires_grad=True)\n" ] } ], "source": [ "params = model.parameters()\n", "\n", "for param in params:\n", " print(param)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Loss functions\n", "PyTorch implements many common loss functions including `MSELoss` and `CrossEntropyLoss`." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.6667)\n" ] } ], "source": [ "mse_loss_fn = nn.MSELoss()\n", "\n", "input = torch.tensor([[0., 0, 0]])\n", "target = torch.tensor([[1., 0, -1]])\n", "\n", "loss = mse_loss_fn(input, target)\n", "\n", "print(loss)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## torch.optim\n", "PyTorch implements a number of gradient-based optimization methods in `torch.optim`, including Gradient Descent. At the minimum, it takes in the model parameters and a learning rate.\n", "\n", "Optimizers do not compute the gradients for you, so you must call `backward()` yourself. You also must call the `optim.zero_grad()` function before calling `backward()` since by default PyTorch does and inplace add to the `.grad` member variable rather than overwriting it.\n", "\n", "This does both the `detach_()` and `zero_()` calls on all tensor's `grad` variables.\n", "\n", "[`torch.optim` docs](https://pytorch.org/docs/stable/optim.html)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "model params before: Parameter containing:\n", "tensor([[0.1950]], requires_grad=True)\n", "model params after: Parameter containing:\n", "tensor([[0.2219]], requires_grad=True)\n" ] } ], "source": [ "# create a simple model\n", "model = nn.Linear(1, 1)\n", "\n", "# create a simple dataset\n", "X_simple = torch.tensor([[1.]])\n", "y_simple = torch.tensor([[2.]])\n", "\n", "# create our optimizer\n", "optim = torch.optim.SGD(model.parameters(), lr=1e-2)\n", "mse_loss_fn = nn.MSELoss()\n", "\n", "y_hat = model(X_simple)\n", "print('model params before:', model.weight)\n", "loss = mse_loss_fn(y_hat, y_simple)\n", "optim.zero_grad()\n", "loss.backward()\n", "optim.step()\n", "print('model params after:', model.weight)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the parameter was updated in the correct direction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear regression using GD with automatically computed derivatives and PyTorch's Modules\n", "\n", "Now let's combine what we've learned to solve linear regression in a \"PyTorchic\" way." ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tloss,\tw\n", "0,\t4.37,\t[-0.5072826 0.7721884]\n", "1,\t2.34,\t[-0.6624693 1.0903175]\n", "2,\t1.25,\t[-0.7725248 1.3242052]\n", "3,\t0.67,\t[-0.85030663 1.4962891 ]\n", "4,\t0.36,\t[-0.90505993 1.6230037 ]\n", "5,\t0.20,\t[-0.9434225 1.716392 ]\n", "6,\t0.11,\t[-0.9701522 1.7852831]\n", "7,\t0.06,\t[-0.98865306 1.8361537 ]\n", "8,\t0.04,\t[-1.0013554 1.8737577]\n", "9,\t0.02,\t[-1.0099901 1.9015862]\n", "10,\t0.02,\t[-1.0157865 1.9222052]\n", "11,\t0.01,\t[-1.019615 1.9375019]\n", "12,\t0.01,\t[-1.0220896 1.9488654]\n", "13,\t0.01,\t[-1.0236413 1.9573189]\n", "14,\t0.01,\t[-1.0245715 1.963617 ]\n", "15,\t0.01,\t[-1.0250894 1.9683164]\n", "16,\t0.01,\t[-1.0253391 1.9718288]\n", "17,\t0.01,\t[-1.0254192 1.9744583]\n", "18,\t0.01,\t[-1.0253965 1.9764304]\n", "19,\t0.01,\t[-1.025315 1.9779121]\n", "\n", "true w\t\t [-1. 2.]\n", "estimated w\t [-1.025315 1.9779121]\n" ] } ], "source": [ "step_size = 0.1\n", "\n", "linear_module = nn.Linear(d, 1, bias=False)\n", "\n", "loss_func = nn.MSELoss()\n", "\n", "optim = torch.optim.SGD(linear_module.parameters(), lr=step_size)\n", "\n", "print('iter,\\tloss,\\tw')\n", "\n", "for i in range(20):\n", " y_hat = linear_module(X)\n", " loss = loss_func(y_hat, y)\n", " optim.zero_grad()\n", " loss.backward()\n", " optim.step()\n", " \n", " print('{},\\t{:.2f},\\t{}'.format(i, loss.item(), linear_module.weight.view(2).detach().numpy()))\n", "\n", "print('\\ntrue w\\t\\t', true_w.view(2).numpy())\n", "print('estimated w\\t', linear_module.weight.view(2).detach().numpy())" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "visualize_fun(linear_module.weight.t(), 'Dataset with learned $w$ (PyTorch GD)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Linear regression using SGD \n", "In the previous examples, we computed the average gradient over the entire dataset (Gradient Descent). We can implement Stochastic Gradient Descent with a simple modification." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tloss,\tw\n", "0,\t0.01,\t[-0.16747226 0.69458336]\n", "20,\t0.73,\t[-0.52777785 1.409119 ]\n", "40,\t0.05,\t[-0.74168175 1.7194624 ]\n", "60,\t0.04,\t[-0.8074937 1.831477 ]\n", "80,\t0.09,\t[-0.88882697 1.8813883 ]\n", "100,\t0.06,\t[-0.9371291 1.9570427]\n", "120,\t0.00,\t[-0.964763 1.9772899]\n", "140,\t0.00,\t[-0.9806282 1.9791764]\n", "160,\t0.04,\t[-0.98312473 1.9838824 ]\n", "180,\t0.01,\t[-0.99795353 1.9885796 ]\n", "\n", "true w\t\t [-1. 2.]\n", "estimated w\t [-0.9991454 1.9860797]\n" ] } ], "source": [ "step_size = 0.01\n", "\n", "linear_module = nn.Linear(d, 1)\n", "loss_func = nn.MSELoss()\n", "optim = torch.optim.SGD(linear_module.parameters(), lr=step_size)\n", "print('iter,\\tloss,\\tw')\n", "for i in range(200):\n", " rand_idx = np.random.choice(n) # take a random point from the dataset\n", " x = X[rand_idx] \n", " y_hat = linear_module(x)\n", " loss = loss_func(y_hat, y[rand_idx]) # only compute the loss on the single point\n", " optim.zero_grad()\n", " loss.backward()\n", " optim.step()\n", " \n", " if i % 20 == 0:\n", " print('{},\\t{:.2f},\\t{}'.format(i, loss.item(), linear_module.weight.view(2).detach().numpy()))\n", "\n", "print('\\ntrue w\\t\\t', true_w.view(2).numpy())\n", "print('estimated w\\t', linear_module.weight.view(2).detach().numpy())" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "visualize_fun(linear_module.weight.t(), 'Dataset with learned $w$ (PyTorch SGD)')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Neural Network Basics in PyTorch\n", "\n", "Let's consider the dataset from hw3. We will try and fit a simple neural network to the data." ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "%matplotlib inline\n", "\n", "d = 1\n", "n = 200\n", "X = torch.rand(n,d)\n", "y = 4 * torch.sin(np.pi * X) * torch.cos(6*np.pi*X**2)\n", "\n", "plt.scatter(X.numpy(), y.numpy())\n", "plt.title('plot of $f(x)$')\n", "plt.xlabel('$x$')\n", "plt.ylabel('$y$')\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here we define a simple two hidden layer neural network with Tanh activations. There are a few hyper parameters to play with to get a feel for how they change the results." ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tloss\n", "0,\t3.96\n", "600,\t3.69\n", "1200,\t2.58\n", "1800,\t1.10\n", "2400,\t0.91\n", "3000,\t0.68\n", "3600,\t0.14\n", "4200,\t0.08\n", "4800,\t0.06\n", "5400,\t0.15\n" ] } ], "source": [ "# feel free to play with these parameters\n", "\n", "step_size = 0.05\n", "n_epochs = 6000\n", "n_hidden_1 = 32\n", "n_hidden_2 = 32\n", "d_out = 1\n", "\n", "neural_network = nn.Sequential(\n", " nn.Linear(d, n_hidden_1), \n", " nn.Tanh(),\n", " nn.Linear(n_hidden_1, n_hidden_2),\n", " nn.Tanh(),\n", " nn.Linear(n_hidden_2, d_out)\n", " )\n", "\n", "loss_func = nn.MSELoss()\n", "\n", "optim = torch.optim.SGD(neural_network.parameters(), lr=step_size)\n", "print('iter,\\tloss')\n", "for i in range(n_epochs):\n", " y_hat = neural_network(X)\n", " loss = loss_func(y_hat, y)\n", " optim.zero_grad()\n", " loss.backward()\n", " optim.step()\n", " \n", " if i % (n_epochs // 10) == 0:\n", " print('{},\\t{:.2f}'.format(i, loss.item()))\n", "\n" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X_grid = torch.from_numpy(np.linspace(0,1,50)).float().view(-1, d)\n", "y_hat = neural_network(X_grid)\n", "plt.scatter(X.numpy(), y.numpy())\n", "plt.plot(X_grid.detach().numpy(), y_hat.detach().numpy(), 'r')\n", "plt.title('plot of $f(x)$ and $\\hat{f}(x)$')\n", "plt.xlabel('$x$')\n", "plt.ylabel('$y$')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Things that might help on the homework\n", "\n", "## Brief Sidenote: Momentum\n", "\n", "There are other optimization algorithms besides stochastic gradient descent. One is a modification of SGD called momentum. We won't get into it here, but if you would like to read more [here](https://distill.pub/2017/momentum/) is a good place to start.\n", "\n", "We only change the step size and add the momentum keyword argument to the optimizer. Notice how it reduces the training loss in fewer iterations." ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "iter,\tloss\n", "0,\t3.83\n", "150,\t3.06\n", "300,\t0.74\n", "450,\t0.12\n", "600,\t0.04\n", "750,\t0.03\n", "900,\t0.01\n", "1050,\t0.00\n", "1200,\t0.00\n", "1350,\t0.00\n" ] } ], "source": [ "# feel free to play with these parameters\n", "\n", "step_size = 0.05\n", "momentum = 0.9\n", "n_epochs = 1500\n", "n_hidden_1 = 32\n", "n_hidden_2 = 32\n", "d_out = 1\n", "\n", "neural_network = nn.Sequential(\n", " nn.Linear(d, n_hidden_1), \n", " nn.Tanh(),\n", " nn.Linear(n_hidden_1, n_hidden_2),\n", " nn.Tanh(),\n", " nn.Linear(n_hidden_2, d_out)\n", " )\n", "\n", "loss_func = nn.MSELoss()\n", "\n", "optim = torch.optim.SGD(neural_network.parameters(), lr=step_size, momentum=momentum)\n", "print('iter,\\tloss')\n", "for i in range(n_epochs):\n", " y_hat = neural_network(X)\n", " loss = loss_func(y_hat, y)\n", " optim.zero_grad()\n", " loss.backward()\n", " optim.step()\n", " \n", " if i % (n_epochs // 10) == 0:\n", " print('{},\\t{:.2f}'.format(i, loss.item()))\n", "\n" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "X_grid = torch.from_numpy(np.linspace(0,1,50)).float().view(-1, d)\n", "y_hat = neural_network(X_grid)\n", "plt.scatter(X.numpy(), y.numpy())\n", "plt.plot(X_grid.detach().numpy(), y_hat.detach().numpy(), 'r')\n", "plt.title('plot of $f(x)$ and $\\hat{f}(x)$')\n", "plt.xlabel('$x$')\n", "plt.ylabel('$y$')\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## CrossEntropyLoss\n", "So far, we have been considering regression tasks and have used the [MSELoss](https://pytorch.org/docs/stable/nn.html#torch.nn.MSELoss) module. For the homework, we will be performing a classification task and will use the cross entropy loss.\n", "\n", "PyTorch implements a version of the cross entropy loss in one module called [CrossEntropyLoss](https://pytorch.org/docs/stable/nn.html#torch.nn.CrossEntropyLoss). Its usage is slightly different than MSE, so we will break it down here. \n", "\n", "- input: The first parameter to CrossEntropyLoss is the output of our network. It expects a *real valued* tensor of dimensions $(N,C)$ where $N$ is the minibatch size and $C$ is the number of classes. In our case $N=3$ and $C=2$. The values along the second dimension correspond to raw unnormalized scores for each class. The CrossEntropyLoss module does the softmax calculation for us, so we do not need to apply our own softmax to the output of our neural network.\n", "- output: The second parameter to CrossEntropyLoss is the true label. It expects an *integer valued* tensor of dimension $(N)$. The integer at each element corresponds to the correct class. In our case, the \"correct\" class labels are class 0, class 1, and class 1.\n", "\n", "Try out the loss function on three toy predictions. The true class labels are $y=[1,1,0]$. The first two examples correspond to predictions that are \"correct\" in that they have higher raw scores for the correct class. The second example is \"more confident\" in the prediction, leading to a smaller loss. The last two examples are incorrect predictions with lower and higher confidence respectively." ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "tensor(0.1269)\n" ] } ], "source": [ "loss = nn.CrossEntropyLoss()\n", "\n", "input = torch.tensor([[-1., 1],[-1, 1],[1, -1]]) # raw scores correspond to the correct class\n", "# input = torch.tensor([[-3., 3],[-3, 3],[3, -3]]) # raw scores correspond to the correct class with higher confidence\n", "# input = torch.tensor([[1., -1],[1, -1],[-1, 1]]) # raw scores correspond to the incorrect class\n", "# input = torch.tensor([[3., -3],[3, -3],[-3, 3]]) # raw scores correspond to the incorrect class with incorrectly placed confidence\n", "\n", "target = torch.tensor([1, 1, 0])\n", "output = loss(input, target)\n", "print(output)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Learning rate schedulers\n", "\n", "Often we do not want to use a fixed learning rate throughout all training. PyTorch offers learning rate schedulers to change the learning rate over time. Common strategies include multiplying the lr by a constant every epoch (e.g. 0.9) and halving the learning rate when the training loss flattens out.\n", "\n", "See the [learning rate scheduler docs](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate) for usage and examples" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Convolutions\n", "When working with images, we often want to use convolutions to extract features using convolutions. PyTorch implments this for us in the `torch.nn.Conv2d` module. It expects the input to have a specific dimension $(N, C_{in}, H_{in}, W_{in})$ where $N$ is batch size, $C_{in}$ is the number of channels the image has, and $H_{in}, W_{in}$ are the image height and width respectively.\n", "\n", "We can modify the convolution to have different properties with the parameters:\n", "- kernel_size\n", "- stride\n", "- padding\n", "\n", "They can change the output dimension so be careful.\n", "\n", "See the [`torch.nn.Conv2d` docs](https://pytorch.org/docs/stable/nn.html#torch.nn.Conv2d) for more information." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To illustrate what the `Conv2d` module is doing, let's set the conv weights manually to a Gaussian blur kernel.\n", "\n", "We can see that it applies the kernel to the image." ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAEICAYAAACZA4KlAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAARo0lEQVR4nO3dfZBddX3H8feHkAdCQh5EYhqCqwgUdGzQFSzSGooixgewtpRYNWXQUEBbK4MijjUy6iDjw2iBaCiUBwXE4SHBhioEKdVRYKNIwoM8BpKwZCEJkwCS7Cbf/nFPnJuw59zNvec+7P4+r5mdvXu+5+F7b/K559xz7jlHEYGZjXx7tLsBM2sNh90sEQ67WSIcdrNEOOxmiXDYzRLhsA8Dkr4v6Utlj1tjPl2SQtKeOfX7Jc1udDnWOvJxdhuMpC7gCWB0RAy0uR0rgdfsHU7SqHb3YCODw94Gkg6VdIek57PN4Q9W1S6XtFDSUkkvAsdkw75aNc7nJPVKelrSJ7LN7TdUTf/V7PFsSWsknSWpL5vmlKr5vE/S7yRtkrRa0oLdeA6rJL0re7xA0k8k/VDSZkkrJB0s6QvZcldLOq5q2lMkPZiN+7ik03aZd9HzGyvpm5KekrQu+9iy1+7+G6TIYW8xSaOBm4GfA/sBnwZ+JOmQqtE+AnwNmAj8cpfpjwc+C7wLeAMwu8YiXwNMAmYApwIXSZqS1V4EPg5MBt4HnC7pxPqeGR8ArgKmAL8Dfkbl/9cM4DzgB1Xj9gHvB/YBTgG+I+ktQ3x+5wMHA7Oy+gzg3+vsOSkOe+u9HZgAnB8RWyPiduCnwNyqcRZHxK8iYntEvLzL9CcB/xUR90fES8CCGsvrB86LiP6IWAq8ABwCEBF3RMSKbDn3AdcA76zzef1fRPws+3z/E+DV2XPsB64FuiRNzpb73xHxWFT8L5U3vr+q9fwkCZgP/FtEbIiIzcDXgZPr7Dkpg+5ptab6M2B1RGyvGvYklTXUDqtrTN8zxHEB1u+yg+0lKm82SDqSypryTcAYYCyVoNZjXdXjPwLPRcS2qr/Jlvu8pPcCX6ayht4DGA+syMYpen6vzsZdXsk9AAK8X2MIvGZvvaeBmZKqX/sDgLVVfxcdIukF9q/6e2YDvVwNLAFmRsQk4PtUwtM0ksYC1wPfBKZFxGRgadVyi57fc1TeON4YEZOzn0kRMaGZPY8UDnvr3UVl7fo5SaOzY9UfoLKpOxTXAadkO/nGA40cU58IbIiIlyUdQWVfQbPt2IJ4FhjI1vLHVdVzn1+2NXQJlc/4+wFImiHpPS3oe9hz2FssIrZSCfd7qaypLgY+HhEPDXH6W4DvAb8AHgV+k5W21NHOGcB5kjZT2cl1XR3z2C3Z5+x/yZa1kcobzJKqeq3n9/kdwyVtAm4j2wdhxfylmmFO0qHASmDsSPzyy0h/fq3kNfswJOlD2fHmKcA3gJtHUhBG+vNrF4d9eDqNyrHqx4BtwOntbad0I/35tYU3480S4TW7WSJa+qWaMRob49i7lYs0S8rLvMjW2DLodyUaCnv2PebvUvkG039GxPlF449jb47UsY0s0swK3BXLcmt1b8Znp15eROV48WHAXEmH1Ts/M2uuRj6zHwE8GhGPZ18UuRY4oZy2zKxsjYR9BjufpLCGnU/mAEDSfEk9knr66/qSl5mVoel74yNiUUR0R0T3aMY2e3FmlqORsK9l5zOS9mfnM7fMrIM0EvZ7gIMkvU7SGCoXEFhSYxoza5O6D71FxICkT1G5/NAo4LKIuL+0zsysVA0dZ88uc7S0pF7MrIn8dVmzRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiIZu2SxpFbAZ2AYMRER3GU2ZWfkaCnvmmIh4roT5mFkTeTPeLBGNhj2An0taLmn+YCNImi+pR1JPP1saXJyZ1avRzfijI2KtpP2AWyU9FBF3Vo8QEYuARQD7aGo0uDwzq1NDa/aIWJv97gNuBI4ooykzK1/dYZe0t6SJOx4DxwEry2rMzMrVyGb8NOBGSTvmc3VE/E8pXZlZ6eoOe0Q8DvxFib2YWRP50JtZIhx2s0Q47GaJcNjNEuGwmyWijBNhrM16P3tUbk01vrM4bn3xCBv/vHj66b/eVjz/m+8unoG1jNfsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiRsxx9r4z8481Azz/5v7C+o3HXVhmOy116Jh76p725RgorE/aY6/Cet/HXiysP/29/P9i337m3YXTrj9pn8L6wOo1hXXbmdfsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kiFNG6m7Tso6lxpI6te/qHL3lbbu2hORcXTjtWo+terrXHR1fNLqxv/EiN4/Crniqxm+HhrljGptigwWpes5slwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmiRhW57MvPObK3Fqt4+jfWH9QYb1v68S6eirDDcvfWlg/4OZBD5t2hDXHFq8vLphzdW7twxM2FU77w647CusfvXp2YX3jP+yfW0vxXPiaa3ZJl0nqk7SyathUSbdKeiT7PaW5bZpZo4ayGX85cPwuw84BlkXEQcCy7G8z62A1wx4RdwIbdhl8AnBF9vgK4MRy2zKzstX7mX1aRPRmj58BpuWNKGk+MB9gHOPrXJyZNarhvfFROZMm92yaiFgUEd0R0T2asY0uzszqVG/Y10maDpD97iuvJTNrhnrDvgSYlz2eBywupx0za5aa57NLugaYDewLrAO+DNwEXAccADwJnBQRu+7Ee4VGz2fXW9+YW3tuVvG5zfvd9IfC+rb1Ndu3Ouzx5vwbvL//2l8VTnvm5NUNLfuQS0/PrXV96dcNzbtTFZ3PXnMHXUTMzSnVn1ozazl/XdYsEQ67WSIcdrNEOOxmiXDYzRIxrC4lbSPL+k/+ZWG95ysLG5r/8i1bc2vnvu6IhubdqXwpaTNz2M1S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WCIfdLBEOu1kihtUtm234WXPuUbm17Ydvbuqyp43KP5994G+Kb5O95+3Ly26n7bxmN0uEw26WCIfdLBEOu1kiHHazRDjsZolw2M0S4evGjwB7vr4rt/boqdMLp7345EUld7Oz2eP6c2uj1L51zWP9LxTWz3jt0S3qpFwNXTde0mWS+iStrBq2QNJaSfdmP3PKbNjMyjeUt9bLgeMHGf6diJiV/Swtty0zK1vNsEfEncCGFvRiZk3UyIemT0m6L9vMn5I3kqT5knok9fSzpYHFmVkj6g37QuBAYBbQC3wrb8SIWBQR3RHRPZqxdS7OzBpVV9gjYl1EbIuI7cAlwMi8JabZCFJX2CVVH8/5ELAyb1wz6ww1z2eXdA0wG9hX0hrgy8BsSbOAAFYBpzWvxZHvhb8/srD+7FuK35PP+9trc2snT9xYV0/l6czvbb3rts8U1g+mpzWNtFDNsEfE3EEGX9qEXsysiTrzbdfMSuewmyXCYTdLhMNulgiH3SwRvpR0CXT4Gwvrky/sLawv7VpYWG/mqaA3vTihsL7yj/s3NP+fXjA7tzZqS/Hp1fPOu7mwPn/S0/W0BMCYZ0bXPe1w5TW7WSIcdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIH2cfoie/kn/r4S+d/OPCaf9x4vrC+lMDLxXWH9qae9UvAD59zSdya+N7B72q8J9Mv+O5wvq2Bx4urNcyid/UPe0jX5hWY+bFx9mfKLhcdNfi4ktJj0Res5slwmE3S4TDbpYIh90sEQ67WSIcdrNEOOxmifBx9iGa/La+3Fqt4+jHPvDBwnr/f7ymsL7X4rsL6138urBeZFvdUzZu+zsPL6yfOLnWRYyL11Ubto/JL969osa8Rx6v2c0S4bCbJcJhN0uEw26WCIfdLBEOu1kiHHazRAzlls0zgSuBaVRu0bwoIr4raSrwY6CLym2bT4qIdt8fuGledWr++c9v+OzphdMeeHbxcfA9eaqunoa7jQePK6y/Y1xj66L5Kz+aW9uXxs7TH46G8moOAGdFxGHA24EzJR0GnAMsi4iDgGXZ32bWoWqGPSJ6I+K32ePNwIPADOAE4IpstCuAE5vUo5mVYLe2kyR1AYcDdwHTImLHfY2eobKZb2YdashhlzQBuB74TERsqq5FRFD5PD/YdPMl9Ujq6WdLQ82aWf2GFHZJo6kE/UcRcUM2eJ2k6Vl9OjDomSIRsSgiuiOiezRjy+jZzOpQM+ySBFwKPBgR364qLQHmZY/nAYvLb8/MyjKUU1zfAXwMWCHp3mzYucD5wHWSTgWeBE5qSocdYqD3mdzagWfn1yzf+rcNNDT9g1uLL8E98eJJDc1/pKkZ9oj4JZB38fFjy23HzJrF36AzS4TDbpYIh90sEQ67WSIcdrNEOOxmifClpK2p3rNyU27txskX1Zi64FLQwLz75xXWp9xyT435p8VrdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sET7Obk31d/vcl1sbv8eEwmkf7n+xsD7+wsn1tJQsr9nNEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0T4OLs1pO+Mowrr00bln1P+RH/+bbAB5n797ML6vrcU3wrbduY1u1kiHHazRDjsZolw2M0S4bCbJcJhN0uEw26WiJrH2SXNBK4EpgEBLIqI70paAHwSeDYb9dyIWNqsRq09NHZsYf3D/3x7YX3z9q25tTl3n1447QE/8HH0Mg3lSzUDwFkR8VtJE4Hlkm7Nat+JiG82rz0zK0vNsEdEL9CbPd4s6UFgRrMbM7Ny7dZndkldwOHAXdmgT0m6T9JlkqbkTDNfUo+knn62NNatmdVtyGGXNAG4HvhMRGwCFgIHArOorPm/Ndh0EbEoIrojons0xZ//zKx5hhR2SaOpBP1HEXEDQESsi4htEbEduAQ4onltmlmjaoZdkoBLgQcj4ttVw6dXjfYhYGX57ZlZWYayN/4dwMeAFZLuzYadC8yVNIvK4bhVwGlN6M/abXsUlq+6+ZjC+i2/n51bO+C639TRkNVrKHvjfwlokJKPqZsNI/4GnVkiHHazRDjsZolw2M0S4bCbJcJhN0uELyVthaI//xRVgK4v+jTU4cJrdrNEOOxmiXDYzRLhsJslwmE3S4TDbpYIh90sEYooPl+51IVJzwJPVg3aF3iuZQ3snk7trVP7AvdWrzJ7e21EvHqwQkvD/oqFSz0R0d22Bgp0am+d2he4t3q1qjdvxpslwmE3S0S7w76ozcsv0qm9dWpf4N7q1ZLe2vqZ3cxap91rdjNrEYfdLBFtCbuk4yX9QdKjks5pRw95JK2StELSvZJ62tzLZZL6JK2sGjZV0q2SHsl+D3qPvTb1tkDS2uy1u1fSnDb1NlPSLyQ9IOl+Sf+aDW/ra1fQV0tet5Z/Zpc0CngYeDewBrgHmBsRD7S0kRySVgHdEdH2L2BI+mvgBeDKiHhTNuwCYENEnJ+9UU6JiM93SG8LgBfafRvv7G5F06tvMw6cCPwTbXztCvo6iRa8bu1Ysx8BPBoRj0fEVuBa4IQ29NHxIuJOYMMug08ArsgeX0HlP0vL5fTWESKiNyJ+mz3eDOy4zXhbX7uCvlqiHWGfAayu+nsNnXW/9wB+Lmm5pPntbmYQ0yKiN3v8DDCtnc0MouZtvFtpl9uMd8xrV8/tzxvlHXSvdHREvAV4L3BmtrnakaLyGayTjp0O6TberTLIbcb/pJ2vXb23P29UO8K+FphZ9ff+2bCOEBFrs999wI103q2o1+24g272u6/N/fxJJ93Ge7DbjNMBr107b3/ejrDfAxwk6XWSxgAnA0va0McrSNo723GCpL2B4+i8W1EvAeZlj+cBi9vYy0465TbeebcZp82vXdtvfx4RLf8B5lDZI/8Y8MV29JDT1+uB32c/97e7N+AaKpt1/VT2bZwKvApYBjwC3AZM7aDergJWAPdRCdb0NvV2NJVN9PuAe7OfOe1+7Qr6asnr5q/LmiXCO+jMEuGwmyXCYTdLhMNulgiH3SwRDrtZIhx2s0T8P3ImkM40Bc0gAAAAAElFTkSuQmCC\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP0AAAEICAYAAACUHfLiAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAUS0lEQVR4nO3dfZBddX3H8fdnN9ld8kQSEsJmk/IQHgNTA0YUZSotiIhaoNOxYkfRUUMdHbF1puBDlenQKTo+1M50bIMgiKilokJnaDVEWmpnRAONeQIMhMRsyCNJyBPZZHe//eOedS6493c2u3f33s3v85rZ2XvP9+zvfPduPjn33nPu7ygiMLN8tDS6ATMbWw69WWYcerPMOPRmmXHozTLj0JtlxqFvEEkbJV1Ro3aZpO6x7qlq+6dJCkkTatTXSrpsbLuyehn0j2qWEhHnN7oHGz7v6Y8zg+2da+2xLU8OfWO9TtI6SXskfVNSx2ArFU+1z6y6f7ek24rbl0nqlnSzpG3ANyXdKun7kr4taR/wfkknSrpT0lZJWyTdJqm1GKNV0pck7ZK0AXh7qunqlybFtv6t2NZ+SaslnS3pU5J2SNos6cqqn/2ApKeKdTdIuvFVY/910eMLkj5U/btLai/6/I2k7ZL+WdIJw3rkM+bQN9afA28FFgBnA58d5jinADOBU4ElxbJrgO8D04H7gLuBXuBM4ELgSuBDxbofBt5RLF8M/Okxbv+dwL3ADOD/gB9T+bfVBfwt8C9V6+4otjUN+ADwVUkXAUi6Cvgr4Iqiz8tetZ3bqTxOi4p6F/C5Y+zVIsJfDfgCNgJ/UXX/auC54vZlQHdVLYAzq+7fDdxWte4RoKOqfivwWNX9OUAPcELVsuuBR4vbP31VL1cW25yQ6P2Kqm0tq6q9EzgAtBb3pxZjTa8x1o+Am4rbdwF/X1U7c+B3BwQcBBZU1S8Bnm/033K8ffm1XmNtrrq9CZg7zHF2RsThxNinAhOBrZIGlrVUrTN3kF6Oxfaq2y8DuyKir+o+wBRgr6S3AZ+nssduASYBq6v6WFHjd5hdrPtE1e8goPUYe82eQ99Y86tu/x7wQo31DlH5Bz/gFKD6kN5gH5WsXraZyp5+VkT0DrLu1kF6qTtJ7cADwPuAByPiqKQfUQnvQB/zqn6kuqddVP4DOT8itoxGf7nwa/rG+qikeZJmAp8B/rXGeiuB9xRvuF0FvPlYNhIRW4GfAF+WNE1Si6QFkgbGuR/4eNHLDOCWYf025dqAdmAn0Fvs9a+sqt8PfEDSeZImAX9T9Tv0A3dQeQ/gZABJXZLeOkq9Hrcc+sb6DpUwbgCeA26rsd5NVF4r76Xy5t+PhrGt91EJ3TpgD5U3+TqL2h1U3nz7FfAk8INhjF8qIvYDH6cS7j3Ae4CHqur/Afwj8CjwLPDzotRTfL95YHlxVOIR4JzR6PV4puINEbOmI+k8YA3QXuNliQ2D9/TWVCRdVxyPnwF8Afh3B76+HHprNjdSOZb/HNAHfKSx7Rx//PTeLDPe05tlZkyP07epPTqYPJabNMvKYQ5yJHqUWmdEoS+OGX+NyllR34iI21PrdzCZ1+vykWzSzBIej+Wl6wz76X3xCa1/At4GLASul7RwuOOZ2dgYyWv6i4FnI2JDRBwBvkflk11m1sRGEvouXvmBiO5imZk1sVF/I0/SEorPeHe84jMjZtYII9nTb+GVn4KaVyx7hYhYGhGLI2LxRNpHsDkzq4eRhP6XwFmSTpfUBrybqg9PmFlzGvbT+4jolfQxKp/OagXuioi1devMzEbFiF7TR8TDwMN16sXMxoBPwzXLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMuPQm2VmQqMbsMG1dHSUrzNjenqF9rb6NJMQB18uXad/z570GL299WrHhmBEoZe0EdgP9AG9EbG4Hk2Z2eipx57+DyNiVx3GMbMx4Nf0ZpkZaegD+ImkJyQtGWwFSUskrZC04ig9I9ycmY3USJ/eXxoRWySdDCyT9HREPFa9QkQsBZYCTNPMGOH2zGyERrSnj4gtxfcdwA+Bi+vRlJmNnmGHXtJkSVMHbgNXAmvq1ZiZjY6RPL2fA/xQ0sA434mI/xxJM2pvL12ndfasZL1v9vTSMY6clD4G3t/e+Pc3e6a1lq5zoCvdZ98JI+8jlK63pw/BAzC1uy9Zn9x9qHSMludfSNb7XtydHiD8ynLAsEMfERuA19SxFzMbA43fpZnZmHLozTLj0JtlxqE3y4xDb5YZh94sMw69WWbGdBINtbTQMmlyzXrfa84sHWPr62v/PMC+c4+WjjFnfvqMklmTDpaOMdp+r728h7Mn70jWT2wtn+CiTIv6k/UtPTNKx1j1UleyvvbX80rH6Fx+VrI+47+fT9Z7t6cfKyCbE3i8pzfLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMjO2F7tom4hOrX3MdtPbJ5UO8cbLVyfrV8xYVzrGaRN3JutTW46UjlHmcKQnwdhw5ORkfdOR9GQhAIf7Jybru3vT5zQMxdTWw8n6xVM2lI7x7hm/SNafnjundIzPTf/jZD1aTk/WZz5auonyY/nHyXF87+nNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8yM6XH6aBH9k9pq1o9OS392G6CnL93yvVveUDrGjgNTkvW+/pH/X9hzJN3nke3pcxJOeKH8YhcTyq8RMWJlF8w4ND99IQuAi37/uWT9I3PLD6L/3aIHk/Wbe/8kWW87cGrpNib9V/oB7d+/v3SM8cB7erPMOPRmmXHozTLj0JtlxqE3y4xDb5YZh94sMw69WWbG9mIXR/to3bq7Zr3rp+WTPjyz5txkvWN3+Qk+M3b3JuvqH/lkCTqa7mPC3peS9Za95SeCRM/IJ/soo7b0RB19p5Rf7OLZN56drH/x2o7SMb5wxgPJ+l8uWp6s/8Nv3lG6jbOePSW9wlOZnJwj6S5JOyStqVo2U9IySeuL7+V/eTNrCkN5en83cNWrlt0CLI+Is4DlxX0zGwdKQx8RjwGvfk5+DXBPcfse4Nr6tmVmo2W4r+nnRMTW4vY2oObMhpKWAEsAOlqnDnNzZlYvI373PiICqPnOV0QsjYjFEbG4raXkI1tmNuqGG/rtkjoBiu9DuA6wmTWD4Yb+IeCG4vYNQPrDzmbWNEpf00v6LnAZMEtSN/B54HbgfkkfBDYB7xrKxuLoUXq3bq9Zn/JI+XHQqSXHjeNg+cwS/YfTF3AYC2VnE5SfbdAkEn/PAZ09ZyXr60+dVzrG0/PSx9CvmfJUsn7X+ZeUbqOnc1qyPiG9iXGjNPQRcX2N0uV17sXMxoBPwzXLjENvlhmH3iwzDr1ZZhx6s8w49GaZcejNMjOmk2gA0F/7iijHyxVEctJScrIUQHSk/5mpT6VjHI70dma3tifrp0+vPXnLgG0z0tNCTJyQ/j2iNz05S7Pwnt4sMw69WWYcerPMOPRmmXHozTLj0JtlxqE3y8zYH6e3oVH5sevWqSUTjXbVnK/0t46ckh6jt6M1We+Znq4D7Dkv/bt0Xbg1WQc4ty29zgTSfVx04ubSbdx3dvqiHCd2pify6N3cXbqNZuA9vVlmHHqzzDj0Zplx6M0y49CbZcahN8uMQ2+WGR+nHyWa2Jast8ycnqxH56zSbew5L31xhl0Xlh/rn3LOnmT95CkHkvW5HQdLt/Fn059P1t84aX3pGAsn1p6HAaBVHcn6H01ZV7qNb5z3pmS9t2tmegAfpzezZuTQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8w49GaZ8ck5g2lJT8jQOvuk0iGOLJyXrO+6IH0yyb6z0yejAJx+3gvJ+i1dvygd47Udm5L1dqX72Nk/qXQb63vSk0883dNZOsb0lnSfC9RfOkaZKLnohvoi/fMj7mBslO7pJd0laYekNVXLbpW0RdLK4uvq0W3TzOplKE/v7wauGmT5VyNiUfH1cH3bMrPRUhr6iHgMKL8QmJmNCyN5I+9jklYVT/9rXvlP0hJJKyStOErPCDZnZvUw3NB/HVgALAK2Al+utWJELI2IxRGxeCLpK4ua2egbVugjYntE9EVEP3AHcHF92zKz0TKs0EuqPsZyHbCm1rpm1lxKj9NL+i5wGTBLUjfweeAySYuoHJrcCNw4ei3WX8uk9LFlnT4/Wd9xSclkCsCeNx9O1t967pPJ+tmTtpVuY2LJMfTuI+V9Ltu9MFnfvH96sr79xRNLt9G6MX1OwtEp5Ue4r7s0fc7Bp07+n2R9bc+C0m20bSmZ+GRP+v3s8jMrmkNp6CPi+kEW3zkKvZjZGPBpuGaZcejNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZrKcRKNlzuxkvfst6UkyplxVfuLM+zqfTtb39aZPWPnepsWl29ixId3npO70ZCAAk7ekT4w5YVdvsn7G3iOl25iwe2eyvuei8qv5rLqgK1k/NDv9e6w+lJ7UBGDK5pIVXtxbOsZ44D29WWYcerPMOPRmmXHozTLj0JtlxqE3y4xDb5aZLI/TR9vEZP3I9PTPT2wpv7DCt9emZxA74Yn0RB4znz5auo1zNr+UrLfs3l86Rv9L+9L1AwfSA8QQLvEwPT3RRm9H+rwJgM5J6d/1aEkbq/fMLd3GlC3paTD6Sh6r8cJ7erPMOPRmmXHozTLj0JtlxqE3y4xDb5YZh94sM1kep2fHi8nynMdrXo8TgH3bOpN1gK5N6c+hT165IVnv27mrdBv9veltlJ9NMDZ04rRk/eBclY6xaGp3sv7c0fTfbMPzc0q3cc62Q8l69I+Xy1mkeU9vlhmH3iwzDr1ZZhx6s8w49GaZcejNMuPQm2XGoTfLTJYn5/Tt3ZusT/rfZ9L1J9MXqgCI/enJJ3oPpU8EGS9aT5pZus6+16YnsOi94GDpGKe2pU9WemB3+uIgJ65KT5wC0Nr9m2Q9fSrU+FG6p5c0X9KjktZJWivppmL5TEnLJK0vvqdPiTKzpjCUp/e9wCcjYiHwBuCjkhYCtwDLI+IsYHlx38yaXGnoI2JrRDxZ3N4PPAV0AdcA9xSr3QNcO0o9mlkdHdNrekmnARcCjwNzImJrUdoGDPqJBklLgCUAHaQngzSz0Tfkd+8lTQEeAD4REa+YFjQiAhh0PtKIWBoRiyNi8UTaR9SsmY3ckEIvaSKVwN8XET8oFm+X1FnUO4Edo9OimdXTUN69F3An8FREfKWq9BBwQ3H7BuDB+rdnZvU2lNf0bwLeC6yWtLJY9mngduB+SR8ENgHvGpUOR0PJBRr69pVc1KCsnpG+BV2l63RflZ7O47OLfly+HdITbSxbtzBZX7Dy5fJt7EpPrnK8KA19RPwMaj7il9e3HTMbbT4N1ywzDr1ZZhx6s8w49GaZcejNMuPQm2Umy8/T29CpPX3q9L7TJ5eO8frzf52sv+6EjaVjfG7TNcn69J+3Jett69MXFwHoLbl4yPHCe3qzzDj0Zplx6M0y49CbZcahN8uMQ2+WGYfeLDMOvVlmfHKOJbXOPSVZ33Nu+X7jLVO2J+v37r6kdIxnfrogWT/tZ7uT9VwmyBgK7+nNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZhx6s8z4OL0l9c2alqz3zExfyALgsR1nJuubV3WWjnHGI4eS9Vj/fLqeyQQZQ+E9vVlmHHqzzDj0Zplx6M0y49CbZcahN8uMQ2+WGYfeLDOlJ+dImg98C5gDBLA0Ir4m6Vbgw8DOYtVPR8TDo9WoNUbLvpeT9ZNWTikdY9+6ucn6GavTJ94AtK5OX6Gmv6endAyrGMoZeb3AJyPiSUlTgSckLStqX42IL41ee2ZWb6Whj4itwNbi9n5JTwFdo92YmY2OY3pNL+k04ELg8WLRxyStknSXpBn1bs7M6m/IoZc0BXgA+ERE7AO+DiwAFlF5JvDlGj+3RNIKSSuO4tddZo02pNBLmkgl8PdFxA8AImJ7RPRFRD9wB3DxYD8bEUsjYnFELJ5I+rLHZjb6SkMvScCdwFMR8ZWq5dWfh7wOWFP/9sys3oby7v2bgPcCqyWtLJZ9Grhe0iIqh/E2AjeOQn9mVmeKiLHbmLQT2FS1aBawa8waGD73WV/joc/x0CP8bp+nRsTs1A+Maeh/Z+PSiohY3LAGhsh91td46HM89AjD69On4ZplxqE3y0yjQ7+0wdsfKvdZX+Ohz/HQIwyjz4a+pjezsdfoPb2ZjTGH3iwzDQu9pKskPSPpWUm3NKqPMpI2SlotaaWkFY3uZ0DxIacdktZULZspaZmk9cX3hn4IqkaPt0raUjyeKyVd3cgei57mS3pU0jpJayXdVCxvtsezVp/H9Jg25DW9pFbg18BbgG7gl8D1EbFuzJspIWkjsDgimupEDUl/ABwAvhURFxTLvgjsjojbi/9IZ0TEzU3W463AgWaah6E4pbyzes4I4Frg/TTX41mrz3dxDI9po/b0FwPPRsSGiDgCfA+4pkG9jEsR8Riw+1WLrwHuKW7fQ+UfRMPU6LHpRMTWiHiyuL0fGJgzotkez1p9HpNGhb4L2Fx1v5vmnZgjgJ9IekLSkkY3U2JOMekJwDYqU5w1o6adh+FVc0Y07eM5krkt/EZeuUsj4iLgbcBHi6esTS8qr9ua8XjskOZhaIRB5oz4rWZ6PIc7t8WARoV+CzC/6v68YlnTiYgtxfcdwA+pMW9Ak9g+8JHn4vuOBvfzO4Y6D8NYG2zOCJrw8RzJ3BYDGhX6XwJnSTpdUhvwbuChBvVSk6TJxRsmSJoMXElzzxvwEHBDcfsG4MEG9jKoZpyHodacETTZ41m3uS0ioiFfwNVU3sF/DvhMo/oo6fEM4FfF19pm6hP4LpWnckepvCfyQeAkYDmwHngEmNmEPd4LrAZWUQlVZxM8lpdSeeq+ClhZfF3dhI9nrT6P6TH1abhmmfEbeWaZcejNMuPQm2XGoTfLjENvlhmH3iwzDr1ZZv4fAakkDvh/NJcAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "# an entire mnist digit\n", "image = np.arraydtype=np.float32)\n", "image_torch = torch.from_numpy(image).view(1, 1, 28, 28)\n", "\n", "# a gaussian blur kernel\n", "gaussian_kernel = torch.tensor([[1., 2, 1],[2, 4, 2],[1, 2, 1]]) / 16.0\n", "\n", "conv = nn.Conv2d(1, 1, 3)\n", "# manually set the conv weight\n", "conv.weight.data[:] = gaussian_kernel\n", "\n", "convolved = conv(image_torch)\n", "\n", "plt.title('original image')\n", "plt.imshow(image_torch.view(28,28).detach().numpy())\n", "plt.show()\n", "\n", "plt.title('blurred image')\n", "plt.imshow(convolved.view(26,26).detach().numpy())\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As we can see, the image is blurred as expected. \n", "\n", "In practice, we learn many kernels at a time. In this example, we take in an RGB image (3 channels) and output a 16 channel image. After an activation function, that could be used as input to another `Conv2d` module." ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "im shape torch.Size([4, 3, 32, 32])\n", "convolved im shape torch.Size([4, 16, 30, 30])\n" ] } ], "source": [ "im_channels = 3 # if we are working with RGB images, there are 3 input channels, with black and white, 1\n", "out_channels = 16 # this is a hyperparameter we can tune\n", "kernel_size = 3 # this is another hyperparameter we can tune\n", "batch_size = 4\n", "image_width = 32\n", "image_height = 32\n", "\n", "im = torch.randn(batch_size, im_channels, image_width, image_height)\n", "\n", "m = nn.Conv2d(im_channels, out_channels, kernel_size)\n", "convolved = m(im) # it is a module so we can call it\n", "\n", "print('im shape', im.shape)\n", "print('convolved im shape', convolved.shape)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Useful links:\n", "- [60 minute PyTorch Tutorial](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)\n", "- [PyTorch Docs](https://pytorch.org/docs/stable/index.html)\n", "- [Lecture notes on Auto-Diff](https://courses.cs.washington.edu/courses/cse446/19wi/notes/auto-diff.pdf)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.7" } }, "nbformat": 4, "nbformat_minor": 2 }