|
"""Based on https://github.com/CaoWGG/multi-scale-training""" |
|
|
|
from torch.utils.data import Sampler,RandomSampler,SequentialSampler |
|
import numpy as np |
|
|
|
class BatchSampler(object): |
|
def __init__(self, sampler, batch_size, drop_last,multiscale_step=None,img_sizes = None): |
|
if not isinstance(sampler, Sampler): |
|
raise ValueError("sampler should be an instance of " |
|
"torch.utils.data.Sampler, but got sampler={}" |
|
.format(sampler)) |
|
if not isinstance(drop_last, bool): |
|
raise ValueError("drop_last should be a boolean value, but got " |
|
"drop_last={}".format(drop_last)) |
|
self.sampler = sampler |
|
self.batch_size = batch_size |
|
self.drop_last = drop_last |
|
if multiscale_step is not None and multiscale_step < 1 : |
|
raise ValueError("multiscale_step should be > 0, but got " |
|
"multiscale_step={}".format(multiscale_step)) |
|
if multiscale_step is not None and img_sizes is None: |
|
raise ValueError("img_sizes must a list, but got img_sizes={} ".format(img_sizes)) |
|
|
|
self.multiscale_step = multiscale_step |
|
self.img_sizes = img_sizes |
|
|
|
def __iter__(self): |
|
num_batch = 0 |
|
batch = [] |
|
size = 416 |
|
for idx in self.sampler: |
|
batch.append([idx,size]) |
|
if len(batch) == self.batch_size: |
|
yield batch |
|
num_batch+=1 |
|
batch = [] |
|
if self.multiscale_step and num_batch % self.multiscale_step == 0 : |
|
size = np.random.choice(self.img_sizes) |
|
if len(batch) > 0 and not self.drop_last: |
|
yield batch |
|
|
|
def __len__(self): |
|
if self.drop_last: |
|
return len(self.sampler) // self.batch_size |
|
else: |
|
return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
|
|
|
class MultiscaleDataSet(torchvision.datasets.ImageFolder): |
|
"""Multiscale ImageFolder dataset""" |
|
def __getitem__(self, index): |
|
if isinstance(index, (tuple, list)): |
|
index, input_size = index |
|
else: |
|
# set the default image size here |
|
input_size = 448 |
|
path, target = self.samples[index] |
|
sample = self.loader(path) |
|
# resize the image |
|
sample = sample.resize((input_size, input_size)) |
|
# return the image and label |
|
return sample, target |
|
|
|
transforms = |
|
# create the dataset and loader |
|
train_dataset = MultiscaleDataSet( |
|
root="data/train", |
|
transform=transform |
|
) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_sampler=BatchSampler(RandomSampler(train_dataset), |
|
batch_size=batch_size, |
|
multiscale_step=1, |
|
drop_last=True, |
|
img_sizes=[320, 384, 448, 512, 576, 640]), |
|
num_workers=7, |
|
) |