reco.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. from lenet5 import Lenet5
  2. import torch
  3. import cv2
  4. import numpy as np
  5. class Lenet5Recognizier:
  6. def __init__(self, model_file="lenet5.pt"):
  7. super(Lenet5Recognizier, self).__init__()
  8. # 创建模型
  9. self.net = Lenet5()
  10. # cuda
  11. self.state = torch.load(model_file, map_location='cpu')
  12. self.net.load_state_dict(self.state)
  13. def recognizie(self, img):
  14. # 图像的处理
  15. # 灰度处理
  16. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  17. # 类型转换
  18. img = img.astype("float32")
  19. # 归一处理
  20. img = img / 255.0
  21. # 转换为Tensor
  22. img = torch.from_numpy(img).clone()
  23. # 数据格式:[N, C, H, W]
  24. img = img.view(1, 1, 28, 28)
  25. # 预测
  26. # cuda
  27. y = self.net(img)
  28. # 处理预测
  29. prob = torch.softmax(y, dim=1)
  30. cls_id = torch.argmax(prob, dim=1)
  31. cls_pr = prob[0][cls_id.item()]
  32. return cls_id.detach().item(), cls_pr.detach().numpy()
  33. def recognizie_file(self, img_file):
  34. # 图像加载
  35. img = cv2.imread(img_file)
  36. return self.recognizie(img)
  37. if __name__ == "__main__":
  38. recognizier = Lenet5Recognizier()
  39. c_id, c_pr = recognizier.recognizie_file("09_4.jpg")
  40. print(c_id, c_pr)