{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "7d33e6c4-dabc-4e68-a492-0a42d7248679", "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import MNIST\n", "from torchvision.transforms.v2 import ToImage, Compose, PILToTensor, ToDtype, Lambda, Normalize\n", "import torch\n", "import math\n", "\n", "transform = Compose([\n", " PILToTensor(), \n", " ToDtype(torch.float), \n", " Lambda(lambda x: x / 255.0)\n", "])\n", "train_mnist = MNIST(root=\"mnist\", train=True, download=True, transform=transform)\n", "valid_mnist = MNIST(root=\"mnist\", train=False,download=True, transform=transform)" ] }, { "cell_type": "code", "execution_count": 2, "id": "73edab51-bc6a-499b-ae59-d461a18bfc16", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbe0lEQVR4nO3df2xV9f3H8dflR6+I7e1KbW8rPyygsIlgxqDrVMRRKd1G5McWdS7BzWhwrRGYuNRM0W2uDqczbEz5Y4GxCSjJgEEWNi22ZLNgQBgxbg0l3VpGWyZb7y2FFmw/3z+I98uVFjyXe/u+vTwfySeh955378fjtU9vezn1OeecAADoZ4OsNwAAuDIRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYGKI9QY+qaenR8eOHVN6erp8Pp/1dgAAHjnn1N7ervz8fA0a1PfrnKQL0LFjxzRq1CjrbQAALlNTU5NGjhzZ5/1J9y249PR06y0AAOLgUl/PExag1atX6/rrr9dVV12lwsJCvfvuu59qjm+7AUBquNTX84QE6PXXX9eyZcu0YsUKvffee5oyZYpKSkp0/PjxRDwcAGAgcgkwffp0V1ZWFvm4u7vb5efnu8rKykvOhkIhJ4nFYrFYA3yFQqGLfr2P+yugM2fOaP/+/SouLo7cNmjQIBUXF6u2tvaC47u6uhQOh6MWACD1xT1AH374obq7u5Wbmxt1e25urlpaWi44vrKyUoFAILJ4BxwAXBnM3wVXUVGhUCgUWU1NTdZbAgD0g7j/PaDs7GwNHjxYra2tUbe3trYqGAxecLzf75ff74/3NgAASS7ur4DS0tI0depUVVVVRW7r6elRVVWVioqK4v1wAIABKiFXQli2bJkWLVqkL3zhC5o+fbpefvlldXR06Nvf/nYiHg4AMAAlJED33HOP/vOf/+jpp59WS0uLbrnlFu3cufOCNyYAAK5cPuecs97E+cLhsAKBgPU2AACXKRQKKSMjo8/7zd8FBwC4MhEgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmhlhvAEgmgwcP9jwTCAQSsJP4KC8vj2nu6quv9jwzYcIEzzNlZWWeZ372s595nrnvvvs8z0hSZ2en55nnn3/e88yzzz7reSYV8AoIAGCCAAEATMQ9QM8884x8Pl/UmjhxYrwfBgAwwCXkZ0A33XST3nrrrf9/kCH8qAkAEC0hZRgyZIiCwWAiPjUAIEUk5GdAhw8fVn5+vsaOHav7779fjY2NfR7b1dWlcDgctQAAqS/uASosLNS6deu0c+dOvfLKK2poaNDtt9+u9vb2Xo+vrKxUIBCIrFGjRsV7SwCAJBT3AJWWluob3/iGJk+erJKSEv3xj39UW1ub3njjjV6Pr6ioUCgUiqympqZ4bwkAkIQS/u6AzMxM3Xjjjaqvr+/1fr/fL7/fn+htAACSTML/HtDJkyd15MgR5eXlJfqhAAADSNwD9Pjjj6umpkb//Oc/9c4772j+/PkaPHhwzJfCAACkprh/C+7o0aO67777dOLECV177bW67bbbtGfPHl177bXxfigAwAAW9wBt2rQp3p8SSWr06NGeZ9LS0jzPfOlLX/I8c9ttt3mekc79zNKrhQsXxvRYqebo0aOeZ1atWuV5Zv78+Z5n+noX7qX87W9/8zxTU1MT02NdibgWHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgwuecc9abOF84HFYgELDexhXllltuiWlu165dnmf4dzsw9PT0eJ75zne+43nm5MmTnmdi0dzcHNPc//73P88zdXV1MT1WKgqFQsrIyOjzfl4BAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMQQ6w3AXmNjY0xzJ06c8DzD1bDP2bt3r+eZtrY2zzN33nmn5xlJOnPmjOeZ3/72tzE9Fq5cvAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwMVLov//9b0xzy5cv9zzzta99zfPMgQMHPM+sWrXK80ysDh486Hnmrrvu8jzT0dHheeamm27yPCNJjz32WExzgBe8AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATPicc856E+cLh8MKBALW20CCZGRkeJ5pb2/3PLNmzRrPM5L04IMPep751re+5Xlm48aNnmeAgSYUCl30v3leAQEATBAgAIAJzwHavXu35s6dq/z8fPl8Pm3dujXqfuecnn76aeXl5WnYsGEqLi7W4cOH47VfAECK8Bygjo4OTZkyRatXr+71/pUrV2rVqlV69dVXtXfvXg0fPlwlJSXq7Oy87M0CAFKH59+IWlpaqtLS0l7vc87p5Zdf1g9+8APdfffdkqT169crNzdXW7du1b333nt5uwUApIy4/gyooaFBLS0tKi4ujtwWCARUWFio2traXme6uroUDoejFgAg9cU1QC0tLZKk3NzcqNtzc3Mj931SZWWlAoFAZI0aNSqeWwIAJCnzd8FVVFQoFApFVlNTk/WWAAD9IK4BCgaDkqTW1tao21tbWyP3fZLf71dGRkbUAgCkvrgGqKCgQMFgUFVVVZHbwuGw9u7dq6Kiong+FABggPP8LriTJ0+qvr4+8nFDQ4MOHjyorKwsjR49WkuWLNGPf/xj3XDDDSooKNBTTz2l/Px8zZs3L577BgAMcJ4DtG/fPt15552Rj5ctWyZJWrRokdatW6cnnnhCHR0devjhh9XW1qbbbrtNO3fu1FVXXRW/XQMABjwuRoqU9MILL8Q09/H/UHlRU1Pjeeb8v6rwafX09HieASxxMVIAQFIiQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACa6GjZQ0fPjwmOa2b9/ueeaOO+7wPFNaWup55s9//rPnGcASV8MGACQlAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEFyMFzjNu3DjPM++9957nmba2Ns8zb7/9tueZffv2eZ6RpNWrV3ueSbIvJUgCXIwUAJCUCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATXIwUuEzz58/3PLN27VrPM+np6Z5nYvXkk096nlm/fr3nmebmZs8zGDi4GCkAICkRIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACa4GClgYNKkSZ5nXnrpJc8zs2bN8jwTqzVr1nieee655zzP/Pvf//Y8AxtcjBQAkJQIEADAhOcA7d69W3PnzlV+fr58Pp+2bt0adf8DDzwgn88XtebMmROv/QIAUoTnAHV0dGjKlClavXp1n8fMmTNHzc3NkbVx48bL2iQAIPUM8TpQWlqq0tLSix7j9/sVDAZj3hQAIPUl5GdA1dXVysnJ0YQJE/TII4/oxIkTfR7b1dWlcDgctQAAqS/uAZozZ47Wr1+vqqoq/fSnP1VNTY1KS0vV3d3d6/GVlZUKBAKRNWrUqHhvCQCQhDx/C+5S7r333sifb775Zk2ePFnjxo1TdXV1r38noaKiQsuWLYt8HA6HiRAAXAES/jbssWPHKjs7W/X19b3e7/f7lZGREbUAAKkv4QE6evSoTpw4oby8vEQ/FABgAPH8LbiTJ09GvZppaGjQwYMHlZWVpaysLD377LNauHChgsGgjhw5oieeeELjx49XSUlJXDcOABjYPAdo3759uvPOOyMff/zzm0WLFumVV17RoUOH9Jvf/EZtbW3Kz8/X7Nmz9aMf/Uh+vz9+uwYADHhcjBQYIDIzMz3PzJ07N6bHWrt2recZn8/neWbXrl2eZ+666y7PM7DBxUgBAEmJAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgaNoALdHV1eZ4ZMsTzb3fRRx995Hkmlt8tVl1d7XkGl4+rYQMAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYML71QMBXLbJkyd7nvn617/ueWbatGmeZ6TYLiwaiw8++MDzzO7duxOwE1jgFRAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIKLkQLnmTBhgueZ8vJyzzMLFizwPBMMBj3P9Kfu7m7PM83NzZ5nenp6PM8gOfEKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVIkfRiuQjnfffdF9NjxXJh0euvvz6mx0pm+/bt8zzz3HPPeZ75wx/+4HkGqYNXQAAAEwQIAGDCU4AqKys1bdo0paenKycnR/PmzVNdXV3UMZ2dnSorK9OIESN0zTXXaOHChWptbY3rpgEAA5+nANXU1KisrEx79uzRm2++qbNnz2r27Nnq6OiIHLN06VJt375dmzdvVk1NjY4dOxbTL98CAKQ2T29C2LlzZ9TH69atU05Ojvbv368ZM2YoFArp17/+tTZs2KAvf/nLkqS1a9fqs5/9rPbs2aMvfvGL8ds5AGBAu6yfAYVCIUlSVlaWJGn//v06e/asiouLI8dMnDhRo0ePVm1tba+fo6urS+FwOGoBAFJfzAHq6enRkiVLdOutt2rSpEmSpJaWFqWlpSkzMzPq2NzcXLW0tPT6eSorKxUIBCJr1KhRsW4JADCAxBygsrIyvf/++9q0adNlbaCiokKhUCiympqaLuvzAQAGhpj+Imp5ebl27Nih3bt3a+TIkZHbg8Ggzpw5o7a2tqhXQa2trX3+ZUK/3y+/3x/LNgAAA5inV0DOOZWXl2vLli3atWuXCgoKou6fOnWqhg4dqqqqqshtdXV1amxsVFFRUXx2DABICZ5eAZWVlWnDhg3atm2b0tPTIz/XCQQCGjZsmAKBgB588EEtW7ZMWVlZysjI0KOPPqqioiLeAQcAiOIpQK+88ookaebMmVG3r127Vg888IAk6ec//7kGDRqkhQsXqqurSyUlJfrVr34Vl80CAFKHzznnrDdxvnA4rEAgYL0NfAq5ubmeZz73uc95nvnlL3/peWbixImeZ5Ld3r17Pc+88MILMT3Wtm3bPM/09PTE9FhIXaFQSBkZGX3ez7XgAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYCKm34iK5JWVleV5Zs2aNTE91i233OJ5ZuzYsTE9VjJ75513PM+8+OKLnmf+9Kc/eZ45ffq05xmgv/AKCABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwwcVI+0lhYaHnmeXLl3uemT59uueZ6667zvNMsjt16lRMc6tWrfI885Of/MTzTEdHh+cZINXwCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMMHFSPvJ/Pnz+2WmP33wwQeeZ3bs2OF55qOPPvI88+KLL3qekaS2traY5gB4xysgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMCEzznnrDdxvnA4rEAgYL0NAMBlCoVCysjI6PN+XgEBAEwQIACACU8Bqqys1LRp05Senq6cnBzNmzdPdXV1UcfMnDlTPp8vai1evDiumwYADHyeAlRTU6OysjLt2bNHb775ps6ePavZs2ero6Mj6riHHnpIzc3NkbVy5cq4bhoAMPB5+o2oO3fujPp43bp1ysnJ0f79+zVjxozI7VdffbWCwWB8dggASEmX9TOgUCgkScrKyoq6/bXXXlN2drYmTZqkiooKnTp1qs/P0dXVpXA4HLUAAFcAF6Pu7m731a9+1d16661Rt69Zs8bt3LnTHTp0yP3ud79z1113nZs/f36fn2fFihVOEovFYrFSbIVCoYt2JOYALV682I0ZM8Y1NTVd9LiqqionydXX1/d6f2dnpwuFQpHV1NRkftJYLBaLdfnrUgHy9DOgj5WXl2vHjh3avXu3Ro4cedFjCwsLJUn19fUaN27cBff7/X75/f5YtgEAGMA8Bcg5p0cffVRbtmxRdXW1CgoKLjlz8OBBSVJeXl5MGwQApCZPASorK9OGDRu0bds2paenq6WlRZIUCAQ0bNgwHTlyRBs2bNBXvvIVjRgxQocOHdLSpUs1Y8YMTZ48OSH/AACAAcrLz33Ux/f51q5d65xzrrGx0c2YMcNlZWU5v9/vxo8f75YvX37J7wOeLxQKmX/fksVisViXvy71tZ+LkQIAEoKLkQIAkhIBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwETSBcg5Z70FAEAcXOrredIFqL293XoLAIA4uNTXc59LspccPT09OnbsmNLT0+Xz+aLuC4fDGjVqlJqampSRkWG0Q3uch3M4D+dwHs7hPJyTDOfBOaf29nbl5+dr0KC+X+cM6cc9fSqDBg3SyJEjL3pMRkbGFf0E+xjn4RzOwzmch3M4D+dYn4dAIHDJY5LuW3AAgCsDAQIAmBhQAfL7/VqxYoX8fr/1VkxxHs7hPJzDeTiH83DOQDoPSfcmBADAlWFAvQICAKQOAgQAMEGAAAAmCBAAwMSACdDq1at1/fXX66qrrlJhYaHeffdd6y31u2eeeUY+ny9qTZw40XpbCbd7927NnTtX+fn58vl82rp1a9T9zjk9/fTTysvL07Bhw1RcXKzDhw/bbDaBLnUeHnjggQueH3PmzLHZbIJUVlZq2rRpSk9PV05OjubNm6e6urqoYzo7O1VWVqYRI0bommuu0cKFC9Xa2mq048T4NOdh5syZFzwfFi9ebLTj3g2IAL3++utatmyZVqxYoffee09TpkxRSUmJjh8/br21fnfTTTepubk5sv7yl79YbynhOjo6NGXKFK1evbrX+1euXKlVq1bp1Vdf1d69ezV8+HCVlJSos7Ozn3eaWJc6D5I0Z86cqOfHxo0b+3GHiVdTU6OysjLt2bNHb775ps6ePavZs2ero6MjcszSpUu1fft2bd68WTU1NTp27JgWLFhguOv4+zTnQZIeeuihqOfDypUrjXbcBzcATJ8+3ZWVlUU+7u7udvn5+a6ystJwV/1vxYoVbsqUKdbbMCXJbdmyJfJxT0+PCwaD7oUXXojc1tbW5vx+v9u4caPBDvvHJ8+Dc84tWrTI3X333Sb7sXL8+HEnydXU1Djnzv27Hzp0qNu8eXPkmL///e9OkqutrbXaZsJ98jw459wdd9zhHnvsMbtNfQpJ/wrozJkz2r9/v4qLiyO3DRo0SMXFxaqtrTXcmY3Dhw8rPz9fY8eO1f3336/GxkbrLZlqaGhQS0tL1PMjEAiosLDwinx+VFdXKycnRxMmTNAjjzyiEydOWG8poUKhkCQpKytLkrR//36dPXs26vkwceJEjR49OqWfD588Dx977bXXlJ2drUmTJqmiokKnTp2y2F6fku5ipJ/04Ycfqru7W7m5uVG35+bm6h//+IfRrmwUFhZq3bp1mjBhgpqbm/Xss8/q9ttv1/vvv6/09HTr7ZloaWmRpF6fHx/fd6WYM2eOFixYoIKCAh05ckRPPvmkSktLVVtbq8GDB1tvL+56enq0ZMkS3XrrrZo0aZKkc8+HtLQ0ZWZmRh2bys+H3s6DJH3zm9/UmDFjlJ+fr0OHDun73/++6urq9Pvf/95wt9GSPkD4f6WlpZE/T548WYWFhRozZozeeOMNPfjgg4Y7QzK49957I3+++eabNXnyZI0bN07V1dWaNWuW4c4So6ysTO+///4V8XPQi+nrPDz88MORP998883Ky8vTrFmzdOTIEY0bN66/t9mrpP8WXHZ2tgYPHnzBu1haW1sVDAaNdpUcMjMzdeONN6q+vt56K2Y+fg7w/LjQ2LFjlZ2dnZLPj/Lycu3YsUNvv/121K9vCQaDOnPmjNra2qKOT9XnQ1/noTeFhYWSlFTPh6QPUFpamqZOnaqqqqrIbT09PaqqqlJRUZHhzuydPHlSR44cUV5envVWzBQUFCgYDEY9P8LhsPbu3XvFPz+OHj2qEydOpNTzwzmn8vJybdmyRbt27VJBQUHU/VOnTtXQoUOjng91dXVqbGxMqefDpc5Dbw4ePChJyfV8sH4XxKexadMm5/f73bp169wHH3zgHn74YZeZmelaWlqst9avvve977nq6mrX0NDg/vrXv7ri4mKXnZ3tjh8/br21hGpvb3cHDhxwBw4ccJLcSy+95A4cOOD+9a9/Oeece/75511mZqbbtm2bO3TokLv77rtdQUGBO336tPHO4+ti56G9vd09/vjjrra21jU0NLi33nrLff7zn3c33HCD6+zstN563DzyyCMuEAi46upq19zcHFmnTp2KHLN48WI3evRot2vXLrdv3z5XVFTkioqKDHcdf5c6D/X19e6HP/yh27dvn2toaHDbtm1zY8eOdTNmzDDeebQBESDnnPvFL37hRo8e7dLS0tz06dPdnj17rLfU7+655x6Xl5fn0tLS3HXXXefuueceV19fb72thHv77bedpAvWokWLnHPn3or91FNPudzcXOf3+92sWbNcXV2d7aYT4GLn4dSpU2727Nnu2muvdUOHDnVjxoxxDz30UMr9T1pv//yS3Nq1ayPHnD592n33u991n/nMZ9zVV1/t5s+f75qbm+02nQCXOg+NjY1uxowZLisry/n9fjd+/Hi3fPlyFwqFbDf+Cfw6BgCAiaT/GRAAIDURIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACb+Dwuo74MxItlsAAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "import cv2\n", "import matplotlib.pyplot as plt\n", "\n", "img = (train_mnist[0][0].numpy() * 255).astype(np.uint8)\n", "plt.imshow(img.transpose(1, 2, 0), cmap=\"gray\")" ] }, { "cell_type": "code", "execution_count": 3, "id": "1c12e730-2c97-4b25-87ee-96fdf0197839", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "train_loader = DataLoader (train_mnist, batch_size=1000, shuffle=True)\n", "valid_loader = DataLoader (valid_mnist, batch_size=10000, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 4, "id": "d3af01c9-f20c-4860-a1f3-8a237c761a89", "metadata": {}, "outputs": [], "source": [ "class LeNet_5(torch.nn.Module):\n", " def __init__(self):\n", " super(LeNet_5, self).__init__()\n", " # 卷积层1 :1 @ 28 * 28 - > 6 @ 28 * 28 -> 6 @ 14 * 14 \n", " # 卷积层2 :6 @ 14 * 14 -> 16 @ 10 * 10 -> 16 @ 5 * 5\n", " # 卷积层3 :16 @ 5 * 5 -> 120 @ 1 * 1\n", " self.layer_1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=2) \n", " self.layer_2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), padding=0)\n", " self.layer_3 = torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(5, 5), padding=0)\n", " # 连接层1 : 120 -> 84\n", " # 链接层2 : 84 -> 10\n", " self.layer_4 = torch.nn.Linear(120, 84)\n", " self.layer_5 = torch.nn.Linear(84, 10)\n", " \n", " \n", " def forward(self, input):\n", " # 预测模型实现\n", " # 卷积层\n", " t = self.layer_1(input) \n", " t = torch.nn.functional.relu(t)\n", " t = torch.nn.functional.max_pool2d(t, kernel_size=(2, 2))\n", " \n", " t = self.layer_2(t)\n", " t = torch.nn.functional.relu(t)\n", " t = torch.nn.functional.max_pool2d(t, kernel_size=(2, 2))\n", " \n", " t = self.layer_3(t)\n", " t = torch.nn.functional.relu(t)\n", " \n", " t = t.squeeze() # 长度为1的维数直接降维\n", " # 链接层\n", " t = self.layer_4(t)\n", " t= torch.nn.functional.relu(t)\n", " \n", " t = self.layer_5(t)\n", " t = torch.nn.functional.log_softmax(t, dim=1)\n", " return t\n", " " ] }, { "cell_type": "code", "execution_count": 5, "id": "113f45d6-823d-481f-b9cf-798ec429b962", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "第00轮, \t损失度:0.375775,\t准确率: 89.17%\n", "第10轮, \t损失度:0.059103,\t准确率: 98.57%\n", "第20轮, \t损失度:0.020224,\t准确率: 99.00%\n", "第30轮, \t损失度:0.014173,\t准确率: 99.02%\n", "第40轮, \t损失度:0.016326,\t准确率: 98.86%\n", "第50轮, \t损失度:0.002744,\t准确率: 99.17%\n", "第60轮, \t损失度:0.002227,\t准确率: 99.17%\n", "第70轮, \t损失度:0.008838,\t准确率: 99.01%\n", "第80轮, \t损失度:0.000361,\t准确率: 99.17%\n", "第90轮, \t损失度:0.001122,\t准确率: 99.06%\n" ] } ], "source": [ "# 模型\n", "model = LeNet_5()\n", "parameters = model.parameters()\n", "# 巡视函数\n", "criterion = torch.nn.CrossEntropyLoss()\n", "# 优化器\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 学习率\n", "\n", "# 训练参数\n", "epoch = 100\n", "for e in range(epoch):\n", " # 批次处理\n", " for data, target in train_loader:\n", " # 清空梯度\n", " optimizer.zero_grad() \n", " # 计算输出\n", " out = model(data)\n", " # 计算损失\n", " loss = criterion(out, target)\n", " # 计算梯度\n", " loss.backward()\n", " # 更新梯度\n", " optimizer.step()\n", " # 一轮结束,可以使用测试集测试准确率\n", " if e % 10 == 0:\n", " with torch.no_grad(): # 关闭梯度计算跟踪\n", " for data, target in valid_loader:\n", " y_ = model(data)\n", " predict = torch.argmax(y_, dim=1)\n", " correct_rate = (predict == target).float().mean()\n", " print(F\"第{e:02d}轮, \\t损失度:{loss:8.6f},\\t准确率:{correct_rate * 100: 5.2f}%\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4645eb6e-3d08-4857-bf12-2610d702db57", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.7" } }, "nbformat": 4, "nbformat_minor": 5 }