reco.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. from Train.lenet5 import Lenet5
  2. import torch
  3. import cv2
  4. import numpy as np
  5. import torchvision.transforms as transforms
  6. from PIL import Image
  7. class Lenet5Recognizier:
  8. def __init__(self, model_file="./lenet5.pt"):
  9. super(Lenet5Recognizier, self).__init__()
  10. # 创建模型
  11. self.net = Lenet5()
  12. # cuda
  13. self.state = torch.load(model_file, map_location='cpu')
  14. self.net.load_state_dict(self.state)
  15. def recognizie(self, img):
  16. img = cv2.resize(img, (200, 300))
  17. data_transforms = transforms.Compose([
  18. transforms.Resize((150, 150)), # 调整图像大小为150x150
  19. transforms.Grayscale(), # 将图像转换为灰度图像
  20. transforms.ToTensor(), # 将图像转换为张量
  21. transforms.Normalize(mean=[0.485], std=[0.229]) # 标准化图像(只有一个通道)
  22. ])
  23. # 应用数据预处理操作
  24. input_tensor = data_transforms(img).unsqueeze(0)
  25. y = self.net(input_tensor)
  26. # 处理预测
  27. prob = torch.softmax(y, dim=1)
  28. cls_id = torch.argmax(prob, dim=1)
  29. cls_pr = prob[0][cls_id.item()]
  30. return cls_id.detach().item(), cls_pr.detach().numpy()
  31. def recognizie_file(self, img_file):
  32. # 图像加载
  33. img = cv2.imread(img_file)
  34. return self.recognizie(img)
  35. if __name__ == "__main__":
  36. recognizier = Lenet5Recognizier()
  37. c_id, c_pr = recognizier.recognizie_file("output1.jpg")
  38. print(c_id, c_pr)