diff --git a/tools/cpp_infer.py b/tools/cpp_infer.py index 6e8f0b5b8fc24af4cd2378d786171378bcad01ca..41f6ffe80677b95dbce09943f01922db5c474b3d 100644 --- a/tools/cpp_infer.py +++ b/tools/cpp_infer.py @@ -108,10 +108,11 @@ class Resize(object): self.max_size = max_size self.interp = interp - def __call__(self, im): + def __call__(self, im, arch): origin_shape = im.shape[:2] im_c = im.shape[2] - if self.max_size != 0: + scale_set = {'RCNN', 'RetinaNet'} + if self.max_size != 0 and arch in scale_set: im_size_min = np.min(origin_shape[0:2]) im_size_max = np.max(origin_shape[0:2]) im_scale = float(self.target_size) / float(im_size_min) @@ -132,7 +133,7 @@ class Resize(object): fy=im_scale_y, interpolation=self.interp) # padding im - if self.max_size != 0: + if self.max_size != 0 and arch in scale_set: padding_im = np.zeros( (self.max_size, self.max_size, im_c), dtype=np.float32) im_h, im_w = im.shape[:2] @@ -178,7 +179,7 @@ def Preprocess(img_path, arch, config): obj = data_aug_conf.pop('type') preprocess = eval(obj)(**data_aug_conf) if obj == 'Resize': - img, scale = preprocess(img) + img, scale = preprocess(img, arch) else: img = preprocess(img)