ds.py 818 B

12345678910111213141516171819202122232425262728
  1. from torchvision.datasets import MNIST
  2. from torchvision.transforms import Compose, ToTensor
  3. import numpy as np
  4. import cv2
  5. # 加载数据集(如果不存在则下载)
  6. transform = Compose([ToTensor()]) # 把图像转换为float的张量
  7. ds_mnist = MNIST(root="datasets", train=True, download=True, transform=transform)
  8. # print(len(ds_mnist))
  9. # print(ds_mnist[0][0].shape)
  10. # print(type(ds_mnist))
  11. # print(ds_mnist[0][1]) # 类别
  12. # print(ds_mnist[0][0])
  13. # 把图像保存为文件
  14. for i in range(10):
  15. img, cls = ds_mnist[i] # img是0-1之间float
  16. img = img.mul(255) # 0-255
  17. img = img.numpy().copy()
  18. img = img.astype(np.uint8) # 转为整数
  19. # 转换成28 * 28 * 1的图像矩阵
  20. img = img.transpose(1, 2, 0) # 1 * 28 * 28 -> 28 * 28 * 1
  21. cv2.imwrite(F"{i:02d}_{cls}.jpg", img)