Explorar o código

上传文件至 'day6 20200404113 刘帅帅'

wifeRoy hai 1 ano
pai
achega
3f88fd7f10

+ 58 - 0
day6 20200404113 刘帅帅/lenet5.py

@@ -0,0 +1,58 @@
+from torch.nn import Module
+from torch.nn import Conv2d, Linear
+from torch.nn.functional import relu, max_pool2d
+import torch
+import torch.nn as nn
+class Lenet5(Module):
+    def __init__(self):
+        super(Lenet5, self).__init__()
+        # 第一层卷积层,输入通道为1,输出通道为6,卷积核大小为5
+        self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
+        # 最大池化层,池化核大小为2,步长为2
+        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
+        # 第二层卷积层,输入通道为6,输出通道为16,卷积核大小为5
+        self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
+        # 全连接层,输入节点数为16*4*4,输出节点数为120
+        self.fc1 = nn.Linear(16 * 4 * 4, 120)
+        # 全连接层,输入节点数为120,输出节点数为84
+        self.fc2 = nn.Linear(120, 84)
+        # 输出层,输入节点数为84,输出节点数为10
+        self.fc3 = nn.Linear(84, 10)
+
+    def forward(self, x):
+        # 第一层卷积,通过relu激活函数
+        x = self.pool(relu(self.conv1(x)))
+        # 第二层卷积,通过relu激活函数
+        x = self.pool(relu(self.conv2(x)))
+        # 展开张量,将其变成一维向量
+        x = x.view(-1, 16 * 4 * 4)
+        # 全连接层,通过relu激活函数
+        x = relu(self.fc1(x))
+        # 全连接层,通过relu激活函数
+        x = relu(self.fc2(x))
+        # 输出层,不使用激活函数
+        x = self.fc3(x)
+        return x
+
+
+
+# # 加载训练数据和测试数据
+# transform = transforms.Compose([
+#     transforms.ToTensor(),  # 转换为Tensor对象
+#     transforms.Normalize((0.5,), (0.5,))  # 归一化处理
+# ])
+# trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
+# trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
+# testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
+# testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
+#
+# # 实例化LeNet-5模型和损失函数、优化器
+# net = LeNet5()
+# criterion = nn.CrossEntropyLoss()  # 交叉熵损失函数
+# optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)  # 随机梯度下降优化器
+#
+# # 训练网络
+# for epoch in range(10):
+#     running_loss = 0.0
+#     for i, data in enumerate(trainloader, 0):
+#         inputs, labels = data

BIN=BIN
day6 20200404113 刘帅帅/学习收获.docx