diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 9c2a3102fe10372bcdadd834ce26e532a2e2acab..6da3476e18c0cd09f68bbb96abb3753f27bf405e 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -490,11 +490,12 @@ class PaddleGraph(object): gen_codes( comment_list, indent=1)) + use_structured_name = False if self.source_type in ["tf", "onnx"] else True self.run_func.extend( gen_codes(["paddle.disable_static()", "params, _ = fluid.load_dygraph('{}/model')".format(code_dir), "model = {}()".format(self.name), - "model.set_dict(params)", + "model.set_dict(params, use_structured_name={})".format(use_structured_name),, "model.eval()", "out = model({})".format(input_data_name), "return out"], indent=1)) @@ -624,7 +625,7 @@ class PaddleGraph(object): paddle.disable_static() restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) model = getattr(x2paddle_code, self.name)() - if self.source_type == "tf": + if self.source_type in ["tf", "onnx"]: model.set_dict(restore, use_structured_name=False) else: model.set_dict(restore)