diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 9440cfa964de5891cc7c53ef901bbb542bd8c30d..a559c6a8603771725cfd563765f464670b42f047 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -292,8 +292,8 @@ class PaddleGraph(object): self.dygraph2static(save_dir, input_shapes, input_types) except Error as e: print("The Dygraph2Static is failed! The possible reason are:\n" + - "1. The current model is not supported yet.\n" + - "2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle.py to confirm the convertor of pytorch2paddle is wrong.\n" + + "1. The convertor of dygraph2static of current model is not supported yet.\n" + + "2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle_model.py in your save_dir to check whether the convertor of pytorch2paddle is wrong.\n" + "The Error is: \n" + e) exit(0) @@ -502,7 +502,7 @@ class PaddleGraph(object): use_structured_name = False if self.source_type in ["tf", "onnx"] else True self.run_func.extend( gen_codes(["paddle.disable_static()", - "params, _ = fluid.load_dygraph('{}/model')".format(code_dir), + "params = paddle.load('{}/model.pdparams')".format(os.path.abspath(code_dir)), "model = {}()".format(self.name), "model.set_dict(params, use_structured_name={})".format(use_structured_name), "model.eval()", @@ -622,9 +622,8 @@ class PaddleGraph(object): return self.init_func, self.forward_func def dump_dygraph_parameter(self, code_dir): - params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb') - pickle.dump(self.parameters, params_output) - params_output.close() + save_path = os.path.join(code_dir, 'model.pdparams') + paddle.save(self.parameters, save_path) def dygraph2static(self, save_dir, input_shapes=[], input_types=[]): from paddle.fluid.dygraph.jit import declarative @@ -638,7 +637,7 @@ class PaddleGraph(object): sys.path.insert(0, save_dir) import x2paddle_code paddle.disable_static() - restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) + restore = paddle.load(osp.join(save_dir, "model.pdparams")) model = getattr(x2paddle_code, self.name)() if self.source_type in ["tf", "onnx"]: model.set_dict(restore, use_structured_name=False) diff --git a/x2paddle/optimizer/code_optimizer/hierachical_tree.py b/x2paddle/optimizer/code_optimizer/hierachical_tree.py index 7566281b07ae44b2db08cf41be20aa1f1005f61b..774bad699f9f319de28165c82282e7a313799800 100644 --- a/x2paddle/optimizer/code_optimizer/hierachical_tree.py +++ b/x2paddle/optimizer/code_optimizer/hierachical_tree.py @@ -358,7 +358,7 @@ class HierarchicalTree(Tree): run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1])) run_func_list.extend( [" paddle.disable_static()", - " params, _ = fluid.load_dygraph('{}/model')".format(save_dir), + " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)), " model = {}()".format(self.pd_graph.name), " model.set_dict(params)", " model.eval()",