From e5b030126e1192f68496e352e52d8fde6b175a43 Mon Sep 17 00:00:00 2001 From: FlyingQianMM <245467267@qq.com> Date: Thu, 9 Jul 2020 10:58:58 +0800 Subject: [PATCH] delete self.arrange_transforms --- paddlex/cv/models/classifier.py | 7 +++++-- paddlex/cv/models/faster_rcnn.py | 7 +++++-- paddlex/cv/models/mask_rcnn.py | 7 +++++-- paddlex/cv/models/yolo_v3.py | 7 +++++-- 4 files changed, 20 insertions(+), 8 deletions(-) diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index feace07..a58b8f3 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -223,8 +223,11 @@ class BaseClassifier(BaseAPI): tuple (metrics, eval_details): 当return_details为True时,增加返回dict, 包含关键字:'true_labels'、'pred_scores',分别代表真实类别id、每个类别的预测得分。 """ - self.arrange_transforms( - transforms=eval_dataset.transforms, mode='eval') + arrange_transforms( + model_type=self.model_type, + class_name=self.__class__.__name__, + transforms=eval_dataset.transforms, + mode='eval') data_generator = eval_dataset.generator( batch_size=batch_size, drop_last=False) k = min(5, self.num_classes) diff --git a/paddlex/cv/models/faster_rcnn.py b/paddlex/cv/models/faster_rcnn.py index e8a6194..68e2d32 100644 --- a/paddlex/cv/models/faster_rcnn.py +++ b/paddlex/cv/models/faster_rcnn.py @@ -312,8 +312,11 @@ class FasterRCNN(BaseAPI): eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。 """ - self.arrange_transforms( - transforms=eval_dataset.transforms, mode='eval') + arrange_transforms( + model_type=self.model_type, + class_name=self.__class__.__name__, + transforms=eval_dataset.transforms, + mode='eval') if metric is None: if hasattr(self, 'metric') and self.metric is not None: metric = self.metric diff --git a/paddlex/cv/models/mask_rcnn.py b/paddlex/cv/models/mask_rcnn.py index 3c7fd5c..e0ffb00 100644 --- a/paddlex/cv/models/mask_rcnn.py +++ b/paddlex/cv/models/mask_rcnn.py @@ -254,8 +254,11 @@ class MaskRCNN(FasterRCNN): 预测框坐标、预测框得分;'mask',对应元素预测区域结果列表,每个预测结果由图像id、 预测区域类别id、预测区域坐标、预测区域得分;’gt‘:真实标注框和标注区域相关信息。 """ - self.arrange_transforms( - transforms=eval_dataset.transforms, mode='eval') + arrange_transforms( + model_type=self.model_type, + class_name=self.__class__.__name__, + transforms=eval_dataset.transforms, + mode='eval') if metric is None: if hasattr(self, 'metric') and self.metric is not None: metric = self.metric diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index dce18af..2cfce7a 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -289,8 +289,11 @@ class YOLOv3(BaseAPI): eval_details为dict,包含关键字:'bbox',对应元素预测结果列表,每个预测结果由图像id、 预测框类别id、预测框坐标、预测框得分;’gt‘:真实标注框相关信息。 """ - self.arrange_transforms( - transforms=eval_dataset.transforms, mode='eval') + arrange_transforms( + model_type=self.model_type, + class_name=self.__class__.__name__, + transforms=eval_dataset.transforms, + mode='eval') if metric is None: if hasattr(self, 'metric') and self.metric is not None: metric = self.metric -- GitLab