|
@@ -20,20 +20,48 @@ class Trainer:
|
|
|
|
|
|
def __init__(self):
|
|
|
|
|
|
+ # 判定电脑安装GPU的环境:cuda,cnn,pytorch-
|
|
|
+
|
|
|
+ self.CUDA = False # torch.cuda.is_available() # 返回逻辑值:True:支持GPU,False不支持GPU
|
|
|
+
|
|
|
# 准备训练需要的数据、损失函数,优化器,学习率
|
|
|
|
|
|
- self.lr = 0.001
|
|
|
+ self.lr = 0.0001
|
|
|
|
|
|
self.m_file = "lenet5.pt" # 模型的保存文件
|
|
|
|
|
|
self.net = LeNet5() # 需要训练的网络
|
|
|
|
|
|
+ if self.CUDA:
|
|
|
+
|
|
|
+ self.net.cuda() # 把net网络模型存储到GPU上面
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+ # 加载上一次训练的模型
|
|
|
+
|
|
|
+ if os.path.exists(self.m_file):
|
|
|
+
|
|
|
+ # 存在就加载
|
|
|
+
|
|
|
+ print("加载模型中...")
|
|
|
+
|
|
|
+ state = torch.load(self.m_file) # 读取文件
|
|
|
+
|
|
|
+ self.net.load_state_dict(state)
|
|
|
+
|
|
|
+ else:
|
|
|
+
|
|
|
+ print("模型存在,重头训练")
|
|
|
+
|
|
|
self.loss_f = CrossEntropyLoss() # 损失函数
|
|
|
|
|
|
self.optimizer = Adam(self.net.parameters(), lr=self.lr) # 优化器
|
|
|
|
|
|
|
|
|
|
|
|
+
|
|
|
+
|
|
|
# 数据集 - 训练集
|
|
|
|
|
|
self.trans = Compose([ToTensor()])
|
|
@@ -68,6 +96,12 @@ class Trainer:
|
|
|
|
|
|
# print(F"\t|-第{batch:02d}批次训练。")
|
|
|
|
|
|
+ if self.CUDA:
|
|
|
+
|
|
|
+ x = x.cuda()
|
|
|
+
|
|
|
+ y = y.cuda()
|
|
|
+
|
|
|
y_ = self.net(x) # 进行预测
|
|
|
|
|
|
# 计算误差
|
|
@@ -100,6 +134,12 @@ class Trainer:
|
|
|
|
|
|
for t_x, t_y in self.bt_valid:
|
|
|
|
|
|
+ if self.CUDA:
|
|
|
+
|
|
|
+ t_x = t_x.cuda()
|
|
|
+
|
|
|
+ t_y = t_y.cuda()
|
|
|
+
|
|
|
# 统计样本数
|
|
|
|
|
|
all_num += len(t_x)
|
|
@@ -142,11 +182,9 @@ class Trainer:
|
|
|
|
|
|
self.valid()
|
|
|
|
|
|
-
|
|
|
-
|
|
|
- # 保存训练的模型
|
|
|
-
|
|
|
+ # 保存训练的模型
|
|
|
|
|
|
+ torch.save(self.net.state_dict(), self.m_file)
|
|
|
|
|
|
|
|
|
|
|
@@ -154,4 +192,22 @@ if __name__ == "__main__":
|
|
|
|
|
|
trainer = Trainer()
|
|
|
|
|
|
- trainer.train(10) # 训练10轮
|
|
|
+ trainer.train(50) # 训练10轮
|
|
|
+
|
|
|
+ # print(torch.cuda.is_available())
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+"""
|
|
|
+
|
|
|
+ 244K = 6 * 5 * 5 * 8 矩阵
|
|
|
+
|
|
|
+ 16 * 5 * 5 * 8
|
|
|
+
|
|
|
+ 120 * 5 * 5 * 8
|
|
|
+
|
|
|
+ 120 * 84* 8
|
|
|
+
|
|
|
+ 84 * 10*8
|
|
|
+
|
|
|
+"""
|