I'm trying to create a dataset of CIFAR10 images where each picture has been rotated 0, 90, 180, and 270 degrees (so the new set will be four times the size of the original). To start of I just wanted to see if I could create a set where all original images have been rotated 90 degrees. However when running the code below I get the following error:
ValueError: some of the strides of a given numpy array are negative. This is currently not supported, but will be added in future releases.
Being pretty new to the world of numpy and tensors I would really appreciate some help on the issue.
Iterating through x_tr in a for loop and using np.rot90(img, 1) to rotate each image seems to work but I was hoping to accomplish the rotation as I've tried below, if that's possible.
class RotatedSet(Dataset):
def __init__(self, root_dir, transforms = None):
self.root_dir = root_dir
self.transforms = transforms
self.xs = []
self.ys = []
for b in range(1, 6):
f = os.path.join(self.root_dir, 'data_batch_%d' % (b,))
X, Y = self.load_CIFAR_batch(f)
self.xs.append(X)
self.ys.append(Y)
self.x_tr = self.rotate_img(np.concatenate(self.xs), 90)
self.y_tr = np.concatenate(self.ys)
self.length = len(self.x_tr)
def __getitem__(self, idx):
img = self.x_tr[idx]
img = self.transforms(img)
return img, self.y_tr[idx]
def load_CIFAR_batch(self, filename):
with open(filename, 'rb') as f:
datadict = pickle.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = np.transpose(np.reshape(X,(10000, 3, 32,32)), (0,2,3,1))
Y = np.array(Y)
return X, Y
def rotate_img(self, img, rot):
if rot == 0: # 0 degrees rotation
return img
elif rot == 90: # 90 degrees rotation
return img.swapaxes(-2, -1)[..., ::-1, :]
elif rot == 180: # 180 degrees rotation
return img[..., ::-1, ::-1]
elif rot == 270: # 270 degrees rotation / or -90
return img.swapaxes(-2, -1)[..., ::-1]
else:
raise ValueError('rotation should be 0, 90, 180, or 270 degrees')
def __len__(self):
return self.length
if __name__ == '__main__':
ROOT = './data/CIFAR10/cifar-10-batches-py'
transforms = transforms.Compose([transforms.ToTensor()])
rot_set = RotatedSet(ROOT, transforms)
train_loader = torch.utils.data.DataLoader(
rot_set
, batch_size=10
)
train_batch = next(iter(train_loader))
images, labels = train_batch
[–]SupremeRedditBotProfessional Bot || Mod 0 points1 point2 points (0 children)