diff --git a/paddlex/cv/transforms/det_transforms.py b/paddlex/cv/transforms/det_transforms.py index 07b2785a5e6194b9d7c2cd6ab02ea15c7af693ab..19db33173b87b7cc20b87054cfbc1241176abc58 100644 --- a/paddlex/cv/transforms/det_transforms.py +++ b/paddlex/cv/transforms/det_transforms.py @@ -156,7 +156,6 @@ class Compose(DetTransform): if not isinstance(augmenters, list): raise Exception( "augmenters should be list type in func add_augmenters()") - assert mode == 'train', "There should be exists augmenters while on train mode" self.transforms = augmenters + self.transforms.transforms diff --git a/paddlex/cv/transforms/seg_transforms.py b/paddlex/cv/transforms/seg_transforms.py index dd2b250b5b9cf80ff73b5cbb873f45fad183de27..d3c67648d500d915315c5607cfc5c2f5538a9090 100644 --- a/paddlex/cv/transforms/seg_transforms.py +++ b/paddlex/cv/transforms/seg_transforms.py @@ -108,6 +108,12 @@ class Compose(SegTransform): outputs = (im, im_info) return outputs + def add_augmenters(self, augmenters): + if not isinstance(augmenters, list): + raise Exception( + "augmenters should be list type in func add_augmenters()") + self.transforms = augmenters + self.transforms.transforms + class RandomHorizontalFlip(SegTransform): """以一定的概率对图像进行水平翻转。当存在标注图像时,则同步进行翻转。