Просмотр исходного кода

上传文件至 'Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp'

wang 1 год назад
Родитель
Сommit
8d44e24006

+ 67 - 0
Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp/DigitAI.py

@@ -0,0 +1,67 @@
+import cv2
+import numpy as np
+import torch
+from .DigitModule import LeNet
+import os
+
+cur_dir = os.path.dirname(__file__)
+mod_file = os.path.join(cur_dir,"data/models.lenet")
+
+class DigitRecognizier:
+    def __init__(self):
+        super(DigitRecognizier, self).__init__()
+        self.CUDA = torch.cuda.is_available()
+        self.net = LeNet(10)
+        if self.CUDA:
+            self.net.cuda()
+        state = torch.load(mod_file)
+        self.net.load_state_dict(state)
+    
+
+    def pre_image(self, img):
+        # 大小
+        img = cv2.resize(img, (28,28))
+        # 灰度
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        # 小数类型
+        img = img.astype("float32")
+        # 逆转(F-> B, B->F)
+        img = 255.0 -img
+        # 去噪
+        img[img <= 150] = 0
+        cv2.imwrite("g.png", img.astype("uint8"))
+        # 转换为张量
+        img = torch.from_numpy(img).clone()
+        # 转换为N C H W
+        
+        return img
+
+    
+    def recognize(self, img):
+        result = []  # (数字,概率)
+        # 根据模型计算输出
+        p_img = self.pre_image(img)
+        if self.CUDA:
+            p_img = p_img.cuda()
+
+        predict = self.net.forward(p_img.view(1, 1, 28, 28))
+
+        pred_prob = torch.nn.functional.softmax(predict, dim=1)
+        # 计算在gpu,速度快
+        pred_prob = pred_prob[0]
+        # pred_prob = torch.squeeze(pred_prob, 0)
+        # pred_prob = pred_prob.view((pred_prob.shape[1]))
+        # 找出最大概率及其下标,判定概率
+        top1 = torch.argmax(pred_prob)
+        pro1 = pred_prob[top1]
+        result.append((top1.cpu().detach().numpy(), pro1.cpu().detach().numpy()))
+        if pro1 < 1.0:
+            # 把top1置为0,再找最大值
+            pred_prob[top1] = 0.0
+            top2 = torch.argmax(pred_prob)
+            pro2 = pred_prob[top2]
+            result.append((top2.cpu().detach().numpy(), pro2.cpu().detach().numpy()))
+
+
+        return result # 返回长度为10的概率向量
+

+ 17 - 0
Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp/DigitApp.py

@@ -0,0 +1,17 @@
+"""
+"""
+from PyQt5.QtWidgets import QApplication
+from .DigitForm import DigitForm
+import sys 
+
+class DigitApp(QApplication):
+    """
+    """
+    def __init__(self):
+        """
+        """
+        super(DigitApp, self).__init__(sys.argv)
+        # 创建应用主窗体
+        self.dlg = DigitForm()
+        self.dlg.show()
+    

+ 36 - 0
Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp/DigitDev.py

@@ -0,0 +1,36 @@
+from PyQt5.QtCore import QThread, pyqtSignal
+import cv2
+
+class DigitDev(QThread):
+    
+    signal_video = pyqtSignal(int, int, int, bytes)
+
+    def __init__(self):
+        super(DigitDev, self).__init__()
+        self.is_over = False
+        # 初始化设备
+        self.dev = cv2.VideoCapture(0, cv2.CAP_DSHOW)
+
+
+    def run(self):
+        while not self.is_over:
+            # 图像抓取
+            status, image = self.dev.read()
+            # 状态判定
+            if not status:
+                self.dev.release()
+                self.exit(0)
+            # 如果抓取图像成功,发送
+            shape = image.shape
+            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+            self.signal_video.emit(shape[0], shape[1], shape[2], image.tobytes())
+            # 视觉暂停
+            QThread.usleep(100000)
+
+    def close(self):
+        self.is_over = True
+        while self.isRunning():
+            pass
+        # 释放设备
+        if self.dev.isOpened():
+            self.dev.release()

+ 84 - 0
Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp/DigitForm.py

@@ -0,0 +1,84 @@
+from PyQt5.QtWidgets import QDialog
+from PyQt5.QtGui import QImage, QPixmap
+import sys
+from .DigitUI import Ui_Digit
+from .DigitDev import DigitDev
+from .DigitAI import DigitRecognizier
+import cv2
+import numpy as np
+
+class DigitForm(QDialog):
+    def __init__(self):
+        super(DigitForm, self).__init__()
+        # 加载UI(先设计好)
+        # 创建对象
+        self.ui = Ui_Digit()
+        # 使用setupUi绑定对话框(父窗体)
+        self.ui.setupUi(self)
+        # AI识别对象
+        self.reco = DigitRecognizier()
+        # 创建视频对象
+        self.dev = DigitDev()
+        self.dev.signal_video.connect(self.show_video)
+        self.dev.start()
+
+
+
+    # 覆盖QDialog原来两个我们不需要的默认功能
+    def keyPressEvent(self, e):
+        pass
+
+
+    def closeEvent(self, e):
+        # 完成一些需要的释放工作
+        self.dev.close()
+        sys.exit(0)
+
+    # UI中的两个槽函数
+    def capture_image(self):
+        # 抓取图像
+        self.capture_data = self.buffer_data
+        self.capture_shape = self.buffer_shape
+        # 显示抓取的图像
+        # byte -> QImage
+        h, w, ch = self.capture_shape
+        image = QImage(self.capture_data, w, h, w*ch, QImage.Format_BGR888)
+        # QImage -> QPixmap
+        pixmap = QPixmap.fromImage(image)
+        # QPixmap -> QLabel 
+        self.ui.lbl_image.setPixmap(pixmap)
+        self.ui.lbl_image.setScaledContents(True)
+
+    def digit_recognize(self):
+        # 已知
+        # self.capture_data
+        # self.capture_shape
+        # 准备:训练好的手写字符模型:models.lenet:LeNet-5
+        # 实现与模型一致的 神经网络结构(层数,层类型,每层参数一致)
+        # 利用神经网络结构,识别图片
+        image = np.ndarray(
+            shape=self.capture_shape,    # 构建图像矩阵的形状
+            dtype=np.uint8,
+            buffer=self.buffer_data
+        )
+
+        result = self.reco.recognize(image)
+        self.ui.lbl_top1.setText(F"<font size=20 color=red><b><strong>{result[0][0]}</strong><b></font>")
+        self.ui.lbl_prob1.setText(F"{result[0][1]:3.2f}")
+        if len(result) == 2:
+            self.ui.lbl_top2.setText(F"{result[1][0]}")
+            self.ui.lbl_prob2.setText(F"{result[1][1]:3.2f}")
+        else:
+            self.ui.lbl_top2.setText("--")
+            self.ui.lbl_prob2.setText("--")
+
+    def show_video(self, h, w, ch, data):
+        self.buffer_data = data
+        self.buffer_shape = (h, w, ch)
+        # byte -> QImage
+        image = QImage(data, w, h, w*ch, QImage.Format_RGB888)
+        # QImage -> QPixmap
+        pixmap = QPixmap.fromImage(image)
+        # QPixmap -> QLabel 
+        self.ui.lbl_video.setPixmap(pixmap)
+        self.ui.lbl_video.setScaledContents(True)

BIN
Day5/PyQt登录/PyQTdeng'lu/build/lib/digitapp/__init__.py