from model import LeNet5 import torch import cv2 import numpy as np class DigitClassifier: def __init__(self): # 初始化 super(DigitClassifier, self).__init__() # 创建网络 self.net = LeNet5() # 加载模型算子(训练好的模型) state = torch.load("lenet5.pt") self.net.load_state_dict(state) def recognize_file(self, digit_file): # 输入图像文件 # 1. 读取文件 img = cv2.imread(digit_file) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 2. 处理文件:图像转为NCHW格式的float张量 img = img.astype(np.float32) img = img / 255.0 # 像素转换为0-1之间的值 img = torch.from_numpy(img).clone() # 把矩阵转换为张量 img = img.view(1, 1, 28, 28) # 模型支持4维图像 NCHW # 3. 调用self.net预测 y = self.net(img) # 4. 处理预测结果:类别与概率 cls = torch.argmax(y, dim=1).item() # item取张量的值 prob = y[0][cls].item() print(cls, prob) return cls, prob# 返回类别,返回这个类别的概率 if __name__ == "__main__": classifier = DigitClassifier() # 生成分类器 cls, prob = classifier.recognize_file("01_0.jpg") print(F"类别:{cls},概率:{prob}")