infer.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. from model import LeNet5
  2. import torch
  3. import cv2
  4. import numpy as np
  5. class DigitClassifier:
  6. def __init__(self): # 初始化
  7. super(DigitClassifier, self).__init__()
  8. # 创建网络
  9. self.net = LeNet5()
  10. # 加载模型算子(训练好的模型)
  11. state = torch.load("lenet5.pt")
  12. self.net.load_state_dict(state)
  13. def recognize_file(self, digit_file): # 输入图像文件
  14. # 1. 读取文件
  15. img = cv2.imread(digit_file)
  16. img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
  17. # 2. 处理文件:图像转为NCHW格式的float张量
  18. img = img.astype(np.float32)
  19. img = img / 255.0 # 像素转换为0-1之间的值
  20. img = torch.from_numpy(img).clone() # 把矩阵转换为张量
  21. img = img.view(1, 1, 28, 28) # 模型支持4维图像 NCHW
  22. # 3. 调用self.net预测
  23. y = self.net(img)
  24. # 4. 处理预测结果:类别与概率
  25. cls = torch.argmax(y, dim=1).item() # item取张量的值
  26. prob = y[0][cls].item()
  27. print(cls, prob)
  28. return cls, prob# 返回类别,返回这个类别的概率
  29. if __name__ == "__main__":
  30. classifier = DigitClassifier() # 生成分类器
  31. cls, prob = classifier.recognize_file("01_0.jpg")
  32. print(F"类别:{cls},概率:{prob}")