model.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. from torch.nn import Module # 扩展该类实现我们自己的深度网络模型
  2. from torch.nn import Conv2d, Linear # 卷积运算(特征抽取),全邻接线性运算(分类器)
  3. from torch.nn.functional import relu, max_pool2d, avg_pool2d # relu折线函数,maxpool(从数组返回一个最大值)
  4. import torch
  5. class LeNet5(Module):
  6. # 构造器
  7. def __init__(self, class_num=10): # 10手写数字的分类。一共10个类别
  8. super(LeNet5, self).__init__()
  9. """
  10. 5层 (28 * 28 * 1)
  11. |- 1. 卷积5 * 5 -> (28 * 28 * 6) -(2, 2) -> (14, 14 , 6)
  12. |- 2. 卷积5 * 5 -> (10 * 10 * 16) -(2, 2) -> (5, 5, 16)
  13. |- 3. 卷积5 * 5 -> (1 * 1 * 120)
  14. |- 4. 全连接 120 -> 84
  15. |- 5. 全连接 84 - 10 (1, 0, 0, 0, 0, 0, 0, 0, 0, 0) 取概率最大的下标就是识别出来的数字
  16. """
  17. self.conv1 = Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
  18. self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
  19. self.conv3 = Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1,padding=0)
  20. self.fc1 = Linear(120, 84)
  21. self.fc2 = Linear(84, 10)
  22. # 预测函数
  23. def forward(self, x):
  24. """
  25. x表述输入图像, 格式是:[NHWC]
  26. """
  27. y = x
  28. # 计算预测
  29. # 第一层网络
  30. y = self.conv1(y) # (28*28*1) -> (28*28*6)
  31. y = max_pool2d(y, (2, 2)) # (28*28*6) -> (14*14*6)
  32. y = relu(y) # 过滤负值(所有的负值设置为0)
  33. # 第二层
  34. y = self.conv2(y) # (14*14*6) -> (10*10*16)
  35. y = max_pool2d(y, (2, 2)) # (10*10*16) -> (5*5*16)
  36. y = relu(y) # 过滤负值
  37. # 第三层
  38. y = self.conv3(y) # (5*5*16) -> (1*1*120)
  39. # 把y从(1*1*120) -> (120)向量
  40. y = y.view(-1, 120)
  41. # 第四层
  42. y = self.fc1(y)
  43. y = relu(y)
  44. # 第五层
  45. y = self.fc2(y)
  46. # y = relu(y) # 这个激活函数已经没有意义
  47. # 把向量的分量全部转换0-1之间的值(概率)
  48. y = torch.softmax(y, dim=1)
  49. return y
  50. # print(__name__)
  51. if __name__ == "__main__": # 表示是独立执行册程序块
  52. # 下面代码被调用,则执行不到。
  53. img = torch.randint(0, 256, (1, 1, 28, 28)) # 构造一个随机矩阵 == 噪音图像[NCHW]
  54. img = img.float() # 神经网络输入的必须是float类型
  55. net = LeNet5()
  56. y = net(img)
  57. # y = net.forward(img) # 等价于y = net(img)
  58. # 判定最大下标
  59. cls = torch.argmax(y, dim=1)
  60. print(F"识别的结果是:{cls.numpy()[0]}")
  61. print(y)