|
@@ -0,0 +1,75 @@
|
|
|
+from torch.nn import Module # 扩展该类实现我们自己的深度网络模型
|
|
|
+from torch.nn import Conv2d, Linear # 卷积运算(特征抽取),全邻接线性运算(分类器)
|
|
|
+from torch.nn.functional import relu, max_pool2d, avg_pool2d # relu折线函数,maxpool(从数组返回一个最大值)
|
|
|
+import torch
|
|
|
+
|
|
|
+class LeNet5(Module):
|
|
|
+ # 构造器
|
|
|
+ def __init__(self, class_num=10): # 10手写数字的分类。一共10个类别
|
|
|
+ super(LeNet5, self).__init__()
|
|
|
+ """
|
|
|
+ 5层 (28 * 28 * 1)
|
|
|
+ |- 1. 卷积5 * 5 -> (28 * 28 * 6) -(2, 2) -> (14, 14 , 6)
|
|
|
+ |- 2. 卷积5 * 5 -> (10 * 10 * 16) -(2, 2) -> (5, 5, 16)
|
|
|
+ |- 3. 卷积5 * 5 -> (1 * 1 * 120)
|
|
|
+ |- 4. 全连接 120 -> 84
|
|
|
+ |- 5. 全连接 84 - 10 (1, 0, 0, 0, 0, 0, 0, 0, 0, 0) 取概率最大的下标就是识别出来的数字
|
|
|
+ """
|
|
|
+ self.conv1 = Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
|
|
|
+ self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
|
|
|
+ self.conv3 = Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1,padding=0)
|
|
|
+
|
|
|
+ self.fc1 = Linear(120, 84)
|
|
|
+ self.fc2 = Linear(84, 10)
|
|
|
+
|
|
|
+
|
|
|
+ # 预测函数
|
|
|
+ def forward(self, x):
|
|
|
+ """
|
|
|
+ x表述输入图像, 格式是:[NHWC]
|
|
|
+ """
|
|
|
+ y = x
|
|
|
+ # 计算预测
|
|
|
+ # 第一层网络
|
|
|
+ y = self.conv1(y) # (28*28*1) -> (28*28*6)
|
|
|
+ y = max_pool2d(y, (2, 2)) # (28*28*6) -> (14*14*6)
|
|
|
+ y = relu(y) # 过滤负值(所有的负值设置为0)
|
|
|
+
|
|
|
+ # 第二层
|
|
|
+ y = self.conv2(y) # (14*14*6) -> (10*10*16)
|
|
|
+ y = max_pool2d(y, (2, 2)) # (10*10*16) -> (5*5*16)
|
|
|
+ y = relu(y) # 过滤负值
|
|
|
+
|
|
|
+ # 第三层
|
|
|
+ y = self.conv3(y) # (5*5*16) -> (1*1*120)
|
|
|
+
|
|
|
+ # 把y从(1*1*120) -> (120)向量
|
|
|
+ y = y.view(-1, 120)
|
|
|
+
|
|
|
+ # 第四层
|
|
|
+ y = self.fc1(y)
|
|
|
+ y = relu(y)
|
|
|
+
|
|
|
+ # 第五层
|
|
|
+ y = self.fc2(y)
|
|
|
+ # y = relu(y) # 这个激活函数已经没有意义
|
|
|
+
|
|
|
+ # 把向量的分量全部转换0-1之间的值(概率)
|
|
|
+ y = torch.softmax(y, dim=1)
|
|
|
+
|
|
|
+ return y
|
|
|
+
|
|
|
+
|
|
|
+# print(__name__)
|
|
|
+if __name__ == "__main__": # 表示是独立执行册程序块
|
|
|
+ # 下面代码被调用,则执行不到。
|
|
|
+ img = torch.randint(0, 256, (1, 1, 28, 28)) # 构造一个随机矩阵 == 噪音图像[NCHW]
|
|
|
+ img = img.float() # 神经网络输入的必须是float类型
|
|
|
+ net = LeNet5()
|
|
|
+ y = net(img)
|
|
|
+ # y = net.forward(img) # 等价于y = net(img)
|
|
|
+
|
|
|
+ # 判定最大下标
|
|
|
+ cls = torch.argmax(y, dim=1)
|
|
|
+ print(F"识别的结果是:{cls.numpy()[0]}")
|
|
|
+ print(y)
|