diff --git a/paddlex/cv/models/classifier.py b/paddlex/cv/models/classifier.py index ba3e868b51c38a29e369bb187d9f1099e985922a..65a594ec3a7caee401cdf239f861a0e0e98667d9 100644 --- a/paddlex/cv/models/classifier.py +++ b/paddlex/cv/models/classifier.py @@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI): ValueError: 模型从inference model进行加载。 """ if not self.trainable: - raise ValueError( - "Model is not trainable since it was loaded from a inference model." - ) + raise ValueError("Model is not trainable from load_model method.") self.labels = train_dataset.labels if optimizer is None: num_steps_each_epoch = train_dataset.num_samples // train_batch_size @@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier): def __init__(self, num_classes=1000): super(ResNet101_vd, self).__init__( model_name='ResNet101_vd', num_classes=num_classes) - - + + class ResNet50_vd_ssld(BaseClassifier): def __init__(self, num_classes=1000): - super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld', - num_classes=num_classes) - + super(ResNet50_vd_ssld, self).__init__( + model_name='ResNet50_vd_ssld', num_classes=num_classes) + + class ResNet101_vd_ssld(BaseClassifier): def __init__(self, num_classes=1000): - super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld', - num_classes=num_classes) + super(ResNet101_vd_ssld, self).__init__( + model_name='ResNet101_vd_ssld', num_classes=num_classes) class DarkNet53(BaseClassifier): @@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier): def __init__(self, num_classes=1000): super(MobileNetV3_large, self).__init__( model_name='MobileNetV3_large', num_classes=num_classes) - - - + + class MobileNetV3_small_ssld(BaseClassifier): def __init__(self, num_classes=1000): - super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld', - num_classes=num_classes) + super(MobileNetV3_small_ssld, self).__init__( + model_name='MobileNetV3_small_ssld', num_classes=num_classes) class MobileNetV3_large_ssld(BaseClassifier): def __init__(self, num_classes=1000): - super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld', - num_classes=num_classes) + super(MobileNetV3_large_ssld, self).__init__( + model_name='MobileNetV3_large_ssld', num_classes=num_classes) class Xception65(BaseClassifier): diff --git a/paddlex/cv/models/deeplabv3p.py b/paddlex/cv/models/deeplabv3p.py index 8f46baf31542ef58adc8d1c07eae52755dd09e04..0fa0ac195cb7b05de1fc16a6f2ee2b300155389f 100644 --- a/paddlex/cv/models/deeplabv3p.py +++ b/paddlex/cv/models/deeplabv3p.py @@ -257,9 +257,7 @@ class DeepLabv3p(BaseAPI): ValueError: 模型从inference model进行加载。 """ if not self.trainable: - raise ValueError( - "Model is not trainable since it was loaded from a inference model." - ) + raise ValueError("Model is not trainable from load_model method.") self.labels = train_dataset.labels diff --git a/paddlex/cv/models/faster_rcnn.py b/paddlex/cv/models/faster_rcnn.py index 3d585987de23f6fcc097769f29690fe7f6de9cee..47dbd75696369fd813438eb20db49a45024b6fc7 100644 --- a/paddlex/cv/models/faster_rcnn.py +++ b/paddlex/cv/models/faster_rcnn.py @@ -203,9 +203,7 @@ class FasterRCNN(BaseAPI): assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" self.metric = metric if not self.trainable: - raise ValueError( - "Model is not trainable since it was loaded from a inference model." - ) + raise ValueError("Model is not trainable from load_model method.") self.labels = copy.deepcopy(train_dataset.labels) self.labels.insert(0, 'background') # 构建训练网络 diff --git a/paddlex/cv/models/load_model.py b/paddlex/cv/models/load_model.py index 271be123325b71304ac7b796017158e672dc6b3f..2469ef2c9094dcba1ff12d234b7ebcd7b6bdc779 100644 --- a/paddlex/cv/models/load_model.py +++ b/paddlex/cv/models/load_model.py @@ -98,6 +98,7 @@ def load_model(model_dir): model.__dict__[k] = v logging.info("Model[{}] loaded.".format(info['Model'])) + model.trainable = False return model diff --git a/paddlex/cv/models/mask_rcnn.py b/paddlex/cv/models/mask_rcnn.py index b30dd1e00a2856c79ac179c01967d1cddf053122..bfdc9f1092e7ce82cecb869dcd5364a0d34aff2e 100644 --- a/paddlex/cv/models/mask_rcnn.py +++ b/paddlex/cv/models/mask_rcnn.py @@ -165,9 +165,7 @@ class MaskRCNN(FasterRCNN): assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" self.metric = metric if not self.trainable: - raise Exception( - "Model is not trainable since it was loaded from a inference model." - ) + raise Exception("Model is not trainable from load_model method.") self.labels = copy.deepcopy(train_dataset.labels) self.labels.insert(0, 'background') # 构建训练网络 diff --git a/paddlex/cv/models/yolo_v3.py b/paddlex/cv/models/yolo_v3.py index 80205238ff181be75f76fdbf32b8a1e99c497c9c..75658547f537b046a95ae290a7799b803f3de502 100644 --- a/paddlex/cv/models/yolo_v3.py +++ b/paddlex/cv/models/yolo_v3.py @@ -194,9 +194,7 @@ class YOLOv3(BaseAPI): ValueError: 模型从inference model进行加载。 """ if not self.trainable: - raise ValueError( - "Model is not trainable since it was loaded from a inference model." - ) + raise ValueError("Model is not trainable from load_model method.") if metric is None: if isinstance(train_dataset, paddlex.datasets.CocoDetection): metric = 'COCO'