12345678910111213141516171819202122232425262728293031323334353637383940414243444546 |
- from lenet5 import Lenet5
- import torch
- import cv2
- import numpy as np
- class Lenet5Recognizier:
- def __init__(self, model_file="lenet5.pt"):
- super(Lenet5Recognizier, self).__init__()
- # 创建模型
- self.net = Lenet5()
- # cuda
- self.state = torch.load(model_file, map_location='cpu')
- self.net.load_state_dict(self.state)
-
- def recognizie(self, img):
- # 图像的处理
- # 灰度处理
- img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
- # 类型转换
- img = img.astype("float32")
- # 归一处理
- img = img / 255.0
- # 转换为Tensor
- img = torch.from_numpy(img).clone()
- # 数据格式:[N, C, H, W]
- img = img.view(1, 1, 28, 28)
- # 预测
- # cuda
- y = self.net(img)
- # 处理预测
- prob = torch.softmax(y, dim=1)
- cls_id = torch.argmax(prob, dim=1)
- cls_pr = prob[0][cls_id.item()]
- return cls_id.detach().item(), cls_pr.detach().numpy()
- def recognizie_file(self, img_file):
- # 图像加载
- img = cv2.imread(img_file)
- return self.recognizie(img)
- if __name__ == "__main__":
- recognizier = Lenet5Recognizier()
- c_id, c_pr = recognizier.recognizie_file("09_4.jpg")
- print(c_id, c_pr)
|