提交 f02f0476 编写于 作者: J jiangjiajun

modify load model

上级 2868f100
...@@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI): ...@@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI):
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
if not self.trainable: if not self.trainable:
raise ValueError( raise ValueError("Model is not trainable from load_model method.")
"Model is not trainable since it was loaded from a inference model."
)
self.labels = train_dataset.labels self.labels = train_dataset.labels
if optimizer is None: if optimizer is None:
num_steps_each_epoch = train_dataset.num_samples // train_batch_size num_steps_each_epoch = train_dataset.num_samples // train_batch_size
...@@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier): ...@@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(ResNet101_vd, self).__init__( super(ResNet101_vd, self).__init__(
model_name='ResNet101_vd', num_classes=num_classes) model_name='ResNet101_vd', num_classes=num_classes)
class ResNet50_vd_ssld(BaseClassifier): class ResNet50_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld', super(ResNet50_vd_ssld, self).__init__(
num_classes=num_classes) model_name='ResNet50_vd_ssld', num_classes=num_classes)
class ResNet101_vd_ssld(BaseClassifier): class ResNet101_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld', super(ResNet101_vd_ssld, self).__init__(
num_classes=num_classes) model_name='ResNet101_vd_ssld', num_classes=num_classes)
class DarkNet53(BaseClassifier): class DarkNet53(BaseClassifier):
...@@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier): ...@@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(MobileNetV3_large, self).__init__( super(MobileNetV3_large, self).__init__(
model_name='MobileNetV3_large', num_classes=num_classes) model_name='MobileNetV3_large', num_classes=num_classes)
class MobileNetV3_small_ssld(BaseClassifier): class MobileNetV3_small_ssld(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld', super(MobileNetV3_small_ssld, self).__init__(
num_classes=num_classes) model_name='MobileNetV3_small_ssld', num_classes=num_classes)
class MobileNetV3_large_ssld(BaseClassifier): class MobileNetV3_large_ssld(BaseClassifier):
def __init__(self, num_classes=1000): def __init__(self, num_classes=1000):
super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld', super(MobileNetV3_large_ssld, self).__init__(
num_classes=num_classes) model_name='MobileNetV3_large_ssld', num_classes=num_classes)
class Xception65(BaseClassifier): class Xception65(BaseClassifier):
......
...@@ -257,9 +257,7 @@ class DeepLabv3p(BaseAPI): ...@@ -257,9 +257,7 @@ class DeepLabv3p(BaseAPI):
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
if not self.trainable: if not self.trainable:
raise ValueError( raise ValueError("Model is not trainable from load_model method.")
"Model is not trainable since it was loaded from a inference model."
)
self.labels = train_dataset.labels self.labels = train_dataset.labels
......
...@@ -203,9 +203,7 @@ class FasterRCNN(BaseAPI): ...@@ -203,9 +203,7 @@ class FasterRCNN(BaseAPI):
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric self.metric = metric
if not self.trainable: if not self.trainable:
raise ValueError( raise ValueError("Model is not trainable from load_model method.")
"Model is not trainable since it was loaded from a inference model."
)
self.labels = copy.deepcopy(train_dataset.labels) self.labels = copy.deepcopy(train_dataset.labels)
self.labels.insert(0, 'background') self.labels.insert(0, 'background')
# 构建训练网络 # 构建训练网络
......
...@@ -98,6 +98,7 @@ def load_model(model_dir): ...@@ -98,6 +98,7 @@ def load_model(model_dir):
model.__dict__[k] = v model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(info['Model'])) logging.info("Model[{}] loaded.".format(info['Model']))
model.trainable = False
return model return model
......
...@@ -165,9 +165,7 @@ class MaskRCNN(FasterRCNN): ...@@ -165,9 +165,7 @@ class MaskRCNN(FasterRCNN):
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'" assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric self.metric = metric
if not self.trainable: if not self.trainable:
raise Exception( raise Exception("Model is not trainable from load_model method.")
"Model is not trainable since it was loaded from a inference model."
)
self.labels = copy.deepcopy(train_dataset.labels) self.labels = copy.deepcopy(train_dataset.labels)
self.labels.insert(0, 'background') self.labels.insert(0, 'background')
# 构建训练网络 # 构建训练网络
......
...@@ -194,9 +194,7 @@ class YOLOv3(BaseAPI): ...@@ -194,9 +194,7 @@ class YOLOv3(BaseAPI):
ValueError: 模型从inference model进行加载。 ValueError: 模型从inference model进行加载。
""" """
if not self.trainable: if not self.trainable:
raise ValueError( raise ValueError("Model is not trainable from load_model method.")
"Model is not trainable since it was loaded from a inference model."
)
if metric is None: if metric is None:
if isinstance(train_dataset, paddlex.datasets.CocoDetection): if isinstance(train_dataset, paddlex.datasets.CocoDetection):
metric = 'COCO' metric = 'COCO'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册