提交 f02f0476 编写于 作者: J jiangjiajun

modify load model

上级 2868f100
......@@ -129,9 +129,7 @@ class BaseClassifier(BaseAPI):
ValueError: 模型从inference model进行加载。
"""
if not self.trainable:
raise ValueError(
"Model is not trainable since it was loaded from a inference model."
)
raise ValueError("Model is not trainable from load_model method.")
self.labels = train_dataset.labels
if optimizer is None:
num_steps_each_epoch = train_dataset.num_samples // train_batch_size
......@@ -300,17 +298,18 @@ class ResNet101_vd(BaseClassifier):
def __init__(self, num_classes=1000):
super(ResNet101_vd, self).__init__(
model_name='ResNet101_vd', num_classes=num_classes)
class ResNet50_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(ResNet50_vd_ssld, self).__init__(model_name='ResNet50_vd_ssld',
num_classes=num_classes)
super(ResNet50_vd_ssld, self).__init__(
model_name='ResNet50_vd_ssld', num_classes=num_classes)
class ResNet101_vd_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(ResNet101_vd_ssld, self).__init__(model_name='ResNet101_vd_ssld',
num_classes=num_classes)
super(ResNet101_vd_ssld, self).__init__(
model_name='ResNet101_vd_ssld', num_classes=num_classes)
class DarkNet53(BaseClassifier):
......@@ -341,19 +340,18 @@ class MobileNetV3_large(BaseClassifier):
def __init__(self, num_classes=1000):
super(MobileNetV3_large, self).__init__(
model_name='MobileNetV3_large', num_classes=num_classes)
class MobileNetV3_small_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(MobileNetV3_small_ssld, self).__init__(model_name='MobileNetV3_small_ssld',
num_classes=num_classes)
super(MobileNetV3_small_ssld, self).__init__(
model_name='MobileNetV3_small_ssld', num_classes=num_classes)
class MobileNetV3_large_ssld(BaseClassifier):
def __init__(self, num_classes=1000):
super(MobileNetV3_large_ssld, self).__init__(model_name='MobileNetV3_large_ssld',
num_classes=num_classes)
super(MobileNetV3_large_ssld, self).__init__(
model_name='MobileNetV3_large_ssld', num_classes=num_classes)
class Xception65(BaseClassifier):
......
......@@ -257,9 +257,7 @@ class DeepLabv3p(BaseAPI):
ValueError: 模型从inference model进行加载。
"""
if not self.trainable:
raise ValueError(
"Model is not trainable since it was loaded from a inference model."
)
raise ValueError("Model is not trainable from load_model method.")
self.labels = train_dataset.labels
......
......@@ -203,9 +203,7 @@ class FasterRCNN(BaseAPI):
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric
if not self.trainable:
raise ValueError(
"Model is not trainable since it was loaded from a inference model."
)
raise ValueError("Model is not trainable from load_model method.")
self.labels = copy.deepcopy(train_dataset.labels)
self.labels.insert(0, 'background')
# 构建训练网络
......
......@@ -98,6 +98,7 @@ def load_model(model_dir):
model.__dict__[k] = v
logging.info("Model[{}] loaded.".format(info['Model']))
model.trainable = False
return model
......
......@@ -165,9 +165,7 @@ class MaskRCNN(FasterRCNN):
assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
self.metric = metric
if not self.trainable:
raise Exception(
"Model is not trainable since it was loaded from a inference model."
)
raise Exception("Model is not trainable from load_model method.")
self.labels = copy.deepcopy(train_dataset.labels)
self.labels.insert(0, 'background')
# 构建训练网络
......
......@@ -194,9 +194,7 @@ class YOLOv3(BaseAPI):
ValueError: 模型从inference model进行加载。
"""
if not self.trainable:
raise ValueError(
"Model is not trainable since it was loaded from a inference model."
)
raise ValueError("Model is not trainable from load_model method.")
if metric is None:
if isinstance(train_dataset, paddlex.datasets.CocoDetection):
metric = 'COCO'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册