diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index 88c71c80ddd74270a503b880be842e1c50b6f3db..89c3bbc75e15cdd7b1af4839798e33e114107c5f 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -144,7 +144,7 @@ class YOLOv3(BaseAPI): iou_aware_factor=self.iou_aware_factor, use_drop_block=self.use_drop_block, batch_size=self.train_batch_size, - max_shape=self.max_shape if hasattr(self, 'max_shape') else 608) + max_shape=self.max_shape if hasattr(self, 'max_shape') else None) inputs = model.generate_inputs() model_out = model.build_net(inputs) outputs = OrderedDict([('bbox', model_out)]) @@ -266,6 +266,7 @@ class YOLOv3(BaseAPI): if isinstance(bt, paddlex.det.transforms.BatchRandomShape): self.max_shape = max(bt.random_shapes) break + self.init_params['max_shape'] = max_shape iou_bt = paddlex.det.transforms.GenerateYoloTarget train_dataset.transforms.batch_transforms.append(iou_bt(anchors=self.anchors, anchor_masks=self.anchor_masks,