{ "cells": [ { "cell_type": "markdown", "id": "4b085841-a82a-45e8-be55-18a07001836c", "metadata": {}, "source": [ "# 1. 数据集处理" ] }, { "cell_type": "markdown", "id": "9e96f81e-0ce6-48bc-8cc4-2af978ead7db", "metadata": {}, "source": [ "## 1.1. 数据集加载与转换" ] }, { "cell_type": "code", "execution_count": 1, "id": "2c567428-5d29-4dde-8c11-d8e8f730fab3", "metadata": {}, "outputs": [], "source": [ "from torchvision.datasets import MNIST # 数据集下载\n", "from torchvision.transforms.v2 import Compose, PILToTensor, ToDtype, Lambda\n", "import torch\n" ] }, { "cell_type": "code", "execution_count": 2, "id": "039899e8-46b7-440a-8884-776faab320ae", "metadata": { "scrolled": true }, "outputs": [], "source": [ "transform = Compose([\n", " PILToTensor(), \n", " ToDtype(torch.float), \n", " Lambda(lambda x: x / 255.0)\n", "])\n", "\n", "train_mnist = MNIST(root=\"mnist\", train=True, download=True, transform=transform)\n", "valid_mnist = MNIST(root=\"mnist\", train=False, download=True, transform=transform)" ] }, { "cell_type": "code", "execution_count": 3, "id": "f51f70ac-3690-41b3-bf88-5ae8971a905d", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(60000, 10000)" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(train_mnist), len(valid_mnist)" ] }, { "cell_type": "code", "execution_count": 4, "id": "3f866a79-07c6-4dea-b0ac-1ab22f2ea006", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8g+/7EAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAaaElEQVR4nO3df0zU9x3H8dehctoWjiGFg6oUtdWlKsucMmZL7SQCXRqtZtHOZboYjQ6bqeuP2KzaH0tY3dI1XZgu2SZrqrYzm5qazMTSgtkGttIa41qZODZxCq4m3CEqOvnsD9PbTvHHF+94c/h8JN9E7r4fvu9+e+HpF84vPuecEwAAfSzJegAAwO2JAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABODrQe4Und3t06cOKGUlBT5fD7rcQAAHjnn1NHRoZycHCUlXfs6p98F6MSJExo5cqT1GACAW9TS0qIRI0Zc8/l+9y24lJQU6xEAADFwo6/ncQtQZWWl7r33Xg0dOlQFBQX64IMPbmod33YDgIHhRl/P4xKgt99+W6tXr9a6dev00UcfKT8/XyUlJTp16lQ8DgcASEQuDqZOnerKy8sjH1+6dMnl5OS4ioqKG64NhUJOEhsbGxtbgm+hUOi6X+9jfgV04cIFNTQ0qLi4OPJYUlKSiouLVVdXd9X+XV1dCofDURsAYOCLeYA+++wzXbp0SVlZWVGPZ2VlqbW19ar9KyoqFAgEIhvvgAOA24P5u+DWrFmjUCgU2VpaWqxHAgD0gZj/O6CMjAwNGjRIbW1tUY+3tbUpGAxetb/f75ff74/1GACAfi7mV0DJycmaPHmyqqurI491d3erurpahYWFsT4cACBBxeVOCKtXr9bChQv1la98RVOnTtVrr72mzs5Offe7343H4QAACSguAZo3b57+/e9/a+3atWptbdWXvvQl7d69+6o3JgAAbl8+55yzHuL/hcNhBQIB6zEAALcoFAopNTX1ms+bvwsOAHB7IkAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwMth4AALyYMWOG5zWbN2/u1bEefvhhz2saGxt7dazbEVdAAAATBAgAYCLmAXrhhRfk8/mitvHjx8f6MACABBeXnwE98MADevfdd/93kMH8qAkAEC0uZRg8eLCCwWA8PjUAYICIy8+Ajhw5opycHI0ePVoLFizQsWPHrrlvV1eXwuFw1AYAGPhiHqCCggJVVVVp9+7d2rBhg5qbm/XQQw+po6Ojx/0rKioUCAQi28iRI2M9EgCgH/I551w8D9De3q7c3Fy9+uqrWrx48VXPd3V1qaurK/JxOBwmQgCuiX8HlDhCoZBSU1Ov+Xzc3x2Qlpam+++/X01NTT0+7/f75ff74z0GAKCfifu/Azpz5oyOHj2q7OzseB8KAJBAYh6gp556SrW1tfrHP/6hv/zlL3r88cc1aNAgPfHEE7E+FAAggcX8W3DHjx/XE088odOnT+vuu+/Wgw8+qPr6et19992xPhQAIIHFPEBvvfVWrD/lgFBUVOR5zfDhwz2v2b59u+c1QCKZMmWK5zUffvhhHCbBreJecAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACAibj/QjpcNn36dM9r7rvvPs9ruBkpEklSkve/A+fl5Xlek5ub63mNJPl8vl6tw83hCggAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBt2H/nOd77jeU1dXV0cJgH6j+zsbM9rlixZ4nnNm2++6XmNJB0+fLhX63BzuAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAExwM9I+kpRE64Er/epXv+qT4xw5cqRPjgNv+KoIADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJjgZqS9MGnSJM9rsrKy4jAJkNgCgUCfHGfPnj19chx4wxUQAMAEAQIAmPAcoL179+qxxx5TTk6OfD6fduzYEfW8c05r165Vdna2hg0bpuLiYn4XBwDgKp4D1NnZqfz8fFVWVvb4/Pr16/X6669r48aN2rdvn+68806VlJTo/PnztzwsAGDg8PwmhLKyMpWVlfX4nHNOr732mn74wx9q1qxZkqQ33nhDWVlZ2rFjh+bPn39r0wIABoyY/gyoublZra2tKi4ujjwWCARUUFCgurq6Htd0dXUpHA5HbQCAgS+mAWptbZV09VuOs7KyIs9dqaKiQoFAILKNHDkyliMBAPop83fBrVmzRqFQKLK1tLRYjwQA6AMxDVAwGJQktbW1RT3e1tYWee5Kfr9fqampURsAYOCLaYDy8vIUDAZVXV0deSwcDmvfvn0qLCyM5aEAAAnO87vgzpw5o6ampsjHzc3NOnDggNLT0zVq1CitXLlSP/rRj3TfffcpLy9Pzz//vHJycjR79uxYzg0ASHCeA7R//3498sgjkY9Xr14tSVq4cKGqqqr0zDPPqLOzU0uXLlV7e7sefPBB7d69W0OHDo3d1ACAhOc5QNOnT5dz7prP+3w+vfTSS3rppZduabD+7NFHH/W8ZtiwYXGYBOg/enPD3by8vDhMcrV//etffXIceGP+LjgAwO2JAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJjzfDRvSuHHj+uQ4f/3rX/vkOEAs/PSnP/W8pjd30P7b3/7meU1HR4fnNYg/roAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABPcjLQf+/DDD61HQD+SmprqeU1paWmvjvXtb3/b85qZM2f26lhevfzyy57XtLe3x34Q3DKugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMtB9LT0+3HiHm8vPzPa/x+Xye1xQXF3teI0kjRozwvCY5OdnzmgULFnhek5Tk/e+L586d87xGkvbt2+d5TVdXl+c1gwd7/xLU0NDgeQ36J66AAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAAT3Iy0F3pzg0fnnOc1Gzdu9Lzmueee87ymL02aNMnzmt7cjPQ///mP5zWSdPbsWc9rPvnkE89rfvOb33hes3//fs9ramtrPa+RpLa2Ns9rjh8/7nnNsGHDPK85fPiw5zXon7gCAgCYIEAAABOeA7R371499thjysnJkc/n044dO6KeX7RokXw+X9RWWloaq3kBAAOE5wB1dnYqPz9flZWV19yntLRUJ0+ejGxbt269pSEBAAOP5zchlJWVqays7Lr7+P1+BYPBXg8FABj44vIzoJqaGmVmZmrcuHFavny5Tp8+fc19u7q6FA6HozYAwMAX8wCVlpbqjTfeUHV1tV555RXV1taqrKxMly5d6nH/iooKBQKByDZy5MhYjwQA6Idi/u+A5s+fH/nzxIkTNWnSJI0ZM0Y1NTWaMWPGVfuvWbNGq1evjnwcDoeJEADcBuL+NuzRo0crIyNDTU1NPT7v9/uVmpoatQEABr64B+j48eM6ffq0srOz430oAEAC8fwtuDNnzkRdzTQ3N+vAgQNKT09Xenq6XnzxRc2dO1fBYFBHjx7VM888o7Fjx6qkpCSmgwMAEpvnAO3fv1+PPPJI5OPPf36zcOFCbdiwQQcPHtRvf/tbtbe3KycnRzNnztTLL78sv98fu6kBAAnP53pzl8w4CofDCgQC1mPE3LPPPut5zde+9rU4TJJ4rrzbxs349NNPe3Ws+vr6Xq0baJYuXep5TW9unvv3v//d85qxY8d6XgMboVDouj/X515wAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMBHzX8mNnr3yyivWIwA3bcaMGX1ynN///vd9chz0T1wBAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmuBkpADPbt2+3HgGGuAICAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJgYbD0AgIHB5/N5XnP//fd7XlNfX+95DfonroAAACYIEADAhKcAVVRUaMqUKUpJSVFmZqZmz56txsbGqH3Onz+v8vJyDR8+XHfddZfmzp2rtra2mA4NAEh8ngJUW1ur8vJy1dfXa8+ePbp48aJmzpypzs7OyD6rVq3SO++8o23btqm2tlYnTpzQnDlzYj44ACCxeXoTwu7du6M+rqqqUmZmphoaGlRUVKRQKKRf//rX2rJli77+9a9LkjZt2qQvfvGLqq+v11e/+tXYTQ4ASGi39DOgUCgkSUpPT5ckNTQ06OLFiyouLo7sM378eI0aNUp1dXU9fo6uri6Fw+GoDQAw8PU6QN3d3Vq5cqWmTZumCRMmSJJaW1uVnJystLS0qH2zsrLU2tra4+epqKhQIBCIbCNHjuztSACABNLrAJWXl+vQoUN66623bmmANWvWKBQKRbaWlpZb+nwAgMTQq3+IumLFCu3atUt79+7ViBEjIo8Hg0FduHBB7e3tUVdBbW1tCgaDPX4uv98vv9/fmzEAAAnM0xWQc04rVqzQ9u3b9d577ykvLy/q+cmTJ2vIkCGqrq6OPNbY2Khjx46psLAwNhMDAAYET1dA5eXl2rJli3bu3KmUlJTIz3UCgYCGDRumQCCgxYsXa/Xq1UpPT1dqaqqefPJJFRYW8g44AEAUTwHasGGDJGn69OlRj2/atEmLFi2SJP3sZz9TUlKS5s6dq66uLpWUlOgXv/hFTIYFAAwcngLknLvhPkOHDlVlZaUqKyt7PRSAxHMzXx+ulJTE3cBuZ/zfBwCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgIle/UZUAIiF3vyiyqqqqtgPAhNcAQEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJrgZKYCY8Pl81iMgwXAFBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GakAK7yxz/+0fOab37zm3GYBAMZV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAmfc85ZD/H/wuGwAoGA9RgAgFsUCoWUmpp6zee5AgIAmCBAAAATngJUUVGhKVOmKCUlRZmZmZo9e7YaGxuj9pk+fbp8Pl/UtmzZspgODQBIfJ4CVFtbq/LyctXX12vPnj26ePGiZs6cqc7Ozqj9lixZopMnT0a29evXx3RoAEDi8/QbUXfv3h31cVVVlTIzM9XQ0KCioqLI43fccYeCwWBsJgQADEi39DOgUCgkSUpPT496fPPmzcrIyNCECRO0Zs0anT179pqfo6urS+FwOGoDANwGXC9dunTJfeMb33DTpk2LevyXv/yl2717tzt48KB788033T333OMef/zxa36edevWOUlsbGxsbANsC4VC1+1IrwO0bNkyl5ub61paWq67X3V1tZPkmpqaenz+/PnzLhQKRbaWlhbzk8bGxsbGduvbjQLk6WdAn1uxYoV27dqlvXv3asSIEdfdt6CgQJLU1NSkMWPGXPW83++X3+/vzRgAgATmKUDOOT355JPavn27ampqlJeXd8M1Bw4ckCRlZ2f3akAAwMDkKUDl5eXasmWLdu7cqZSUFLW2tkqSAoGAhg0bpqNHj2rLli169NFHNXz4cB08eFCrVq1SUVGRJk2aFJf/AABAgvLycx9d4/t8mzZtcs45d+zYMVdUVOTS09Od3+93Y8eOdU8//fQNvw/4/0KhkPn3LdnY2NjYbn270dd+bkYKAIgLbkYKAOiXCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAm+l2AnHPWIwAAYuBGX8/7XYA6OjqsRwAAxMCNvp77XD+75Oju7taJEyeUkpIin88X9Vw4HNbIkSPV0tKi1NRUowntcR4u4zxcxnm4jPNwWX84D845dXR0KCcnR0lJ177OGdyHM92UpKQkjRgx4rr7pKam3tYvsM9xHi7jPFzGebiM83CZ9XkIBAI33KfffQsOAHB7IEAAABMJFSC/369169bJ7/dbj2KK83AZ5+EyzsNlnIfLEuk89Ls3IQAAbg8JdQUEABg4CBAAwAQBAgCYIEAAABMJE6DKykrde++9Gjp0qAoKCvTBBx9Yj9TnXnjhBfl8vqht/Pjx1mPF3d69e/XYY48pJydHPp9PO3bsiHreOae1a9cqOztbw4YNU3FxsY4cOWIzbBzd6DwsWrToqtdHaWmpzbBxUlFRoSlTpiglJUWZmZmaPXu2Ghsbo/Y5f/68ysvLNXz4cN11112aO3eu2trajCaOj5s5D9OnT7/q9bBs2TKjiXuWEAF6++23tXr1aq1bt04fffSR8vPzVVJSolOnTlmP1uceeOABnTx5MrL96U9/sh4p7jo7O5Wfn6/Kysoen1+/fr1ef/11bdy4Ufv27dOdd96pkpISnT9/vo8nja8bnQdJKi0tjXp9bN26tQ8njL/a2lqVl5ervr5ee/bs0cWLFzVz5kx1dnZG9lm1apXeeecdbdu2TbW1tTpx4oTmzJljOHXs3cx5kKQlS5ZEvR7Wr19vNPE1uAQwdepUV15eHvn40qVLLicnx1VUVBhO1ffWrVvn8vPzrccwJclt37498nF3d7cLBoPuJz/5SeSx9vZ25/f73datWw0m7BtXngfnnFu4cKGbNWuWyTxWTp065SS52tpa59zl//dDhgxx27Zti+zz6aefOkmurq7Oasy4u/I8OOfcww8/7L7//e/bDXUT+v0V0IULF9TQ0KDi4uLIY0lJSSouLlZdXZ3hZDaOHDminJwcjR49WgsWLNCxY8esRzLV3Nys1tbWqNdHIBBQQUHBbfn6qKmpUWZmpsaNG6fly5fr9OnT1iPFVSgUkiSlp6dLkhoaGnTx4sWo18P48eM1atSoAf16uPI8fG7z5s3KyMjQhAkTtGbNGp09e9ZivGvqdzcjvdJnn32mS5cuKSsrK+rxrKwsHT582GgqGwUFBaqqqtK4ceN08uRJvfjii3rooYd06NAhpaSkWI9norW1VZJ6fH18/tztorS0VHPmzFFeXp6OHj2q5557TmVlZaqrq9OgQYOsx4u57u5urVy5UtOmTdOECRMkXX49JCcnKy0tLWrfgfx66Ok8SNK3vvUt5ebmKicnRwcPHtSzzz6rxsZG/eEPfzCcNlq/DxD+p6ysLPLnSZMmqaCgQLm5ufrd736nxYsXG06G/mD+/PmRP0+cOFGTJk3SmDFjVFNToxkzZhhOFh/l5eU6dOjQbfFz0Ou51nlYunRp5M8TJ05Udna2ZsyYoaNHj2rMmDF9PWaP+v234DIyMjRo0KCr3sXS1tamYDBoNFX/kJaWpvvvv19NTU3Wo5j5/DXA6+Nqo0ePVkZGxoB8faxYsUK7du3S+++/H/XrW4LBoC5cuKD29vao/Qfq6+Fa56EnBQUFktSvXg/9PkDJycmaPHmyqqurI491d3erurpahYWFhpPZO3PmjI4ePars7GzrUczk5eUpGAxGvT7C4bD27dt3278+jh8/rtOnTw+o14dzTitWrND27dv13nvvKS8vL+r5yZMna8iQIVGvh8bGRh07dmxAvR5udB56cuDAAUnqX68H63dB3Iy33nrL+f1+V1VV5T755BO3dOlSl5aW5lpbW61H61M/+MEPXE1NjWtubnZ//vOfXXFxscvIyHCnTp2yHi2uOjo63Mcff+w+/vhjJ8m9+uqr7uOPP3b//Oc/nXPO/fjHP3ZpaWlu586d7uDBg27WrFkuLy/PnTt3znjy2Lreeejo6HBPPfWUq6urc83Nze7dd991X/7yl919993nzp8/bz16zCxfvtwFAgFXU1PjTp48GdnOnj0b2WfZsmVu1KhR7r333nP79+93hYWFrrCw0HDq2LvReWhqanIvvfSS279/v2tubnY7d+50o0ePdkVFRcaTR0uIADnn3M9//nM3atQol5yc7KZOnerq6+utR+pz8+bNc9nZ2S45Odndc889bt68ea6pqcl6rLh7//33naSrtoULFzrnLr8V+/nnn3dZWVnO7/e7GTNmuMbGRtuh4+B65+Hs2bNu5syZ7u6773ZDhgxxubm5bsmSJQPuL2k9/fdLcps2bYrsc+7cOfe9733PfeELX3B33HGHe/zxx93Jkyftho6DG52HY8eOuaKiIpeenu78fr8bO3ase/rpp10oFLId/Ar8OgYAgIl+/zMgAMDARIAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY+C9JPEvo0+q40gAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import numpy\n", "import matplotlib.pyplot as plt\n", "img = (train_mnist[2][0]*255.0).numpy().astype(numpy.uint8).transpose(1, 2, 0)\n", "plt.imshow(img, cmap=\"gray\")" ] }, { "cell_type": "code", "execution_count": 5, "id": "6a91c959-2846-4fe6-bf21-c15d53cf05ed", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ " train_mnist[2][1]" ] }, { "cell_type": "markdown", "id": "f2a8da03-4090-465e-8058-23a775e3ec86", "metadata": {}, "source": [ "## 1.2. 数据集切分为批次集" ] }, { "cell_type": "code", "execution_count": 6, "id": "77064990-70e3-4c2c-aaf4-78d8927d5766", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "train_loader = DataLoader(train_mnist, batch_size=10000, shuffle=True)\n", "valid_loader = DataLoader(valid_mnist, batch_size=10000, shuffle=False)" ] }, { "cell_type": "code", "execution_count": 7, "id": "2426a1e9-0ebd-4396-a8ad-68bb64a5de1b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([10000, 1, 28, 28]) 10000\n", "torch.Size([10000, 1, 28, 28]) 10000\n", "torch.Size([10000, 1, 28, 28]) 10000\n", "torch.Size([10000, 1, 28, 28]) 10000\n", "torch.Size([10000, 1, 28, 28]) 10000\n", "torch.Size([10000, 1, 28, 28]) 10000\n" ] } ], "source": [ "for x, y in train_loader:\n", " print(x.shape, len(y))" ] }, { "cell_type": "markdown", "id": "6f0a164c-07b3-482c-8474-c9f742d9679b", "metadata": {}, "source": [] }, { "cell_type": "markdown", "id": "780cdb49-2162-4048-8000-b639eeafea1e", "metadata": {}, "source": [ "- `N C H W`" ] }, { "cell_type": "markdown", "id": "d9c25787-2927-4b47-b041-2d13d7265b07", "metadata": {}, "source": [ "# 3. 预测模型" ] }, { "cell_type": "code", "execution_count": 8, "id": "53d0e93b-aa26-48bd-b21d-c83ae43e000d", "metadata": {}, "outputs": [], "source": [ "class LeNet5(torch.nn.Module):\n", " def __init__(self):\n", " super(LeNet5, self).__init__()\n", " # 第一层\n", " # 卷积层1 :1 @ 28 * 28 - > 6 @ 28 * 28 -> 6 @ 14 * 14 \n", " self.layer_1 = torch.nn.Conv2d(in_channels=1, out_channels=6, kernel_size=(5, 5), padding=2) \n", " self.act_1 = torch.nn.ReLU() # y = x (<0全部取0)\n", " self.pool_1 = torch.nn.MaxPool2d((2, 2))\n", " # 第二层\n", " # 卷积层2 :6 @ 14 * 14 -> 16 @ 10 * 10 -> 16 @ 5 * 5\n", " self.layer_2 = torch.nn.Conv2d(in_channels=6, out_channels=16, kernel_size=(5, 5), padding=0)\n", " self.act_2 = torch.nn.ReLU() # y = x (<0全部取0)\n", " self.pool_2 = torch.nn.MaxPool2d((2, 2))\n", " # 第三层\n", " # 卷积层3 :16 @ 5 * 5 -> 120 @ 1 * 1\n", " self.layer_3 = torch.nn.Conv2d(in_channels=16, out_channels=120, kernel_size=(5, 5), padding=0)\n", " self.act_3 = torch.nn.ReLU() # y = x (<0全部取0)\n", " # 第四层\n", " self.layer_4 = torch.nn.Linear(120, 84)\n", " self.act_4 = torch.nn.ReLU() # y = x (<0全部取0)\n", " # 第五层\n", " self.layer_5 = torch.nn.Linear(84, 10)\n", "\n", " def forward(self, x):\n", " y = self.layer_1(x)\n", " y = self.act_1(y)\n", " y = self.pool_1(y)\n", "\n", " y = self.layer_2(y)\n", " y = self.act_2(y)\n", " y = self.pool_2(y)\n", "\n", " y = self.layer_3(y)\n", " y = self.act_3(y)\n", "\n", " # 把120 * 1 * 1变成120的向量\n", " y = y.squeeze() # 长度为1的维度都删除\n", "\n", " y = self.layer_4(y)\n", " y = self.act_4(y)\n", "\n", " y = self.layer_5(y)\n", " y = torch.nn.functional.log_softmax(y, dim=1) # 概率\n", " \n", " return y" ] }, { "cell_type": "markdown", "id": "cccb867f-56b1-4626-aa41-9e291229addf", "metadata": {}, "source": [ "# 4. 训练" ] }, { "cell_type": "code", "execution_count": null, "id": "0a0268d8-87c8-4fb8-9ab0-44a5ceff438d", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "第00轮, \t损失度:2.279487,\t准确率: 23.13%\n", "第01轮, \t损失度:2.198344,\t准确率: 59.50%\n", "第02轮, \t损失度:1.990428,\t准确率: 56.98%\n", "第03轮, \t损失度:1.583153,\t准确率: 59.37%\n", "第04轮, \t损失度:1.095197,\t准确率: 76.22%\n", "第05轮, \t损失度:0.725260,\t准确率: 80.17%\n", "第06轮, \t损失度:0.567638,\t准确率: 83.46%\n", "第07轮, \t损失度:0.484717,\t准确率: 85.75%\n", "第08轮, \t损失度:0.430051,\t准确率: 87.58%\n" ] } ], "source": [ "# 模型\n", "model = LeNet5()\n", "\n", "# model=model.cuda()\n", "\n", "# 损失函数\n", "criterion = torch.nn.CrossEntropyLoss()\n", "criterion = criterion.cuda()\n", "# 梯度更新(优化器)\n", "optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 学习率\n", "\n", "# 训练参数\n", "epoch = 100\n", "for e in range(epoch):\n", " for x, y in train_loader:\n", " # x = x.cuda()\n", " # y = x.cuda()\n", " optimizer.zero_grad() \n", " y_ = model(x)\n", " loss = criterion(y_, y)\n", " loss.backward()\n", " optimizer.step() # \n", " # 一轮结束,可以使用测试集测试准确率\n", " if e % 1 == 0:\n", " with torch.no_grad(): # 关闭梯度计算跟踪\n", " for v_x, v_y in valid_loader:\n", " # v_x = v_x.cuda()\n", " # v_y = v_y.cuda()\n", " v_y_ = model(v_x)\n", " predict = torch.argmax(v_y_, dim=1)\n", " correct_rate = (predict == v_y).float().mean()\n", " print(F\"第{e:02d}轮, \\t损失度:{loss:8.6f},\\t准确率:{correct_rate * 100: 5.2f}%\")\n", " " ] }, { "cell_type": "code", "execution_count": null, "id": "1bec9a58-edcb-448c-9a0c-9ad6ca614db8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "bbdfc2e0-89cb-49ec-9921-8c8542411528", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "6622b807-1532-45c3-b05b-c059f24307cf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.7" } }, "nbformat": 4, "nbformat_minor": 5 }