DigitModule.py 756 B

123456789101112131415161718192021
  1. import torch.nn as nn # 神经网络的层的实现:卷积层
  2. import torch.nn.functional as fu
  3. class LeNet(nn.Module):
  4. def __init__(self, cls_num=10):
  5. super(LeNet, self).__init__()
  6. self.conv1 = nn.Conv2d(1, 6, 5, padding=2)
  7. self.conv2 = nn.Conv2d(6, 16, 5)
  8. self.fc1 = nn.Linear(16 * 5 * 5, 120)
  9. self.fc2 = nn.Linear(120, 84)
  10. self.fc3 = nn.Linear(84, cls_num)
  11. def forward(self, x):
  12. y = fu.max_pool2d(fu.relu(self.conv1(x)), (2, 2))
  13. y = fu.max_pool2d(fu.relu(self.conv2(y)), (2, 2))
  14. # 格式转换
  15. y = y.view(y.size()[0], -1)
  16. y = fu.relu(self.fc1(y))
  17. y = fu.relu(self.fc2(y))
  18. y = self.fc3(y)
  19. return y