DigitAI.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. import cv2
  2. import numpy as np
  3. import torch
  4. from .DigitModule import LeNet
  5. import os
  6. cur_dir = os.path.dirname(__file__)
  7. mod_file = os.path.join(cur_dir,"data/models.lenet")
  8. class DigitRecognizier:
  9. def __init__(self):
  10. super(DigitRecognizier, self).__init__()
  11. self.CUDA = torch.cuda.is_available()
  12. self.net = LeNet(10)
  13. if self.CUDA:
  14. self.net.cuda()
  15. state = torch.load(mod_file)
  16. self.net.load_state_dict(state)
  17. def pre_image(self, img):
  18. # 大小
  19. img = cv2.resize(img, (28,28))
  20. # 灰度
  21. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  22. # 小数类型
  23. img = img.astype("float32")
  24. # 逆转(F-> B, B->F)
  25. img = 255.0 -img
  26. # 去噪
  27. img[img <= 150] = 0
  28. cv2.imwrite("g.png", img.astype("uint8"))
  29. # 转换为张量
  30. img = torch.from_numpy(img).clone()
  31. # 转换为N C H W
  32. return img
  33. def recognize(self, img):
  34. result = [] # (数字,概率)
  35. # 根据模型计算输出
  36. p_img = self.pre_image(img)
  37. if self.CUDA:
  38. p_img = p_img.cuda()
  39. predict = self.net.forward(p_img.view(1, 1, 28, 28))
  40. pred_prob = torch.nn.functional.softmax(predict, dim=1)
  41. # 计算在gpu,速度快
  42. pred_prob = pred_prob[0]
  43. # pred_prob = torch.squeeze(pred_prob, 0)
  44. # pred_prob = pred_prob.view((pred_prob.shape[1]))
  45. # 找出最大概率及其下标,判定概率
  46. top1 = torch.argmax(pred_prob)
  47. pro1 = pred_prob[top1]
  48. result.append((top1.cpu().detach().numpy(), pro1.cpu().detach().numpy()))
  49. if pro1 < 1.0:
  50. # 把top1置为0,再找最大值
  51. pred_prob[top1] = 0.0
  52. top2 = torch.argmax(pred_prob)
  53. pro2 = pred_prob[top2]
  54. result.append((top2.cpu().detach().numpy(), pro2.cpu().detach().numpy()))
  55. return result # 返回长度为10的概率向量