1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950 |
- from Train.lenet5 import Lenet5
- import torch
- import cv2
- import numpy as np
- import torchvision.transforms as transforms
- from PIL import Image
- class Lenet5Recognizier:
- def __init__(self, model_file="./lenet5.pt"):
- super(Lenet5Recognizier, self).__init__()
- # 创建模型
- self.net = Lenet5()
- # cuda
- self.state = torch.load(model_file, map_location='cpu')
- self.net.load_state_dict(self.state)
-
- def recognizie(self, img):
- img = cv2.resize(img, (200, 300))
- data_transforms = transforms.Compose([
- transforms.Resize((150, 150)), # 调整图像大小为150x150
- transforms.Grayscale(), # 将图像转换为灰度图像
- transforms.ToTensor(), # 将图像转换为张量
- transforms.Normalize(mean=[0.485], std=[0.229]) # 标准化图像(只有一个通道)
- ])
- # 应用数据预处理操作
- input_tensor = data_transforms(img).unsqueeze(0)
- y = self.net(input_tensor)
- # 处理预测
- prob = torch.softmax(y, dim=1)
- cls_id = torch.argmax(prob, dim=1)
- cls_pr = prob[0][cls_id.item()]
- return cls_id.detach().item(), cls_pr.detach().numpy()
- def recognizie_file(self, img_file):
- # 图像加载
- img = cv2.imread(img_file)
- return self.recognizie(img)
- if __name__ == "__main__":
- recognizier = Lenet5Recognizier()
- c_id, c_pr = recognizier.recognizie_file("output1.jpg")
- print(c_id, c_pr)
|