diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 2be4b3fe8af4d76e5261ee1f8c5abdea55158661..607c0a6b53013dd89b22cd3e94a1b7551c0b0649 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -757,9 +757,9 @@ class RandomExpand: return (im, im_info, label_info) y = np.random.randint(0, h - height) x = np.random.randint(0, w - width) - canvas = np.ones((h, w, 3), dtype=np.uint8) - canvas *= np.array(self.fill_value, dtype=np.uint8) - canvas[y:y + height, x:x + width, :] = im.astype(np.uint8) + canvas = np.ones((h, w, 3), dtype=np.float32) + canvas *= np.array(self.fill_value, dtype=np.float32) + canvas[y:y + height, x:x + width, :] = im im_info['augment_shape'] = np.array([h, w]).astype('int32') if 'gt_bbox' in label_info and len(label_info['gt_bbox']) > 0: