|
@@ -0,0 +1,67 @@
|
|
|
|
+import cv2
|
|
|
|
+import numpy as np
|
|
|
|
+import torch
|
|
|
|
+from .DigitModule import LeNet
|
|
|
|
+import os
|
|
|
|
+
|
|
|
|
+cur_dir = os.path.dirname(__file__)
|
|
|
|
+mod_file = os.path.join(cur_dir,"data/models.lenet")
|
|
|
|
+
|
|
|
|
+class DigitRecognizier:
|
|
|
|
+ def __init__(self):
|
|
|
|
+ super(DigitRecognizier, self).__init__()
|
|
|
|
+ self.CUDA = torch.cuda.is_available()
|
|
|
|
+ self.net = LeNet(10)
|
|
|
|
+ if self.CUDA:
|
|
|
|
+ self.net.cuda()
|
|
|
|
+ state = torch.load(mod_file)
|
|
|
|
+ self.net.load_state_dict(state)
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def pre_image(self, img):
|
|
|
|
+ # 大小
|
|
|
|
+ img = cv2.resize(img, (28,28))
|
|
|
|
+ # 灰度
|
|
|
|
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
|
|
|
+ # 小数类型
|
|
|
|
+ img = img.astype("float32")
|
|
|
|
+ # 逆转(F-> B, B->F)
|
|
|
|
+ img = 255.0 -img
|
|
|
|
+ # 去噪
|
|
|
|
+ img[img <= 150] = 0
|
|
|
|
+ cv2.imwrite("g.png", img.astype("uint8"))
|
|
|
|
+ # 转换为张量
|
|
|
|
+ img = torch.from_numpy(img).clone()
|
|
|
|
+ # 转换为N C H W
|
|
|
|
+
|
|
|
|
+ return img
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ def recognize(self, img):
|
|
|
|
+ result = [] # (数字,概率)
|
|
|
|
+ # 根据模型计算输出
|
|
|
|
+ p_img = self.pre_image(img)
|
|
|
|
+ if self.CUDA:
|
|
|
|
+ p_img = p_img.cuda()
|
|
|
|
+
|
|
|
|
+ predict = self.net.forward(p_img.view(1, 1, 28, 28))
|
|
|
|
+
|
|
|
|
+ pred_prob = torch.nn.functional.softmax(predict, dim=1)
|
|
|
|
+ # 计算在gpu,速度快
|
|
|
|
+ pred_prob = pred_prob[0]
|
|
|
|
+ # pred_prob = torch.squeeze(pred_prob, 0)
|
|
|
|
+ # pred_prob = pred_prob.view((pred_prob.shape[1]))
|
|
|
|
+ # 找出最大概率及其下标,判定概率
|
|
|
|
+ top1 = torch.argmax(pred_prob)
|
|
|
|
+ pro1 = pred_prob[top1]
|
|
|
|
+ result.append((top1.cpu().detach().numpy(), pro1.cpu().detach().numpy()))
|
|
|
|
+ if pro1 < 1.0:
|
|
|
|
+ # 把top1置为0,再找最大值
|
|
|
|
+ pred_prob[top1] = 0.0
|
|
|
|
+ top2 = torch.argmax(pred_prob)
|
|
|
|
+ pro2 = pred_prob[top2]
|
|
|
|
+ result.append((top2.cpu().detach().numpy(), pro2.cpu().detach().numpy()))
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+ return result # 返回长度为10的概率向量
|
|
|
|
+
|