from torchvision.datasets import MNIST from torchvision.transforms import Compose, ToTensor import numpy as np import cv2 # 加载数据集(如果不存在则下载) transform = Compose([ToTensor()]) # 把图像转换为float的张量 ds_mnist = MNIST(root="datasets", train=True, download=True, transform=transform) # print(len(ds_mnist)) # print(ds_mnist[0][0].shape) # print(type(ds_mnist)) # print(ds_mnist[0][1]) # 类别 # print(ds_mnist[0][0]) # 把图像保存为文件 for i in range(10): img, cls = ds_mnist[i] # img是0-1之间float img = img.mul(255) # 0-255 img = img.numpy().copy() img = img.astype(np.uint8) # 转为整数 # 转换成28 * 28 * 1的图像矩阵 img = img.transpose(1, 2, 0) # 1 * 28 * 28 -> 28 * 28 * 1 cv2.imwrite(F"{i:02d}_{cls}.jpg", img)