提交 92001121 编写于 作者: S SunAhong1993

fix the conflict

上级 adc64c7c
...@@ -490,11 +490,12 @@ class PaddleGraph(object): ...@@ -490,11 +490,12 @@ class PaddleGraph(object):
gen_codes( gen_codes(
comment_list, comment_list,
indent=1)) indent=1))
use_structured_name = False if self.source_type in ["tf", "onnx"] else True
self.run_func.extend( self.run_func.extend(
gen_codes(["paddle.disable_static()", gen_codes(["paddle.disable_static()",
"params, _ = fluid.load_dygraph('{}/model')".format(code_dir), "params, _ = fluid.load_dygraph('{}/model')".format(code_dir),
"model = {}()".format(self.name), "model = {}()".format(self.name),
"model.set_dict(params)", "model.set_dict(params, use_structured_name={})".format(use_structured_name),,
"model.eval()", "model.eval()",
"out = model({})".format(input_data_name), "out = model({})".format(input_data_name),
"return out"], indent=1)) "return out"], indent=1))
...@@ -624,7 +625,7 @@ class PaddleGraph(object): ...@@ -624,7 +625,7 @@ class PaddleGraph(object):
paddle.disable_static() paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) restore, _ = fluid.load_dygraph(osp.join(save_dir, "model"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type == "tf": if self.source_type in ["tf", "onnx"]:
model.set_dict(restore, use_structured_name=False) model.set_dict(restore, use_structured_name=False)
else: else:
model.set_dict(restore) model.set_dict(restore)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册