spidermanYT vor 1 Jahr
Ursprung
Commit
c799b8e915

BIN
day06/lenet/__pycache__/model.cpython-39.pyc


BIN
day06/lenet/datasets/MNIST/raw/t10k-images-idx3-ubyte


BIN
day06/lenet/datasets/MNIST/raw/t10k-images-idx3-ubyte.gz


BIN
day06/lenet/datasets/MNIST/raw/t10k-labels-idx1-ubyte


BIN
day06/lenet/datasets/MNIST/raw/t10k-labels-idx1-ubyte.gz


BIN
day06/lenet/datasets/MNIST/raw/train-images-idx3-ubyte


BIN
day06/lenet/datasets/MNIST/raw/train-images-idx3-ubyte.gz


BIN
day06/lenet/datasets/MNIST/raw/train-labels-idx1-ubyte


BIN
day06/lenet/datasets/MNIST/raw/train-labels-idx1-ubyte.gz


BIN
day06/lenet/lenet5.pt


+ 62 - 6
day06/lenet/train.py

@@ -20,20 +20,48 @@ class Trainer:
 
     def __init__(self):
 
+        # 判定电脑安装GPU的环境:cuda,cnn,pytorch-
+
+        self.CUDA = False # torch.cuda.is_available()   # 返回逻辑值:True:支持GPU,False不支持GPU
+
         # 准备训练需要的数据、损失函数,优化器,学习率
 
-        self.lr = 0.001
+        self.lr = 0.0001
 
         self.m_file = "lenet5.pt"  # 模型的保存文件
 
         self.net = LeNet5() # 需要训练的网络
 
+        if self.CUDA:
+
+            self.net.cuda()   # 把net网络模型存储到GPU上面
+
+
+
+        # 加载上一次训练的模型
+
+        if os.path.exists(self.m_file):
+
+            # 存在就加载
+
+            print("加载模型中...")
+
+            state = torch.load(self.m_file)  # 读取文件
+
+            self.net.load_state_dict(state)
+
+        else:
+
+            print("模型存在,重头训练")    
+
         self.loss_f = CrossEntropyLoss()  # 损失函数
 
         self.optimizer = Adam(self.net.parameters(), lr=self.lr)  # 优化器
 
 
 
+
+
         # 数据集 - 训练集
 
         self.trans = Compose([ToTensor()])
@@ -68,6 +96,12 @@ class Trainer:
 
             # print(F"\t|-第{batch:02d}批次训练。")
 
+            if self.CUDA:
+
+                x = x.cuda()
+
+                y = y.cuda()
+
             y_ = self.net(x)  # 进行预测
 
             # 计算误差
@@ -100,6 +134,12 @@ class Trainer:
 
         for t_x, t_y in self.bt_valid:
 
+            if self.CUDA:
+
+                t_x = t_x.cuda()
+
+                t_y = t_y.cuda()
+
             # 统计样本数
 
             all_num  += len(t_x)
@@ -142,11 +182,9 @@ class Trainer:
 
             self.valid()
 
-        
-
-        # 保存训练的模型
-
+            # 保存训练的模型
 
+            torch.save(self.net.state_dict(), self.m_file)
 
 
 
@@ -154,4 +192,22 @@ if __name__ == "__main__":
 
     trainer = Trainer()
 
-    trainer.train(10) # 训练10轮
+    trainer.train(50) # 训练10轮
+
+    # print(torch.cuda.is_available()) 
+
+
+
+"""
+
+    244K = 6 * 5 * 5 * 8 矩阵  
+
+          16 * 5 * 5 * 8 
+
+         120 * 5 * 5 * 8
+
+             120 * 84* 8
+
+              84 * 10*8
+
+"""