spidermanYT 1 jaar geleden
bovenliggende
commit
4e8edde965

BIN
day06/infer/03_1.jpg


BIN
day06/infer/04_9.jpg


BIN
day06/infer/05_2.jpg


BIN
day06/infer/__pycache__/model.cpython-39.pyc


+ 76 - 0
day06/infer/infer.py

@@ -0,0 +1,76 @@
+from model import LeNet5
+
+import torch
+
+import cv2
+
+import numpy as np
+
+
+
+class  DigitClassifier:
+
+    def __init__(self): # 初始化
+
+        super(DigitClassifier, self).__init__()
+
+        # 创建网络
+
+        self.net = LeNet5()
+
+        # 加载模型算子(训练好的模型)
+
+        state = torch.load("lenet5.pt")
+
+        self.net.load_state_dict(state)
+
+
+
+    
+
+    def recognize_file(self, digit_file): # 输入图像文件
+
+        # 1. 读取文件
+
+        img = cv2.imread(digit_file)
+
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+        # 2. 处理文件:图像转为NCHW格式的float张量
+
+        img = img.astype(np.float32) 
+
+        img = img / 255.0   # 像素转换为0-1之间的值
+
+        img = torch.from_numpy(img).clone()  # 把矩阵转换为张量
+
+        img = img.view(1, 1, 28, 28)  # 模型支持4维图像 NCHW
+
+        # 3. 调用self.net预测
+
+        y = self.net(img)
+
+        # 4. 处理预测结果:类别与概率
+
+        cls = torch.argmax(y, dim=1).item()  # item取张量的值
+
+        prob = y[0][cls].item()
+
+        print(cls, prob)
+
+        return cls,  prob# 返回类别,返回这个类别的概率
+
+
+
+
+
+
+
+if __name__ == "__main__":
+
+    classifier = DigitClassifier() # 生成分类器
+
+    cls, prob = classifier.recognize_file("05_2.jpg")
+
+    print(F"类别:{cls},概率:{prob}")
+

BIN
day06/infer/lenet5.pt


+ 149 - 0
day06/infer/model.py

@@ -0,0 +1,149 @@
+from torch.nn import Module    # 扩展该类实现我们自己的深度网络模型
+
+from torch.nn import Conv2d, Linear   # 卷积运算(特征抽取),全邻接线性运算(分类器)
+
+from torch.nn.functional import relu, max_pool2d, avg_pool2d  # relu折线函数,maxpool(从数组返回一个最大值)
+
+import torch
+
+
+
+class LeNet5(Module):
+
+    # 构造器
+
+    def __init__(self, class_num=10):  # 10手写数字的分类。一共10个类别
+
+        super(LeNet5, self).__init__()
+
+        """
+
+            5层  (28 * 28 * 1)
+
+                |- 1. 卷积5 * 5 -> (28 * 28 * 6)    -(2, 2) -> (14, 14 , 6)
+
+                |- 2. 卷积5 * 5 -> (10 * 10 * 16)   -(2, 2) -> (5, 5, 16)
+
+                |- 3. 卷积5 * 5 -> (1 * 1 * 120)   
+
+                |- 4. 全连接 120 -> 84 
+
+                |- 5. 全连接 84 - 10 (1, 0, 0, 0, 0, 0, 0, 0, 0, 0)  取概率最大的下标就是识别出来的数字
+
+        """
+
+        self.conv1 = Conv2d(in_channels=1, out_channels=6,  kernel_size=5, stride=1, padding=2)
+
+        self.conv2 = Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0)
+
+        self.conv3 = Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1,padding=0)
+
+
+
+        self.fc1 = Linear(120, 84)
+
+        self.fc2 = Linear(84, 10)
+
+
+
+
+
+    # 预测函数
+
+    def forward(self, x):
+
+        """
+
+            x表述输入图像, 格式是:[NHWC]
+
+        """
+
+        y = x
+
+        # 计算预测
+
+        # 第一层网络
+
+        y = self.conv1(y)  # (28*28*1) -> (28*28*6)
+
+        y = max_pool2d(y, (2, 2))  # (28*28*6) -> (14*14*6)
+
+        y = relu(y)        # 过滤负值(所有的负值设置为0)
+
+
+
+        # 第二层
+
+        y = self.conv2(y)   # (14*14*6) -> (10*10*16)
+
+        y = max_pool2d(y, (2, 2))   # (10*10*16) -> (5*5*16)
+
+        y = relu(y)         # 过滤负值
+
+
+
+        # 第三层
+
+        y = self.conv3(y)   # (5*5*16) -> (1*1*120)
+
+
+
+        # 把y从(1*1*120) -> (120)向量
+
+        y = y.view(-1, 120)  
+
+
+
+        # 第四层
+
+        y = self.fc1(y)
+
+        y = relu(y)
+
+
+
+        # 第五层
+
+        y = self.fc2(y)
+
+        # y = relu(y)        # 这个激活函数已经没有意义
+
+
+
+        # 把向量的分量全部转换0-1之间的值(概率)     
+
+        y = torch.softmax(y, dim=1) 
+
+
+
+        return y
+
+
+
+
+
+# print(__name__)
+
+if __name__ == "__main__":  # 表示是独立执行册程序块
+
+    # 下面代码被调用,则执行不到。
+
+    img = torch.randint(0, 256, (1, 1, 28, 28))  # 构造一个随机矩阵 == 噪音图像[NCHW]
+
+    img = img.float()    # 神经网络输入的必须是float类型
+
+    net = LeNet5()
+
+    y = net(img)
+
+    # y = net.forward(img)  # 等价于y = net(img)
+
+
+
+    # 判定最大下标
+
+    cls = torch.argmax(y, dim=1)
+
+    print(F"识别的结果是:{cls.numpy()[0]}")
+
+    print(y)

BIN
day06/lenet/03_1.jpg


BIN
day06/lenet/04_9.jpg


BIN
day06/lenet/05_2.jpg


BIN
day06/lenet/lenet5.pt