提交 5a7cd9c6 编写于 作者: W wangguanzhong 提交者: GitHub

add multiscale-training & refine train+eval (#2911)

上级 b0683058
......@@ -127,18 +127,24 @@ class ResizeImage(BaseOperator):
use_cv2=True):
"""
Args:
target_size (int): the taregt size of image's short side
target_size (int|list): the target size of image's short side,
multi-scale training is adopted when type is list.
max_size (int): the max size of image
interp (int): the interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL interpolation method
use_cv2 (bool): use the cv2 interpolation method or use PIL
interpolation method
"""
super(ResizeImage, self).__init__()
self.target_size = int(target_size)
self.max_size = int(max_size)
self.interp = int(interp)
self.use_cv2 = use_cv2
if not (isinstance(self.target_size, int) and isinstance(
self.max_size, int) and isinstance(self.interp, int)):
if not (isinstance(target_size, int) or isinstance(target_size, list)):
raise TypeError(
"Type of target_size is invalid. Must be Integer or List, now is {}".
format(type(target_size)))
self.target_size = target_size
if not (isinstance(self.max_size, int) and isinstance(self.interp,
int)):
raise TypeError("{}: input type is invalid.".format(self))
def __call__(self, sample, context=None):
......@@ -152,10 +158,15 @@ class ResizeImage(BaseOperator):
im_shape = im.shape
im_size_min = np.min(im_shape[0:2])
im_size_max = np.max(im_shape[0:2])
if isinstance(self.target_size, list):
# Case for multi-scale training
selected_size = random.choice(self.target_size)
else:
selected_size = self.target_size
if float(im_size_min) == 0:
raise ZeroDivisionError('{}: min size of image is 0'.format(self))
if self.max_size != 0:
im_scale = float(self.target_size) / float(im_size_min)
im_scale = float(selected_size) / float(im_size_min)
# Prevent the biggest axis from being more than max_size
if np.round(im_scale * im_size_max) > self.max_size:
im_scale = float(self.max_size) / float(im_size_max)
......@@ -168,8 +179,8 @@ class ResizeImage(BaseOperator):
],
dtype=np.float32)
else:
im_scale_x = float(self.target_size) / float(im_shape[1])
im_scale_y = float(self.target_size) / float(im_shape[0])
im_scale_x = float(selected_size) / float(im_shape[1])
im_scale_y = float(selected_size) / float(im_shape[0])
if self.use_cv2:
im = cv2.resize(
im,
......@@ -180,7 +191,9 @@ class ResizeImage(BaseOperator):
interpolation=self.interp)
else:
im = Image.fromarray(im)
im = im.resize((self.target_size, self.target_size), self.interp)
resize_w = selected_size * im_scale_x
resize_h = selected_size * im_scale_y
im = im.resize((resize_w, resize_h), self.interp)
im = np.array(im)
sample['image'] = im
......
......@@ -171,8 +171,9 @@ def main():
it, np.mean(outs[-1]), logs, end_time - start_time)
logger.info(strs)
if it > 0 and it % cfg.snapshot_iter == 0:
checkpoint.save(exe, train_prog, os.path.join(save_dir, str(it)))
if it > 0 and it % cfg.snapshot_iter == 0 or it == cfg.max_iters - 1:
save_name = str(it) if it != cfg.max_iters - 1 else "model_final"
checkpoint.save(exe, train_prog, os.path.join(save_dir, save_name))
if FLAGS.eval:
# evaluation
......@@ -184,7 +185,6 @@ def main():
eval_results(results, eval_feed, cfg.metric, resolution,
FLAGS.output_file)
checkpoint.save(exe, train_prog, os.path.join(save_dir, "model_final"))
train_pyreader.reset()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册