未验证 提交 f8150efe 编写于 作者: G gaotingquan

fix: (h, w) -> (w, h)

上级 e4380ce5
......@@ -72,11 +72,11 @@ class MultiScaleDataset(Dataset):
for resize in resize_op:
if resize in op:
if self.has_crop_flag:
logger.error(
logger.warning(
"Multi scale dataset will crop image according to the multi scale resolution"
)
self.transform_ops[i][resize] = {
'size': (img_height, img_width)
'size': (img_width, img_height)
}
has_crop = True
self.has_crop_flag = 0
......
......@@ -64,11 +64,11 @@ class MultiScaleSampler(Sampler):
base_elements = base_im_w * base_im_h * base_batch_size
for (h, w) in zip(height_dims, width_dims):
batch_size = int(max(1, (base_elements / (h * w))))
img_batch_pairs.append((h, w, batch_size))
img_batch_pairs.append((w, h, batch_size))
self.img_batch_pairs = img_batch_pairs
self.shuffle = True
else:
self.img_batch_pairs = [(base_im_h, base_im_w, base_batch_size)]
self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)]
self.img_indices = img_indices
self.n_samples_per_replica = num_samples_per_replica
......@@ -81,7 +81,7 @@ class MultiScaleSampler(Sampler):
indices_rank_i = self.img_indices[self.rank:len(self.img_indices):
self.num_replicas]
while self.current < self.n_samples_per_replica:
curr_h, curr_w, curr_bsz = random.choice(self.img_batch_pairs)
curr_w, curr_h, curr_bsz = random.choice(self.img_batch_pairs)
end_index = min(self.current + curr_bsz,
self.n_samples_per_replica)
......@@ -93,7 +93,7 @@ class MultiScaleSampler(Sampler):
self.current += curr_bsz
if len(batch_ids) > 0:
batch = [curr_h, curr_w, len(batch_ids)]
batch = [curr_w, curr_h, len(batch_ids)]
self.batch_list.append(batch)
self.length = len(self.batch_list)
......@@ -113,7 +113,7 @@ class MultiScaleSampler(Sampler):
start_index = 0
for batch_tuple in self.batch_list:
curr_h, curr_w, curr_bsz = batch_tuple
curr_w, curr_h, curr_bsz = batch_tuple
end_index = min(start_index + curr_bsz, self.n_samples_per_replica)
batch_ids = indices_rank_i[start_index:end_index]
n_batch_samples = len(batch_ids)
......@@ -122,7 +122,7 @@ class MultiScaleSampler(Sampler):
start_index += curr_bsz
if len(batch_ids) > 0:
batch = [(curr_h, curr_w, b_id) for b_id in batch_ids]
batch = [(curr_w, curr_h, b_id) for b_id in batch_ids]
yield batch
def set_epoch(self, epoch: int):
......
......@@ -50,22 +50,17 @@ class UnifiedResize(object):
}
def _pil_resize(src, size, resample):
# to be accordance with opencv, the input size is (h,w)
pil_img = Image.fromarray(src)
pil_img = pil_img.resize(size, resample)
return np.asarray(pil_img)
def _cv2_resize(src, size, interpolation):
cv_img = cv2.resize(src, size[::-1], interpolation)
return cv_img
if backend.lower() == "cv2":
if isinstance(interpolation, str):
interpolation = _cv2_interp_from_str[interpolation.lower()]
# compatible with opencv < version 4.4.0
elif interpolation is None:
interpolation = cv2.INTER_LINEAR
self.resize_func = partial(_cv2_resize, interpolation=interpolation)
self.resize_func = partial(cv2.resize, interpolation=interpolation)
elif backend.lower() == "pil":
if isinstance(interpolation, str):
interpolation = _pil_interp_from_str[interpolation.lower()]
......@@ -128,8 +123,8 @@ class ResizeImage(object):
self.h = None
elif size is not None:
self.resize_short = None
self.h = size if type(size) is int else size[0]
self.w = size if type(size) is int else size[1]
self.w = size if type(size) is int else size[0]
self.h = size if type(size) is int else size[1]
else:
raise OperatorParamError("invalid params for ReisizeImage for '\
'both 'size' and 'resize_short' are None")
......@@ -146,7 +141,7 @@ class ResizeImage(object):
else:
w = self.w
h = self.h
return self._resize_func(img, (h, w))
return self._resize_func(img, (w, h))
class CropImage(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册