提交 2621ae8e 编写于 作者: S sunyanfang01

add init params

上级 14f72db0
...@@ -144,7 +144,7 @@ class YOLOv3(BaseAPI): ...@@ -144,7 +144,7 @@ class YOLOv3(BaseAPI):
iou_aware_factor=self.iou_aware_factor, iou_aware_factor=self.iou_aware_factor,
use_drop_block=self.use_drop_block, use_drop_block=self.use_drop_block,
batch_size=self.train_batch_size, batch_size=self.train_batch_size,
max_shape=self.max_shape) max_shape=self.max_shape if hasattr(self, 'max_shape') else 608)
inputs = model.generate_inputs() inputs = model.generate_inputs()
model_out = model.build_net(inputs) model_out = model.build_net(inputs)
outputs = OrderedDict([('bbox', model_out)]) outputs = OrderedDict([('bbox', model_out)])
...@@ -253,6 +253,7 @@ class YOLOv3(BaseAPI): ...@@ -253,6 +253,7 @@ class YOLOv3(BaseAPI):
if isinstance(transform, paddlex.det.transforms.Normalize): if isinstance(transform, paddlex.det.transforms.Normalize):
transform.is_scale = False transform.is_scale = False
if self.use_iou_loss or self.use_iou_aware_loss: if self.use_iou_loss or self.use_iou_aware_loss:
self.init_params['train_batch_size'] = train_batch_size
self.max_shape = 0 self.max_shape = 0
for transform in train_dataset.transforms.transforms: for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Resize): if isinstance(transform, paddlex.det.transforms.Resize):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册