diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index 4ad13c09528b3dc8b15377436ad28a0980a69f4d..88c71c80ddd74270a503b880be842e1c50b6f3db 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -144,7 +144,7 @@ class YOLOv3(BaseAPI): iou_aware_factor=self.iou_aware_factor, use_drop_block=self.use_drop_block, batch_size=self.train_batch_size, - max_shape=self.max_shape) + max_shape=self.max_shape if hasattr(self, 'max_shape') else 608) inputs = model.generate_inputs() model_out = model.build_net(inputs) outputs = OrderedDict([('bbox', model_out)]) @@ -253,6 +253,7 @@ class YOLOv3(BaseAPI): if isinstance(transform, paddlex.det.transforms.Normalize): transform.is_scale = False if self.use_iou_loss or self.use_iou_aware_loss: + self.init_params['train_batch_size'] = train_batch_size self.max_shape = 0 for transform in train_dataset.transforms.transforms: if isinstance(transform, paddlex.det.transforms.Resize):