lenet5.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from torch.nn import Module
  2. from torch.nn import Conv2d, Linear
  3. from torch.nn.functional import relu, max_pool2d
  4. import torch
  5. import torch.nn as nn
  6. class Lenet5(Module):
  7. def __init__(self):
  8. super(Lenet5, self).__init__()
  9. # 第一层卷积层,输入通道为1,输出通道为6,卷积核大小为5
  10. self.conv1 = nn.Conv2d(1, 6, kernel_size=5)
  11. # 最大池化层,池化核大小为2,步长为2
  12. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  13. # 第二层卷积层,输入通道为6,输出通道为16,卷积核大小为5
  14. self.conv2 = nn.Conv2d(6, 16, kernel_size=5)
  15. # 全连接层,输入节点数为16*4*4,输出节点数为120
  16. self.fc1 = nn.Linear(16 * 4 * 4, 120)
  17. # 全连接层,输入节点数为120,输出节点数为84
  18. self.fc2 = nn.Linear(120, 84)
  19. # 输出层,输入节点数为84,输出节点数为10
  20. self.fc3 = nn.Linear(84, 10)
  21. def forward(self, x):
  22. # 第一层卷积,通过relu激活函数
  23. x = self.pool(relu(self.conv1(x)))
  24. # 第二层卷积,通过relu激活函数
  25. x = self.pool(relu(self.conv2(x)))
  26. # 展开张量,将其变成一维向量
  27. x = x.view(-1, 16 * 4 * 4)
  28. # 全连接层,通过relu激活函数
  29. x = relu(self.fc1(x))
  30. # 全连接层,通过relu激活函数
  31. x = relu(self.fc2(x))
  32. # 输出层,不使用激活函数
  33. x = self.fc3(x)
  34. return x
  35. # # 加载训练数据和测试数据
  36. # transform = transforms.Compose([
  37. # transforms.ToTensor(), # 转换为Tensor对象
  38. # transforms.Normalize((0.5,), (0.5,)) # 归一化处理
  39. # ])
  40. # trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
  41. # trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
  42. # testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
  43. # testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False)
  44. #
  45. # # 实例化LeNet-5模型和损失函数、优化器
  46. # net = LeNet5()
  47. # criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
  48. # optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) # 随机梯度下降优化器
  49. #
  50. # # 训练网络
  51. # for epoch in range(10):
  52. # running_loss = 0.0
  53. # for i, data in enumerate(trainloader, 0):
  54. # inputs, labels = data