{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "UNmkdWoXvkBE" }, "source": [ "## Image Classification on CIFAR-10\n", "In this problem we will explore different deep learning architectures for image classification on the CIFAR-10 dataset. Make sure that you are familiar with torch `Tensor`s, two-dimensional convolutions (`nn.Conv2d`) and fully-connected layers (`nn.Linear`), ReLU non-linearities (`F.relu`), pooling (`nn.MaxPool2d`), and tensor reshaping (`view`).\n", "\n", "We will use Colab because it has free GPU runtimes available; GPUs can accelerate training times for this problem by 10-100x. **You will need to enable the GPU runtime to use it**. To do so, click \"Runtime\" above and then \"Change runtime type\". There under hardware accelerator choose \"GPU\".\n", "\n", "This notebook provides some starter code for the CIFAR-10 problem on HW4, including a completed training loop to assist with some of the Pytorch setup. You'll need to modify this code to implement the layers required for the assignment, but this provides a working training loop to start from.\n", "\n", "*Note: GPU runtimes are limited on Colab. Limit your training to short-running jobs (around 20mins or less) and spread training out over time, if possible. Colab WILL limit your usage of GPU time, so plan ahead and be prepared to take breaks during training.* We also suggest performing your early coding/sweeps on a small fraction of the dataset (~10%) to minimize training time and GPU usage." ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "id": "bb7WymOmv_cx" }, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import numpy as np\n", "\n", "from typing import Tuple, Union, List, Callable\n", "from torch.optim import SGD\n", "import torchvision\n", "from torch.utils.data import DataLoader, TensorDataset, random_split\n", "import matplotlib.pyplot as plt\n", "from tqdm.notebook import tqdm\n", "\n", "%matplotlib inline " ] }, { "cell_type": "markdown", "metadata": { "id": "-SusLoH91CEz" }, "source": [ "Let's verify that we are using a gpu:" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vmxNdxwvxNs1", "outputId": "b15a87a2-1199-4be3-bbcb-591d1686d99a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cuda\n" ] } ], "source": [ "assert torch.cuda.is_available(), \"GPU is not available, check the directions above (or disable this assertion to use CPU)\"\n", "\n", "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(DEVICE) # this should print out CUDA" ] }, { "cell_type": "markdown", "metadata": { "id": "LbFk2t2RxYcn" }, "source": [ "To use the GPU you will need to send both the model and data to a device; this transfers the model from its default location on CPU to the GPU.\n", "\n", "Note that torch operations on Tensors will fail if they are not located on the same device.\n", "\n", "```python\n", "model = model.to(DEVICE) # Sending a model to GPU\n", "\n", "for x, y in tqdm(data_loader):\n", " x, y = x.to(DEVICE), y.to(DEVICE)\n", "```\n", "When reading tensors you may need to send them back to cpu, you can do so with `x = x.cpu()`." ] }, { "cell_type": "markdown", "metadata": { "id": "xODE5P5D1Wwy" }, "source": [ "Let's load CIFAR-10 data. This is how we load datasets using PyTorch in the real world!" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wOvLuZry1cKc", "outputId": "080f62ec-5519-46b8-d465-fa75816613cb" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "train_dataset = torchvision.datasets.CIFAR10(\"./data\", train=True, download=True, transform=torchvision.transforms.ToTensor())\n", "test_dataset = torchvision.datasets.CIFAR10(\"./data\", train=False, download=True, transform=torchvision.transforms.ToTensor())" ] }, { "cell_type": "markdown", "metadata": { "id": "oG78KMOj61HJ" }, "source": [ "Here, we'll use the torch `DataLoader` to wrap our datasets. `DataLoader`s handle batching, shuffling, and iterating over data; they can also be useful for building more complex input pipelines that perform transfoermations such as data augmentation." ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "id": "mHIZY4BoRxvC" }, "outputs": [], "source": [ "batch_size = 128\n", "\n", "train_dataset, val_dataset = random_split(train_dataset, [int(0.9 * len(train_dataset)), int( 0.1 * len(train_dataset))])\n", "\n", "# Create separate dataloaders for the train, test, and validation set\n", "train_loader = DataLoader(\n", " train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True\n", ")\n", "\n", "val_loader = DataLoader(\n", " val_dataset,\n", " batch_size=batch_size,\n", " shuffle=True\n", ")\n", "\n", "test_loader = DataLoader(\n", " test_dataset,\n", " batch_size=batch_size,\n", " shuffle=True\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "ju8QRAyv2u7F" }, "source": [ "## For Reference: Logistic Regression\n", "\n", "This problem is about deep learning architectures, not pytorch. We are providing an implementation of logistic regression using SGD in torch, which can serve as a template for the rest of your implementation in this problem." ] }, { "cell_type": "markdown", "metadata": { "id": "ELjydYRt5Frf" }, "source": [ "Before we get started, let's take a look at our data to get an understanding of what we are doing. CIFAR-10 is a dataset containing images split into 10 classes." ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 370 }, "id": "_UU8aIle5m8Q", "outputId": "f55d6286-7701-473e-fbbb-b0e960abe6f4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "A single batch of images has shape: torch.Size([128, 3, 32, 32])\n", "A single RGB image has 3 channels, width 32, and height 32.\n", "Size of a batch of images flattened with view: torch.Size([128, 3072])\n", "Size of a batch of images flattened with flatten: torch.Size([128, 3072])\n", "True\n", "This image is labeled as class frog\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPsAAAD5CAYAAADhukOtAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/Il7ecAAAACXBIWXMAAAsTAAALEwEAmpwYAAAZ50lEQVR4nO2de4ycV3nGn3dmr967vev12vElNg65kevihAaSlECVAmqISlNARVGFMKpAKlL7R0SlQqtWolUB8QcCmSYQLiWES5S0pJc0SptyC9lAcEwudhJ8WXu9F3vX3rV317szb/+Yieqk5zm7nus65/lJlmfPM2e+d7+ZZ77Z8877HnN3CCFe/2TqHYAQojbI7EIkgswuRCLI7EIkgswuRCLI7EIkQkM5k83sVgBfAJAF8I/u/pnY/Xt7e33Lli3lHPI1VCNtaJU9nkcezyKPV+q8ihM7H2KlsX//fkxMTASftJLNbmZZAF8E8E4AwwCeNLOH3P1ZNmfLli0YGhoKarlcLnassOD55Qf8qgeMibEPO+EYo19V8Mjj2WJkHn9qLGZ2qpX6BpGNaHojqASV/K7LddddR7VyPsbvAPCiu7/s7mcA3AfgtjIeTwhRRcox+wYAh876ebg4JoRYgVR9gc7MdprZkJkNjY+PV/twQghCOWY/DGDjWT9fUBx7Fe6+y90H3X2wr6+vjMMJIcqhHLM/CWC7mV1oZk0A3g/gocqEJYSoNCWvxrv7opl9HMC/o7Bke4+7/7rUx8tmY6u+jNhqfGyFM3KsSMorT1afLbIqbRke4/zCCao1ZNuolsm28OPR9+9oyiCixa4HWo2vBLWqPC0rz+7uDwN4uEKxCCGqiL5BJ0QiyOxCJILMLkQiyOxCJILMLkQilLUaX288VhkWSa/F6mcswx8zY+R0RcIYO7aPanteeIRqTU3dVLvisvdQraWpKzjeYPx8RBNokaIbVp8kzg1a6FVhdGUXIhFkdiESQWYXIhFkdiESQWYXIhFWzGp8rBiAr1aWtuLO2ksBwMnpI1ybCdfjLy4u0Dn//ZP7qHZ8ZjfVhoenqZbPhlfcAeC3rn5XeE6kA1Y29p4fkTxSQFOrFWaxfHRlFyIRZHYhEkFmFyIRZHYhEkFmFyIRZHYhEmHFpN5KwSMptEyWv4/tefZnVPvvn36baqfmDwXHm1t4CnB6Zphqc2eOUe3gMNcefuweqjU3NQXHL92yg85pb+mkmnrQvX7QlV2IRJDZhUgEmV2IRJDZhUgEmV2IRJDZhUiEslJvZrYfwDQKZWSL7j5YiaCWTWTXHMc81f7np9/n2k9+SLW1A93B8fZV/FjrulZRDWikSm8/r2wbneDVct/87t8Hx3/rTbfROb//7j+mWmNzB9U8lnpjz43xdGns2hPbYqvSxHZjivfriz5qqRMrRiXy7L/t7hMVeBwhRBXRx3ghEqFcszuA/zCzp8xsZyUCEkJUh3I/xr/V3Q+b2VoAj5jZ8+7++Nl3KL4J7ASATZs2lXk4IUSplHVld/fDxf/HADwA4P99Advdd7n7oLsP9vX1lXM4IUQZlGx2M2szs45XbgP4HQB7KhWYEKKylPMxvh/AA8XGgg0A/snd/60iUS0Ti2xplI90WOxdy9NJ7e1tVDszG06jzZ2hUzA5yZtRTpyeo1pTP69EG1jbSrV9+/YGx7/74BfpnPW966h20023U21xkV8rspnwSyveiLLU9FRsXliLbh0W0WLT4gm08zj15u4vA7iygrEIIaqIUm9CJILMLkQiyOxCJILMLkQiyOxCJML53XAysqFbNttMtYu28uaLBy7hXxUYHx0Ljs+M8Oq1sRle5XXKeZpvbOI41VpbeqjW0R6O5eSxUTrn6/f/LdWaWnl687o330o1eEt4PM8fL5qWK7miLPwaiR3LncdYOvVvzqkruxCJILMLkQgyuxCJILMLkQgyuxCJcF6vxkdrKiKs67uYaj1d/VTLnTkRHG+a46vxJ3O82GVhihfrYI6vCJ84witvvDn8lHZ38ZX/yakDVLv3m5+j2uypWapdt+OdwfG2tl46B7FV8FgFSuSalc/HVurDWKYaRStajRdC1AiZXYhEkNmFSASZXYhEkNmFSASZXYhEOM9Tb6W9V/X18pbWF268lmqnZoaD47nOGTrn2HGeeuvp4sU6XW18+6eG0/z3zjSEH/P0PO+F19zEU4DDh1+g2ue/8NdUu/GtTwbHb7vt9+mcTRdcQjVaWAOgqZGfx0xEY+RyfDsvI731AMBQjQKayqEruxCJILMLkQgyuxCJILMLkQgyuxCJILMLkQhLpt7M7B4A7wEw5u6XF8dWA/gOgC0A9gO4w90nywkkvi1QKfBUUybDq9SuvZL3VRubCPenm5vlfevm87xPXlueV6+t7+ApowXnT9uJ0+Ged+0N7XTOzAJPNeUWeepw8thBqj3w4D3B8b0v/ojOufwynvbs7lxPtVWta6h2wYaLguMXbn0jndPXP0A1RPoeIpoKPj+q3r4G4LUOuAvAo+6+HcCjxZ+FECuYJc1e3G/9ta1ObwNwb/H2vQDeW9mwhBCVptS/2fvdfaR4+ygKO7oKIVYwZS/Qubsj0rjbzHaa2ZCZDY2Pj5d7OCFEiZRq9lEzGwCA4v/h3RMAuPsudx9098G+vr4SDyeEKJdSzf4QgDuLt+8E8GBlwhFCVIvlpN6+DeBmAL1mNgzgUwA+A+B+M/swgAMA7ig3kMJfAzSGUh4xciy+JVNXB0+77Lj294LjxyZ51rG1bYpqjTOnqHbdlVdR7cgY3xrqx3v2BccXmvn7+uou/olr/hQ/j7l2/rws5MIpqkNHw/EBwJFRrm1cv5Vqx8Z548uR4fA5fvOOG+ic29/3Qapdv+PtVGts7KBarIVlhp7Gc2+WGWNJs7v7B4h0S0UjEUJUFX2DTohEkNmFSASZXYhEkNmFSASZXYhEOK8bTsbhlW3xTB4XN2/YERy/7OIjdM7Lz/OGjWtW8f3XDg6Hm1sCwGJki7hMNtyYceY0rwLs7OFx9HR0U+30SZ4aOj0brpbzJl411tkWaRwZ2X/N89NUY/vz7dv7czrn0cf4a2Bujqf5bnzbe6nW0MAbZsLCv1upr1OGruxCJILMLkQiyOxCJILMLkQiyOxCJILMLkQivI5Tb5Vv8JexVcHxKy99G52z/6XdVBs78AzVrDmytxl4+qqL9LA8Pc3TU9lWnqacO85Tdr19vEFRHuFqs9l5XiHY3BU+vwDQ1MRjbIuk7No7wvMmj03ROXuff55qY6NfpdqaNWupdu3VvFrOWXVbtOhNqTchBEFmFyIRZHYhEkFmFyIRZHYhEuF1vBpfBcjqaE/XBjrlphv/gGq7vsx7rvX39VBtdnaKaqOnwoUfJ+Z5v7v+Fr7S3dLBXyJN1kQ1z4dXi1ua+e+1CL7V1MGDI1TLnYkU+XS1Bscnj5+kcw79hhchrWrjq+Dfuf+LVMvztoe45pq3BMczxrMMWo0XQlBkdiESQWYXIhFkdiESQWYXIhFkdiESYTnbP90D4D0Axtz98uLYpwF8BMAr27J+0t0frlaQbGuo0raFKh12PM/z98z16y6m2uYtl1Htp0M/pFp7B0/JzJMGddksr6rIGm9qt66/k2otDXy7o4yHC15OnzpN58QuPataeZ+8E7MzVLOG8HPWu5rHfnKa95k7cYxvvRWpT8K//PM3qZbLhyfuGLyZzjHL8oMRlnNl/xqAWwPjn3f3q4r/qmZ0IURlWNLs7v44gMjbmRDifKCcv9k/bma7zeweM+NfixJCrAhKNfuXAGwDcBWAEQCfZXc0s51mNmRmQ+Pj4+xuQogqU5LZ3X3U3XPungfwFQDh3RMK993l7oPuPtjXx/cBF0JUl5LMbmYDZ/14O4A9lQlHCFEtlpN6+zaAmwH0mtkwgE8BuNnMrkKhDmw/gI+WGwhLrxVjKPfhK4OFUyQWqUBqbmqn2o5rw9VOADB64CdUOzoxRbUGsu1VJ+nFBgCtztM4jeBpvtw8f84aGsKPmQcv/5qb59oiL2zDqhaeloOH04qrVvF048IZ0sgPwPjIFNXWrOafXMfGXqbaC3vDfQrffO2NdE6mBEssaXZ3/0Bg+O5zP5QQop7oG3RCJILMLkQiyOxCJILMLkQiyOxCJMKKaTi5YtJrUcKpoXyep6BykZTRmg6elhvoXUO1E5M8bTQ+Hi5jaMrxp/p0nqeaZiZ5Oiyf4Q0nf3PkSHA808zjiO12lHE+b26Gn+SFhXATy97VfKuplmaeplxc5L/zsTG+tdWGzTw9uH37NqLwa3EsVc3QlV2IRJDZhUgEmV2IRJDZhUgEmV2IRJDZhUiEFZN6i1H5tBzvDFgo0Q9zcmoqOD565DCdMza6n2qnp16kWn6ep3jGJniDxYaFcErm+KFpOifXHXkZZPj56OhqoVpzS7jqramV/17794f3qQOAfI7P27Ce77V36GB437YzC7zxZW83/706Wvn1sZVn89DTtZZqV14Wrn60SLoUDUq9CSEIMrsQiSCzC5EIMrsQiSCzC5EI58VqfKWJrbh7ZKU+Q94aMw28WGRm7hjVRo7yVXzP8KXdqRN8e6Ibb7goHMckX8F/6eABqq3t76ZacxNftb78kvDWVoeGx/jjNZykWj4bWQVvaaVa/7qu4HhDIy/+wQIvrGlo5v361g2sp9rEBM80PLvnmeD44LXvoHNYP8QYurILkQgyuxCJILMLkQgyuxCJILMLkQgyuxCJsJztnzYC+DqAfhTahO1y9y+Y2WoA3wGwBYUtoO5wd96EayURK6yJ1BcskoZycwvzdE7fwIVUW9XC02v5OZ7+GZ+OxJ8Jp9G61/CnevFlnoY6deIU1WYR7u8GACdnwi+F0TFekNOzhqfypqd5jHt+9SzVLr9iIDh+yZs20TnDB3i69NQJ/js3taymWovz5+wnP30kOH7xJdfSOR1d4ZRijOVc2RcB/Jm7XwrgegAfM7NLAdwF4FF33w7g0eLPQogVypJmd/cRd/9F8fY0gOcAbABwG4B7i3e7F8B7qxSjEKICnNPf7Ga2BcDVAJ4A0O/uI0XpKAof84UQK5Rlm93M2gF8H8An3P1V32v0QhPr4F+7ZrbTzIbMbGh8fLysYIUQpbMss5tZIwpG/5a7/6A4PGpmA0V9AEDwS8/uvsvdB919sK+P718thKguS5rdCj2h7gbwnLt/7izpIQB3Fm/fCeDByocnhKgUy6l6uwHAhwA8Y2ZPF8c+CeAzAO43sw8DOADgjqpEWA0iW+fEttWZmwunf9as5j3QeteGUz8A4It8G6emDH9qmtp4P7NvfC2cFFnby1M1c7wdG9b18d5vPV3dVBv55UvB8VWr+ONNTvCKuA3rN1KtvYk/Zl9Pd3C80fh1bjHyvBwa5tnl9s7w1lsAECm0xMHMvuD49DR/vFJSb0ua3d1/BIAlCW855yMKIeqCvkEnRCLI7EIkgswuRCLI7EIkgswuRCIk2XAyxqlTvMrL8+G03Lp+nnrLZBuptpDj77ULC7zKa2CAfzO5qzOsXXEZr6BqauKPd2Tk11TriKTetr0x3PhyIccrBOfneqh20bZLqTZyZIJqw0fCVYAL85HU22keI/K8eu3o4UNUyxhP6Xb3hhtVtra28ThKQFd2IRJBZhciEWR2IRJBZhciEWR2IRJBZhciEZJMvVmk4WQmUg119Gg4tTIzxVM/m7ZeQrXmJn6svfueotoTP3+Maqtaw/uenT7F00njEyNUsyzfRw0ZniqbJ0042zp5U8m1veuodnqGV4DNnzlCtROT4eemo3UNndMaSZfmzvCKuO72Tqq1NFMJF6wPNyXt6ODnN9oZlaAruxCJILMLkQgyuxCJILMLkQgyuxCJkORqvEe24mlr76DaRW8MF2M8+8wv6ZzDP36Uaps28m67Tz/FV9yf+tnDVNuyJdzzbu/+5+iciSle/HPZlVdTbVUn74XX1x++jjQ38lXk/fv2UM3sJNXa2rNUW90TXv3PLfCmcJPHZ6jW3s637OpczZ/PtWt5sdG73/P+4HhDE1/C91hTO4Ku7EIkgswuRCLI7EIkgswuRCLI7EIkgswuRCIsmXozs40Avo7ClswOYJe7f8HMPg3gIwBe2Zr1k+7Oc0IrCp56yy3ylEZjc7gn2OZtb6BzHn/sh1T7tx9+lWqzJ0eptmOQF9eMTYXnTc7wbYvu+OBHqHbzLe+m2sED4W2LAOA///WB4PjY4ZfpnOlpXlDU28cLciYneVruxIlwWjGf4z3+Nm3eRLVr3rCdavt+M0y1P/zdP6Latu3h9GZsK7LYa5ixnDz7IoA/c/dfmFkHgKfM7JGi9nl3/4dzPqoQouYsZ6+3EQAjxdvTZvYcAN5OVQixIjmnv9nNbAuAqwE8URz6uJntNrN7zCxWfCuEqDPLNruZtQP4PoBPuPtJAF8CsA3AVShc+T9L5u00syEzGxofHw/dRQhRA5ZldjNrRMHo33L3HwCAu4+6e84LX9L9CoAdobnuvsvdB919sK+Pf3dYCFFdljS7FXo43Q3gOXf/3FnjZ1dc3A6AVzEIIerOclbjbwDwIQDPmNnTxbFPAviAmV2FQjpuP4CPViG+qhBpQYe5uTmqvfji88Hx2TmeMnLjKR5kudbS0US1A4eP8Xlt7cHxm26+ns5505v41lCz0zzG9pbwsQBgduZEcPyFF56mczKZyO88zFOHXd08LdfSHq56W5jnL4ILNm+jWj7Pe+jd8o7bqfaW699BNd5PLpZeO/cedMtZjf8ROep5klMXQgD6Bp0QySCzC5EIMrsQiSCzC5EIMrsQiZBkw8lYSqOxkW/9M7COlAQY/7LQmp4uqq3tW021A/v3Um1hjlfmbb5wc3A808DTU0NP/oJqk8fHqDZ7aopqp8h2Ux3tF9A5HV3dVJtfXKBabx/fdqkxG06Vzc7m6JyNWy+j2tY3XE61q665gWrZDG9UyYrbYiniEjJvurILkQoyuxCJILMLkQgyuxCJILMLkQgyuxCJkGjqjdPczPfX6l8X3kctlgbp799Ite3br6Da7GykieIkrwAbPhxueriml+/LtnUbz/E8/cv/opqBp7x27LgpON6QCTftBIDWNl5R1tzG93PLLfLKvFwunGJraeV7+vWt5V3XOjtjPRl42jb2Gomm2Co4SVd2IRJBZhciEWR2IRJBZhciEWR2IRJBZhciEVZM6s1Kyj9Untj+Wp4PaxbJq8S36+KVaKvaeIoK4I0e2zvXB8c7OrvpHHfeZLOri6eouiJVatkGlpar9UuOVQhGnrNIVeQiSeUBQEM2ll/jUin7tpWCruxCJILMLkQiyOxCJILMLkQiyOxCJMKSS6Nm1gLgcQDNxft/z90/ZWYXArgPwBoATwH4kLtH9jo6P4hlBYyutvKecAZewBF7r83l+Mpucytfjc9mwvHnI6vIHomjpyfc064An5fPs3MSiSPaV40fyyMr62bhOPKRg5lxW2SzkWKXyOsgzspZjZ8H8HZ3vxKF7ZlvNbPrAfwdgM+7+xsATAL4cNWiFEKUzZJm9wIzxR8bi/8cwNsBfK84fi+A91YjQCFEZVju/uzZ4g6uYwAeAfASgCl3XyzeZRgALwIWQtSdZZnd3XPufhWACwDsAHDxcg9gZjvNbMjMhsbHx0uLUghRNue0Gu/uUwAeA/AWAN32fysZFwA4TObscvdBdx/s64t1+RBCVJMlzW5mfWbWXbzdCuCdAJ5DwfTvK97tTgAPVilGIUQFWE5VwgCAe80si8Kbw/3u/i9m9iyA+8zsbwD8EsDd5QTCeoUB50ORTCy+WDqmhD18lpjGTmM0pRh5GcTSYfGUFzteKXOWmBeZxa5n2cik2O/FiqGWmhfn3FN2sYItxpJmd/fdAK4OjL+Mwt/vQojzAH2DTohEkNmFSASZXYhEkNmFSASZXYhEsFKW8Es+mNk4gAPFH3sBTNTs4BzF8WoUx6s53+LY7O7Bb6/V1OyvOrDZkLsP1uXgikNxJBiHPsYLkQgyuxCJUE+z76rjsc9GcbwaxfFqXjdx1O1vdiFEbdHHeCESoS5mN7NbzewFM3vRzO6qRwzFOPab2TNm9rSZDdXwuPeY2ZiZ7TlrbLWZPWJm+4r/99Qpjk+b2eHiOXnazN5Vgzg2mtljZvasmf3azP60OF7TcxKJo6bnxMxazOznZvarYhx/VRy/0MyeKPrmO2bWdE4P7O41/Qcgi0Jbq60AmgD8CsCltY6jGMt+AL11OO6NAK4BsOessb8HcFfx9l0A/q5OcXwawJ/X+HwMALimeLsDwF4Al9b6nETiqOk5QaFqt714uxHAEwCuB3A/gPcXx78M4E/O5XHrcWXfAeBFd3/ZC62n7wNwWx3iqBvu/jiA468Zvg2Fxp1AjRp4kjhqjruPuPsvirenUWiOsgE1PieROGqKF6h4k9d6mH0DgENn/VzPZpUO4D/M7Ckz21mnGF6h391HirePAuivYywfN7PdxY/5Vf9z4mzMbAsK/ROeQB3PyWviAGp8TqrR5DX1Bbq3uvs1AH4XwMfM7MZ6BwQU3tlRchubsvkSgG0o7BEwAuCztTqwmbUD+D6AT7j7ybO1Wp6TQBw1PydeRpNXRj3MfhjAxrN+ps0qq427Hy7+PwbgAdS3886omQ0AQPH/sXoE4e6jxRdaHsBXUKNzYmaNKBjsW+7+g+Jwzc9JKI56nZPisadwjk1eGfUw+5MAthdXFpsAvB/AQ7UOwszazKzjldsAfgfAnvisqvIQCo07gTo28HzFXEVuRw3OiRWaz90N4Dl3/9xZUk3PCYuj1uekak1ea7XC+JrVxnehsNL5EoC/qFMMW1HIBPwKwK9rGQeAb6PwcXABhb+9PozCnnmPAtgH4D8BrK5THN8A8AyA3SiYbaAGcbwVhY/ouwE8Xfz3rlqfk0gcNT0nAK5AoYnrbhTeWP7yrNfszwG8COC7AJrP5XH1DTohEiH1BTohkkFmFyIRZHYhEkFmFyIRZHYhEkFmFyIRZHYhEkFmFyIR/hexgOXVZbQJfwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "imgs, labels = next(iter(train_loader))\n", "print(f\"A single batch of images has shape: {imgs.size()}\")\n", "example_image, example_label = imgs[0], labels[0]\n", "c, w, h = example_image.size()\n", "print(f\"A single RGB image has {c} channels, width {w}, and height {h}.\")\n", "\n", "# This is one way to flatten our images\n", "batch_flat_view = imgs.view(-1, c * w * h)\n", "print(f\"Size of a batch of images flattened with view: {batch_flat_view.size()}\")\n", "\n", "# This is another equivalent way\n", "batch_flat_flatten = imgs.flatten(1)\n", "print(f\"Size of a batch of images flattened with flatten: {batch_flat_flatten.size()}\")\n", "\n", "# The new dimension is just the product of the ones we flattened\n", "d = example_image.flatten().size()[0]\n", "print(c * w * h == d)\n", "\n", "# View the image\n", "t = torchvision.transforms.ToPILImage()\n", "plt.imshow(t(example_image))\n", "\n", "# These are what the class labels in CIFAR-10 represent. For more information,\n", "# visit https://www.cs.toronto.edu/~kriz/cifar.html\n", "classes = [\"airplane\", \"automobile\", \"bird\", \"cat\", \"deer\", \"dog\", \"frog\",\n", " \"horse\", \"ship\", \"truck\"]\n", "print(f\"This image is labeled as class {classes[example_label]}\")\n" ] }, { "cell_type": "markdown", "metadata": { "id": "JdBbLQdC_mA_" }, "source": [ "In this problem, we will attempt to predict what class an image is labeled as." ] }, { "cell_type": "markdown", "metadata": { "id": "J3HQyGMn_42D" }, "source": [ "First, let's create our model. For a linear model we could flatten the data before passing it into the model, but that is not be the case for the convolutional neural network." ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "id": "5qLVC9PbACDt" }, "outputs": [], "source": [ "def linear_model() -> nn.Module:\n", " \"\"\"Instantiate a linear model and send it to device.\"\"\"\n", " model = nn.Sequential(\n", " nn.Flatten(),\n", " nn.Linear(d, 10)\n", " )\n", " return model.to(DEVICE)" ] }, { "cell_type": "markdown", "metadata": { "id": "kd49udL8AZ_E" }, "source": [ "Let's define a method to train this model using SGD as our optimizer." ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "id": "YgcFP1-UBj1Z" }, "outputs": [], "source": [ "def train(\n", " model: nn.Module, optimizer: SGD,\n", " train_loader: DataLoader, val_loader: DataLoader,\n", " epochs: int = 20\n", " )-> Tuple[List[float], List[float], List[float], List[float]]:\n", " \"\"\"\n", " Trains a model for the specified number of epochs using the loaders.\n", "\n", " Returns: \n", " Lists of training loss, training accuracy, validation loss, validation accuracy for each epoch.\n", " \"\"\"\n", "\n", " loss = nn.CrossEntropyLoss()\n", " train_losses = []\n", " train_accuracies = []\n", " val_losses = []\n", " val_accuracies = []\n", " for e in tqdm(range(epochs)):\n", " model.train()\n", " train_loss = 0.0\n", " train_acc = 0.0\n", "\n", " # Main training loop; iterate over train_loader. The loop\n", " # terminates when the train loader finishes iterating, which is one epoch.\n", " for (x_batch, labels) in train_loader:\n", " x_batch, labels = x_batch.to(DEVICE), labels.to(DEVICE)\n", " optimizer.zero_grad()\n", " labels_pred = model(x_batch)\n", " batch_loss = loss(labels_pred, labels)\n", " train_loss = train_loss + batch_loss.item()\n", "\n", " labels_pred_max = torch.argmax(labels_pred, 1)\n", " batch_acc = torch.sum(labels_pred_max == labels)\n", " train_acc = train_acc + batch_acc.item()\n", "\n", " batch_loss.backward()\n", " optimizer.step()\n", " train_losses.append(train_loss / len(train_loader))\n", " train_accuracies.append(train_acc / (batch_size * len(train_loader)))\n", "\n", " # Validation loop; use .no_grad() context manager to save memory.\n", " model.eval()\n", " val_loss = 0.0\n", " val_acc = 0.0\n", "\n", " with torch.no_grad():\n", " for (v_batch, labels) in val_loader:\n", " v_batch, labels = v_batch.to(DEVICE), labels.to(DEVICE)\n", " labels_pred = model(v_batch)\n", " v_batch_loss = loss(labels_pred, labels)\n", " val_loss = val_loss + v_batch_loss.item()\n", "\n", " v_pred_max = torch.argmax(labels_pred, 1)\n", " batch_acc = torch.sum(v_pred_max == labels)\n", " val_acc = val_acc + batch_acc.item()\n", " val_losses.append(val_loss / len(val_loader))\n", " val_accuracies.append(val_acc / (batch_size * len(val_loader)))\n", "\n", " return train_losses, train_accuracies, val_losses, val_accuracies\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2ZUqV2ZrEf-u" }, "source": [ "For this problem, we will be using SGD. The two hyperparameters for our linear model trained with SGD are the learning rate and momentum. Only learning rate will be searched for in this example.\n", "\n", "Note: We ask you to plot the accuracies for the best 5 models for each structure, so you will need to return multiple sets of hyperparameters for the homework, or, if you do random search, run your hyperparameter search multiple times." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "def parameter_search(train_loader: DataLoader, \n", " val_loader: DataLoader, \n", " model_fn:Callable[[], nn.Module]) -> float:\n", " \"\"\"\n", " Parameter search for our linear model using SGD.\n", "\n", " Args:\n", " train_loader: the train dataloader.\n", " val_loader: the validation dataloader.\n", " model_fn: a function that, when called, returns a torch.nn.Module.\n", "\n", " Returns:\n", " The learning rate with the least validation loss.\n", " NOTE: you may need to modify this function to search over and return\n", " other parameters beyond learning rate.\n", " \"\"\"\n", " num_iter = 10 # This will likely not be enough for the rest of the problem.\n", " best_loss = torch.tensor(np.inf)\n", " best_lr = 0.0\n", "\n", " lrs = torch.linspace(10 ** (-6), 10 ** (-1), num_iter)\n", "\n", " for lr in lrs:\n", " print(f\"trying learning rate {lr}\")\n", " model = model_fn()\n", " optim = SGD(model.parameters(), lr)\n", " train_loss, train_acc, val_loss, val_acc = train(\n", " model,\n", " optim,\n", " train_loader,\n", " val_loader,\n", " epochs=20\n", " )\n", "\n", " if min(val_loss) < best_loss:\n", " best_loss = min(val_loss)\n", " best_lr = lr\n", " \n", " return best_lr" ] }, { "cell_type": "markdown", "metadata": { "id": "ZULxD9sGHm1D" }, "source": [ "Now that we have everything, we can train and evaluate our model." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "azwXX-AEIGKx", "outputId": "55b65e28-7d24-4924-e8a7-bbb5b1fc22ff" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "trying learning rate 9.999999974752427e-07\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b854d671884f4665a52c35a60e10ee17", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/20 [00:00" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "epochs = range(1, 21)\n", "plt.plot(epochs, train_accuracy, label=\"Train Accuracy\")\n", "plt.plot(epochs, val_accuracy, label=\"Validation Accuracy\")\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Accuracy\")\n", "plt.legend()\n", "plt.title(\"Logistic Regression Accuracy for CIFAR-10 vs Epoch\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "zu8AaNRl4FDi" }, "source": [ "The last thing we have to do is evaluate our model on the testing data." ] }, { "cell_type": "code", "execution_count": 15, "metadata": { "id": "KzxYGlYH4MwC" }, "outputs": [], "source": [ "def evaluate(\n", " model: nn.Module, loader: DataLoader\n", ") -> Tuple[float, float]:\n", " \"\"\"Computes test loss and accuracy of model on loader.\"\"\"\n", " loss = nn.CrossEntropyLoss()\n", " model.eval()\n", " test_loss = 0.0\n", " test_acc = 0.0\n", " with torch.no_grad():\n", " for (batch, labels) in loader:\n", " batch, labels = batch.to(DEVICE), labels.to(DEVICE)\n", " y_batch_pred = model(batch)\n", " batch_loss = loss(y_batch_pred, labels)\n", " test_loss = test_loss + batch_loss.item()\n", "\n", " pred_max = torch.argmax(y_batch_pred, 1)\n", " batch_acc = torch.sum(pred_max == labels)\n", " test_acc = test_acc + batch_acc.item()\n", " test_loss = test_loss / len(loader)\n", " test_acc = test_acc / (batch_size * len(loader))\n", " return test_loss, test_acc" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ekfbYD_34XFB", "outputId": "97ef26db-193f-40ba-e9d4-3d005e1ed61a" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Test Accuracy: 0.38558148734177217\n" ] } ], "source": [ "test_loss, test_acc = evaluate(model, test_loader)\n", "print(f\"Test Accuracy: {test_acc}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OK-KuYt41uFA" }, "source": [ "The rest is yours to code. You can structure the code any way you would like.\n", "\n", "We do advise making using code cells and functions (train, search, predict etc.) for each subproblem, since they will make your code easier to debug. \n", "\n", "Also note that several of the functions above can be reused for the various different models you will implement for this problem; i.e., you won't need to write a new `evaluate()`." ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "image-classification-on-cifar-10-starter-code.ipynb", "provenance": [] }, "kernelspec": { "display_name": "cse446", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" }, "varInspector": { "cols": { "lenName": 16, "lenType": 16, "lenVar": 40 }, "kernels_config": { "python": { "delete_cmd_postfix": "", "delete_cmd_prefix": "del ", "library": "var_list.py", "varRefreshCmd": "print(var_dic_list())" }, "r": { "delete_cmd_postfix": ") ", "delete_cmd_prefix": "rm(", "library": "var_list.r", "varRefreshCmd": "cat(var_dic_list()) " } }, "types_to_exclude": [ "module", "function", "builtin_function_or_method", "instance", "_Feature" ], "window_display": false }, "vscode": { "interpreter": { "hash": "bf814fda1cc440ac317e22279a1ec33d1b20faeecc9ea242b28923e33d4f784d" } } }, "nbformat": 4, "nbformat_minor": 1 }