瀏覽代碼

上传文件至 'Day5 刘琪20200404222/PyQT登录/build/lib/digitapp'

lq 1 年之前
父節點
當前提交
7916deaff6

+ 67 - 0
Day5 刘琪20200404222/PyQT登录/build/lib/digitapp/DigitAI.py

@@ -0,0 +1,67 @@
+import cv2
+import numpy as np
+import torch
+from .DigitModule import LeNet
+import os
+
+cur_dir = os.path.dirname(__file__)
+mod_file = os.path.join(cur_dir,"data/models.lenet")
+
+class DigitRecognizier:
+    def __init__(self):
+        super(DigitRecognizier, self).__init__()
+        self.CUDA = torch.cuda.is_available()
+        self.net = LeNet(10)
+        if self.CUDA:
+            self.net.cuda()
+        state = torch.load(mod_file)
+        self.net.load_state_dict(state)
+    
+
+    def pre_image(self, img):
+        # 大小
+        img = cv2.resize(img, (28,28))
+        # 灰度
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        # 小数类型
+        img = img.astype("float32")
+        # 逆转(F-> B, B->F)
+        img = 255.0 -img
+        # 去噪
+        img[img <= 150] = 0
+        cv2.imwrite("g.png", img.astype("uint8"))
+        # 转换为张量
+        img = torch.from_numpy(img).clone()
+        # 转换为N C H W
+        
+        return img
+
+    
+    def recognize(self, img):
+        result = []  # (数字,概率)
+        # 根据模型计算输出
+        p_img = self.pre_image(img)
+        if self.CUDA:
+            p_img = p_img.cuda()
+
+        predict = self.net.forward(p_img.view(1, 1, 28, 28))
+
+        pred_prob = torch.nn.functional.softmax(predict, dim=1)
+        # 计算在gpu,速度快
+        pred_prob = pred_prob[0]
+        # pred_prob = torch.squeeze(pred_prob, 0)
+        # pred_prob = pred_prob.view((pred_prob.shape[1]))
+        # 找出最大概率及其下标,判定概率
+        top1 = torch.argmax(pred_prob)
+        pro1 = pred_prob[top1]
+        result.append((top1.cpu().detach().numpy(), pro1.cpu().detach().numpy()))
+        if pro1 < 1.0:
+            # 把top1置为0,再找最大值
+            pred_prob[top1] = 0.0
+            top2 = torch.argmax(pred_prob)
+            pro2 = pred_prob[top2]
+            result.append((top2.cpu().detach().numpy(), pro2.cpu().detach().numpy()))
+
+
+        return result # 返回长度为10的概率向量
+

+ 17 - 0
Day5 刘琪20200404222/PyQT登录/build/lib/digitapp/DigitApp.py

@@ -0,0 +1,17 @@
+"""
+"""
+from PyQt5.QtWidgets import QApplication
+from .DigitForm import DigitForm
+import sys 
+
+class DigitApp(QApplication):
+    """
+    """
+    def __init__(self):
+        """
+        """
+        super(DigitApp, self).__init__(sys.argv)
+        # 创建应用主窗体
+        self.dlg = DigitForm()
+        self.dlg.show()
+    

+ 36 - 0
Day5 刘琪20200404222/PyQT登录/build/lib/digitapp/DigitDev.py

@@ -0,0 +1,36 @@
+from PyQt5.QtCore import QThread, pyqtSignal
+import cv2
+
+class DigitDev(QThread):
+    
+    signal_video = pyqtSignal(int, int, int, bytes)
+
+    def __init__(self):
+        super(DigitDev, self).__init__()
+        self.is_over = False
+        # 初始化设备
+        self.dev = cv2.VideoCapture(0, cv2.CAP_DSHOW)
+
+
+    def run(self):
+        while not self.is_over:
+            # 图像抓取
+            status, image = self.dev.read()
+            # 状态判定
+            if not status:
+                self.dev.release()
+                self.exit(0)
+            # 如果抓取图像成功,发送
+            shape = image.shape
+            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+            self.signal_video.emit(shape[0], shape[1], shape[2], image.tobytes())
+            # 视觉暂停
+            QThread.usleep(100000)
+
+    def close(self):
+        self.is_over = True
+        while self.isRunning():
+            pass
+        # 释放设备
+        if self.dev.isOpened():
+            self.dev.release()