How to load a custom dataset in Pytorch (Create a CustomDataLoader in Pytorch)

Felipe A. Moreno
2 min readMar 16, 2021

Hi, today I gonna show you how to create a custom Dataloader for your custom dataset composed of a set (X,y) usually used in classification/regression tasks.

When we work with Pytorch, we know about the module dataset which has CIFAR, MNIST, etc. But what happens when we want to upload our dataset to train a model wanting to get the same object type of DataLoader as the default datasets?

There are several methods/function to approach this, but in my opinion, I’ll show you two:

First

Use the TensorDataset Class:

dataset_train = TensorDataset( torch.tensor(train_x), torch.tensor(train_y) )  
dataset_test = TensorDataset( torch.tensor(test_x), torch.tensor(test_y) )

But in addition, you should need to define a Batch function to get the values when using the iter() function (which will return an error because is not defined).

Second

Personally, I prefer this one because we can add some specific functionalities, we need to create a new class:

import torch
from torch.utils.data import Dataset

class CustomTensorDataset(Dataset):
def __init__(self, dataset, transform_list=None):
[data_X, data_y] = dataset
X_tensor, y_tensor = torch.tensor(data_X), torch.tensor(data_y)
#X_tensor, y_tensor = Tensor(data_X), Tensor(data_y)
tensors = (X_tensor, y_tensor)
assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)
self.tensors = tensors
self.transforms = transform_list

def __getitem__(self, index):
x = self.tensors[0][index]

if self.transforms:
#for transform in self.transforms:
# x = transform(x)
x = self.transforms(x)

y = self.tensors[1][index]

return x, y

def __len__(self):
return self.tensors[0].size(0)

We need to import theDataset class and create our new one called CustomDataset :

import torch
from torch.utils.data import Dataset

class CustomTensorDataset(Dataset):
def __init__(self, dataset, transform_list=None):

Our constructor function requires 2 parameters: Dataset (pre-processed before) and a list of transformation (for images) like:

tfr_ = transforms.Compose([
#transforms.RandomCrop(im_size, padding=4),
#transforms.RandomHorizontalFlip(),
transforms.ToPILImage(),
transforms.Resize((im_size, im_size)),
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761))])

Then, we verify the input data:

We can convert our input_data (generally NumPy arrays) using two methods: torch.tensor infers the dtype automatically, while torch.Tensor returns a torch.FloatTensor.
I would recommend sticking to torch.tensor, which also has arguments like dtype, if you would like to change the type.

[data_X, data_y] = dataset
X_tensor, y_tensor = torch.tensor(data_X), torch.tensor(data_y)
#X_tensor, y_tensor = Tensor(data_X), Tensor(data_y)
tensors = (X_tensor, y_tensor)

Next, we verify the size of samples (the same quantity in X and y).

assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors)

Let's define a function to get the length (overwritten the parent class):

def __len__(self):
return self.tensors[0].size(0)

Finally, we define a get_item function which will be used when you call the iter() function:

def __getitem__(self, index):
x = self.tensors[0][index]

if self.transforms:
#for transform in self.transforms:
# x = transform(x)
x = self.transforms(x)

y = self.tensors[1][index]

return x, y

Here I'll write how to use this code:

trainset, testset = PrepareDataset(...)#trainset = [X_train, y_train]
#testset = [X_test, y_test]
if transforms_list: # != None
dataset_train = CustomTensorDataset(dataset=trainset, transform_list=transforms_list_train)
dataset_test = CustomTensorDataset(dataset=testset, transform_list=transforms_list_test)
else:
dataset_train = CustomTensorDataset(dataset=trainset)
dataset_test = CustomTensorDataset(dataset=testset)
trainloader = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=True, num_workers=8)testloader = DataLoader(dataset=dataset_test, batch_size=100, shuffle=True, num_workers=5)

--

--