1234567891011121314151617181920212223242526272829303132333435363738 |
- from model import LeNet5
- import torch
- import cv2
- import numpy as np
- class DigitClassifier:
- def __init__(self): # 初始化
- super(DigitClassifier, self).__init__()
- # 创建网络
- self.net = LeNet5()
- # 加载模型算子(训练好的模型)
- state = torch.load("lenet5.pt")
- self.net.load_state_dict(state)
-
- def recognize_file(self, digit_file): # 输入图像文件
- # 1. 读取文件
- img = cv2.imread(digit_file)
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # 2. 处理文件:图像转为NCHW格式的float张量
- img = img.astype(np.float32)
- img = img / 255.0 # 像素转换为0-1之间的值
- img = torch.from_numpy(img).clone() # 把矩阵转换为张量
- img = img.view(1, 1, 28, 28) # 模型支持4维图像 NCHW
- # 3. 调用self.net预测
- y = self.net(img)
- # 4. 处理预测结果:类别与概率
- cls = torch.argmax(y, dim=1).item() # item取张量的值
- prob = y[0][cls].item()
- print(cls, prob)
- return cls, prob# 返回类别,返回这个类别的概率
- if __name__ == "__main__":
- classifier = DigitClassifier() # 生成分类器
- cls, prob = classifier.recognize_file("02_2.jpg")
- print(F"类别:{cls},概率:{prob}")
|