提交 186683aa 编写于 作者: J jiangjiajun

add build_transforms_v1 for old version paddlex

上级 4a2d8927
...@@ -28,7 +28,12 @@ def load_model(model_dir): ...@@ -28,7 +28,12 @@ def load_model(model_dir):
raise Exception("There's not model.yml in {}".format(model_dir)) raise Exception("There's not model.yml in {}".format(model_dir))
with open(osp.join(model_dir, "model.yml")) as f: with open(osp.join(model_dir, "model.yml")) as f:
info = yaml.load(f.read(), Loader=yaml.Loader) info = yaml.load(f.read(), Loader=yaml.Loader)
status = info['status']
if 'status' in info:
status = info['status']
elif 'save_method' in info:
# 兼容老版本PaddleX
status = info['save_method']
if not hasattr(paddlex.cv.models, info['Model']): if not hasattr(paddlex.cv.models, info['Model']):
raise Exception("There's no attribute {} in paddlex.cv.models".format( raise Exception("There's no attribute {} in paddlex.cv.models".format(
...@@ -40,7 +45,7 @@ def load_model(model_dir): ...@@ -40,7 +45,7 @@ def load_model(model_dir):
model = getattr(paddlex.cv.models, model = getattr(paddlex.cv.models,
info['Model'])(**info['_init_params']) info['Model'])(**info['_init_params'])
if status == "Normal" or \ if status == "Normal" or \
status == "Prune": status == "Prune" or status == "fluid.save":
startup_prog = fluid.Program() startup_prog = fluid.Program()
model.test_prog = fluid.Program() model.test_prog = fluid.Program()
with fluid.program_guard(model.test_prog, startup_prog): with fluid.program_guard(model.test_prog, startup_prog):
...@@ -59,7 +64,7 @@ def load_model(model_dir): ...@@ -59,7 +64,7 @@ def load_model(model_dir):
fluid.io.set_program_state(model.test_prog, load_dict) fluid.io.set_program_state(model.test_prog, load_dict)
elif status == "Infer" or \ elif status == "Infer" or \
status == "Quant": status == "Quant" or status == "fluid.save_inference_model":
[prog, input_names, outputs] = fluid.io.load_inference_model( [prog, input_names, outputs] = fluid.io.load_inference_model(
model_dir, model.exe, params_filename='__params__') model_dir, model.exe, params_filename='__params__')
model.test_prog = prog model.test_prog = prog
...@@ -77,9 +82,15 @@ def load_model(model_dir): ...@@ -77,9 +82,15 @@ def load_model(model_dir):
to_rgb = True to_rgb = True
else: else:
to_rgb = False to_rgb = False
model.test_transforms = build_transforms(model.model_type, if 'BatchTransforms' in info:
info['Transforms'], to_rgb) # 兼容老版本PaddleX模型
model.eval_transforms = copy.deepcopy(model.test_transforms) model.test_transforms = build_transforms_v1(
model.model_type, info['Transforms'], info['BatchTransforms'])
model.eval_transforms = copy.deepcopy(model.test_transforms)
else:
model.test_transforms = build_transforms(
model.model_type, info['Transforms'], to_rgb)
model.eval_transforms = copy.deepcopy(model.test_transforms)
if '_Attributes' in info: if '_Attributes' in info:
for k, v in info['_Attributes'].items(): for k, v in info['_Attributes'].items():
...@@ -109,3 +120,46 @@ def build_transforms(model_type, transforms_info, to_rgb=True): ...@@ -109,3 +120,46 @@ def build_transforms(model_type, transforms_info, to_rgb=True):
eval_transforms = T.Compose(transforms) eval_transforms = T.Compose(transforms)
eval_transforms.to_rgb = to_rgb eval_transforms.to_rgb = to_rgb
return eval_transforms return eval_transforms
def build_transforms_v1(model_type, transforms_info, batch_transforms_info):
""" 老版本模型加载,仅支持PaddleX前端导出的模型
"""
logging.debug("Use build_transforms_v1 to reconstruct transforms")
if model_type == "classifier":
import paddlex.cv.transforms.cls_transforms as T
elif model_type == "detector":
import paddlex.cv.transforms.det_transforms as T
elif model_type == "segmenter":
import paddlex.cv.transforms.seg_transforms as T
transforms = list()
for op_info in transforms_info:
op_name = op_info[0]
op_attr = op_info[1]
if op_name == 'DecodeImage':
continue
if op_name == 'Permute':
continue
if op_name == 'ResizeByShort':
op_attr_new = dict()
if 'short_size' in op_attr:
op_attr_new['short_size'] = op_attr['short_size']
else:
op_attr_new['short_size'] = op_attr['target_size']
op_attr_new['max_size'] = op_attr.get('max_size', -1)
op_attr = op_attr_new
if op_name.startswith('Arrange'):
continue
if not hasattr(T, op_name):
raise Exception(
"There's no operator named '{}' in transforms of {}".format(
op_name, model_type))
transforms.append(getattr(T, op_name)(**op_attr))
if model_type == "detector" and len(batch_transforms_info) > 0:
op_name = batch_transforms_info[0][0]
op_attr = batch_transforms_info[0][1]
assert op_name == "PaddingMiniBatch", "Only PaddingMiniBatch transform is supported for batch transform"
padding = T.Padding(coarsest_stride=op_attr['coarsest_stride'])
transforms.append(padding)
eval_transforms = T.Compose(transforms)
return eval_transforms
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册