在CIFAR数据集上使用Pytorch创建自定义数据集

我没有使用这些数据集的Pytorch内置API,而是尝试创建自己的数据集并将该数据集馈送到Pytorch的DATASET API和DATALOADER API。但是不知何故,我遇到了一些错误。

我的数据是通过将所有4个火车泡菜合并为一个而创建的。 IMAGES LABELS

创建数据并遵循此[CustomDataset] [3]之后,我编写了以下代码:

import numpy as np
import pickle as pkl
import cv2
import torch
from torch.utils.data import Dataset,DataLoader
from torchvision import transforms,utils


# For custom dataset inherit the parent Dataset class into the child class
class CIFARDataset(Dataset):
    """CIFAR dataset."""

    def __init__(self,pckl_path,transform=None):
        """

        :param pckl_path:
        :param transform:
        """
        " Load the pickle files data"
        pckl_fd = open(pckl_path,"rb")
        self.data_pckl = pkl.load(pckl_fd)

        self.transform = transform

    def __len__(self):
        return len(self.data_pckl)

    def __getitem__(self,idx):
        print("inside __get_item")
        if torch.is_tensor(idx):
            idx = idx.tolist()
        sample = {'image': self.data_pckl['images'][idx],'label': self.data_pckl['labels'][idx]}
        if self.transform:
            sample = self.transform(sample)

        return sample

class ToTensor(object):
    """Convert ndarrays in sample to Tensors."""

    def __call__(self,sample):
        print("In ToTensor")
        image,label = sample['images'],sample['labels']
        image = image.transpose((2,1))
        return {'image': torch.from_numpy(image),'label': torch.from_numpy(np.ndarray(label))}


dataset= CIFARDataset('cifar/train_set.pickle',transform=transforms.Compose(ToTensor()))
# composed = transforms.Compose([ToTensor()])
# sample = dataset.data_pckl
sample1 = {'images':None,'labels': None}

data = dataset[0]

运行此命令时,出现以下错误:

错误:

data = dataset[0]
  File "/home/garud/Documents/DSP_notes/Project/create_dataset.py",line 34,in __getitem__
    sample = self.transform(sample)
  File "/home/garud/anaconda3/lib/python3.7/site-packages/torchvision/transforms/transforms.py",line 60,in __call__
    for t in self.transforms:
TypeError: 'ToTensor' object is not iterable

我调试并检查的示例是将传递给transform函数的字典。不知道哪里出了问题。

请忠告什么是错误的,以及需要遵循哪些最佳实践才能更好地做到这一点。

yanchengwanghao 回答:在CIFAR数据集上使用Pytorch创建自定义数据集

使用transforms.Compose编写转换时,需要提供转换的列表
试试:

transforms.Compose([ToTensor(),])

您仍然只提供一个转换,但是它包装在一个列表中。

本文链接:https://www.f2er.com/3161893.html

大家都在问