12345678910111213141516171819202122232425262728 |
- from torchvision.datasets import MNIST
- from torchvision.transforms import Compose, ToTensor
- import numpy as np
- import cv2
- # 加载数据集(如果不存在则下载)
- transform = Compose([ToTensor()]) # 把图像转换为float的张量
- ds_mnist = MNIST(root="datasets", train=True, download=True, transform=transform)
- # print(len(ds_mnist))
- # print(ds_mnist[0][0].shape)
- # print(type(ds_mnist))
- # print(ds_mnist[0][1]) # 类别
- # print(ds_mnist[0][0])
- # 把图像保存为文件
- for i in range(10):
- img, cls = ds_mnist[i] # img是0-1之间float
- img = img.mul(255) # 0-255
- img = img.numpy().copy()
- img = img.astype(np.uint8) # 转为整数
- # 转换成28 * 28 * 1的图像矩阵
- img = img.transpose(1, 2, 0) # 1 * 28 * 28 -> 28 * 28 * 1
- cv2.imwrite(F"{i:02d}_{cls}.jpg", img)
|