|
@@ -0,0 +1,102 @@
|
|
|
+from lenet5 import Lenet5
|
|
|
+from torch.nn import CrossEntropyLoss
|
|
|
+from torch.optim import Adam
|
|
|
+from torchvision.datasets import MNIST
|
|
|
+from torchvision.transforms import Compose, ToTensor
|
|
|
+from torch.utils.data import DataLoader
|
|
|
+import torch
|
|
|
+import os
|
|
|
+
|
|
|
+class Lenet5Trainer:
|
|
|
+ def __init__(self, lr=0.0001, model_file="lenet5.pt", ds_path="datasets"):
|
|
|
+ """
|
|
|
+ lr:学习率
|
|
|
+ model_file:保存的模型文件
|
|
|
+ ds_path:数据集目录
|
|
|
+ """
|
|
|
+ super(Lenet5Trainer, self).__init__()
|
|
|
+ # GPU运算
|
|
|
+ self.CUDA = torch.cuda.is_available()
|
|
|
+ self.lr = lr
|
|
|
+ self.model_file = model_file
|
|
|
+ self.ds_path = ds_path
|
|
|
+ # 神经网络模型
|
|
|
+ self.net = Lenet5()
|
|
|
+ if self.CUDA:
|
|
|
+ self.net.cuda()
|
|
|
+
|
|
|
+ # 判定模型文件是否存在
|
|
|
+ if os.path.exists(self.model_file):
|
|
|
+ print("模型存在")
|
|
|
+ state = torch.load(self.model_file, map_location='cpu')
|
|
|
+ self.net.load_state_dict(state)
|
|
|
+ else:
|
|
|
+ print("模型不存在,从头训练")
|
|
|
+ # 损失函数
|
|
|
+ self.loss_f = CrossEntropyLoss()
|
|
|
+ # 优化器
|
|
|
+ self.optimizer = Adam(self.net.parameters(), lr=self.lr)
|
|
|
+ # 数据集的初始化
|
|
|
+ self.transform = Compose([ToTensor()])
|
|
|
+ self.ds_train = MNIST(self.ds_path, download=True, train=True, transform=self.transform)
|
|
|
+ self.ds_valid = MNIST(self.ds_path, download=True, train=False, transform=self.transform)
|
|
|
+ # 批次数据集
|
|
|
+ self.loader_train = DataLoader(self.ds_train, shuffle=True, batch_size=1000)
|
|
|
+ self.loader_valid = DataLoader(self.ds_valid, shuffle=True, batch_size=1000)
|
|
|
+
|
|
|
+
|
|
|
+ def train_one(self):
|
|
|
+ for x, y in self.loader_train:
|
|
|
+ if self.CUDA:
|
|
|
+ x = x.cuda()
|
|
|
+ y = y.cuda()
|
|
|
+ # 计算预测输出
|
|
|
+ y_ = self.net(x)
|
|
|
+ # 计算损失
|
|
|
+ loss = self.loss_f(y_, y) # 单热编码one-hot
|
|
|
+ # 求导
|
|
|
+ self.optimizer.zero_grad()
|
|
|
+ loss.backward()
|
|
|
+ # 梯度更新
|
|
|
+ self.optimizer.step()
|
|
|
+
|
|
|
+ @torch.no_grad()
|
|
|
+ def valid(self):
|
|
|
+ all_num = 0.0
|
|
|
+ acc_num = 0.0
|
|
|
+ all_loss = 0.0
|
|
|
+ for t_x, t_y in self.loader_valid:
|
|
|
+ if self.CUDA:
|
|
|
+ t_x = t_x.cuda()
|
|
|
+ t_y = t_y.cuda()
|
|
|
+ all_num += len(t_y) # 累计所有批次的总数
|
|
|
+ t_y_ = self.net(t_x) # 批次预测
|
|
|
+ # 累计计算损失
|
|
|
+ all_loss += self.loss_f(t_y_, t_y)
|
|
|
+ # 统计识别正确数
|
|
|
+ prob = torch.softmax(t_y_, dim=1)
|
|
|
+ y_cls = torch.argmax(prob, dim=1)
|
|
|
+ # 统计
|
|
|
+ acc_num += (y_cls == t_y).float().sum()
|
|
|
+
|
|
|
+ # 输出
|
|
|
+ print(F"测试集损失:{all_loss:8.6f}")
|
|
|
+ print(F"\t|-测试集识别正确率:{100.0 * acc_num / all_num:5.2f}%")
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ def train(self, epoch, interval=1):
|
|
|
+ # 迭代训练
|
|
|
+ for e in range(epoch):
|
|
|
+ # 训练一轮
|
|
|
+ self.train_one()
|
|
|
+ # 验证判定
|
|
|
+ if e % interval == 0:
|
|
|
+ # 验证
|
|
|
+ self.valid()
|
|
|
+ # 保存模型
|
|
|
+ torch.save(self.net.state_dict(), self.model_file)
|
|
|
+
|
|
|
+if __name__ == "__main__":
|
|
|
+ trainer = Lenet5Trainer()
|
|
|
+ trainer.train(5)
|