From 4c8b9537fe84d14c3692d4fb9ff7fbbcee7f2915 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 4 Mar 2021 17:34:49 +0800 Subject: [PATCH] fix the path str --- x2paddle/core/program.py | 2 +- x2paddle/optimizer/pytorch_code_optimizer/hierachical_tree.py | 2 +- x2paddle/optimizer/pytorch_code_optimizer/module_graph.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index ddc1519..e32c07a 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -518,7 +518,7 @@ class PaddleGraph(object): use_structured_name = False if self.source_type in ["tf"] else True self.run_func.extend( gen_codes(["paddle.disable_static()", - "params = paddle.load('{}/model.pdparams')".format(osp.abspath(code_dir)), + "params = paddle.load('{}')".format(osp.join(osp.abspath(code_dir), "model.pdparams")), "model = {}()".format(self.name), "model.set_dict(params, use_structured_name={})".format(use_structured_name), "model.eval()", diff --git a/x2paddle/optimizer/pytorch_code_optimizer/hierachical_tree.py b/x2paddle/optimizer/pytorch_code_optimizer/hierachical_tree.py index 5ae708b..a803e34 100644 --- a/x2paddle/optimizer/pytorch_code_optimizer/hierachical_tree.py +++ b/x2paddle/optimizer/pytorch_code_optimizer/hierachical_tree.py @@ -388,7 +388,7 @@ class HierarchicalTree(Tree): run_func_list.append(" # {}: shape-{},type-{}。".format(k, v[0], v[1])) run_func_list.extend( [" paddle.disable_static()", - " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)), + " params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), " model = {}()".format(self.pd_graph.name), " model.set_dict(params)", " model.eval()", diff --git a/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py b/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py index b6fb8f7..9a9d375 100644 --- a/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py +++ b/x2paddle/optimizer/pytorch_code_optimizer/module_graph.py @@ -354,7 +354,7 @@ class ModuleGraph(object): run_func_list.append(" # {}: shape-{},type-{}.".format(k, v[0], v[1])) run_func_list.extend( [" paddle.disable_static()", - " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)), + " params = paddle.load('{}')".format(osp.join(osp.abspath(save_dir), "model.pdparams")), " model = {}()".format(self.pd_graph.name), " model.set_dict(params)", " model.eval()", -- GitLab