提交 612acfa8 编写于 作者: S sunyanfang01

modify transform

上级 312eeca0
......@@ -249,6 +249,10 @@ class YOLOv3(BaseAPI):
self.train_batch_size = train_batch_size
self.labels = train_dataset.labels
if pretrain_weights == "Object365":
for transform in train_dataset.transforms.transforms:
if isinstance(transform, paddlex.det.transforms.Normalize):
transform.is_scale = False
if self.use_iou_loss or self.use_iou_aware_loss:
if self.train_random_shapes is None or len(self.train_random_shapes) == 0:
for transform in train_dataset.transforms.transforms:
......
......@@ -494,6 +494,7 @@ class Normalize(DetTransform):
Args:
mean (list): 图像数据集的均值。默认为[0.485, 0.456, 0.406]。
std (list): 图像数据集的标准差。默认为[0.229, 0.224, 0.225]。
is_scale (bool): 是否对图像归一化。默认为True。
Raises:
TypeError: 形参数据类型不满足需求。
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册