123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- from torch.nn import Module # 扩展该类实现我们自己的深度网络模型
- from torch.nn import Conv2d, Linear # 卷积运算(特征抽取),全邻接线性运算(分类器)
- from torch.nn.functional import relu, max_pool2d, avg_pool2d # relu折线函数,maxpool(从数组返回一个最大值)
- import torch
- class LeNet5(Module):
- # 构造器
- def __init__(self, class_num=10): # 10手写数字的分类。一共10个类别
- super(LeNet5, self).__init__()
- """
- 5层 (28 * 28 * 1)
- |- 1. 卷积5 * 5 -> (28 * 28 * 6) -(2, 2) -> (14, 14 , 6)
- |- 2. 卷积5 * 5 -> (10 * 10 * 16) -(2, 2) -> (5, 5, 16)
- |- 3. 卷积5 * 5 -> (1 * 1 * 120)
- |- 4. 全连接 120 -> 84
- |- 5. 全连接 84 - 10 (1, 0, 0, 0, 0, 0, 0, 0, 0, 0) 取概率最大的下标就是识别出来的数字
- """
- self.conv1 = Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=2)
- self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
- self.conv3 = Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1,padding=0)
- self.fc1 = Linear(120, 84)
- self.fc2 = Linear(84, 10)
- # 预测函数
- def forward(self, x):
- """
- x表述输入图像, 格式是:[NHWC]
- """
- y = x
- # 计算预测
- # 第一层网络
- y = self.conv1(y) # (28*28*1) -> (28*28*6)
- y = max_pool2d(y, (2, 2)) # (28*28*6) -> (14*14*6)
- y = relu(y) # 过滤负值(所有的负值设置为0)
- # 第二层
- y = self.conv2(y) # (14*14*6) -> (10*10*16)
- y = max_pool2d(y, (2, 2)) # (10*10*16) -> (5*5*16)
- y = relu(y) # 过滤负值
- # 第三层
- y = self.conv3(y) # (5*5*16) -> (1*1*120)
- # 把y从(1*1*120) -> (120)向量
- y = y.view(-1, 120)
- # 第四层
- y = self.fc1(y)
- y = relu(y)
- # 第五层
- y = self.fc2(y)
- # y = relu(y) # 这个激活函数已经没有意义
- # 把向量的分量全部转换0-1之间的值(概率)
- y = torch.softmax(y, dim=1)
- return y
- # print(__name__)
- if __name__ == "__main__": # 表示是独立执行册程序块
- # 下面代码被调用,则执行不到。
- img = torch.randint(0, 256, (1, 1, 28, 28)) # 构造一个随机矩阵 == 噪音图像[NCHW]
- img = img.float() # 神经网络输入的必须是float类型
- net = LeNet5()
- y = net(img)
- # y = net.forward(img) # 等价于y = net(img)
- # 判定最大下标
- cls = torch.argmax(y, dim=1)
- print(F"识别的结果是:{cls.numpy()[0]}")
- print(y)
|