提交 55bfc88f 编写于 作者: S SunAhong1993

fix the program

上级 e16d731e
...@@ -292,8 +292,8 @@ class PaddleGraph(object): ...@@ -292,8 +292,8 @@ class PaddleGraph(object):
self.dygraph2static(save_dir, input_shapes, input_types) self.dygraph2static(save_dir, input_shapes, input_types)
except Error as e: except Error as e:
print("The Dygraph2Static is failed! The possible reason are:\n" + print("The Dygraph2Static is failed! The possible reason are:\n" +
"1. The current model is not supported yet.\n" + "1. The convertor of dygraph2static of current model is not supported yet.\n" +
"2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle.py to confirm the convertor of pytorch2paddle is wrong.\n" + "2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle_model.py in your save_dir to check whether the convertor of pytorch2paddle is wrong.\n" +
"The Error is: \n" + "The Error is: \n" +
e) e)
exit(0) exit(0)
...@@ -502,7 +502,7 @@ class PaddleGraph(object): ...@@ -502,7 +502,7 @@ class PaddleGraph(object):
use_structured_name = False if self.source_type in ["tf", "onnx"] else True use_structured_name = False if self.source_type in ["tf", "onnx"] else True
self.run_func.extend( self.run_func.extend(
gen_codes(["paddle.disable_static()", gen_codes(["paddle.disable_static()",
"params, _ = fluid.load_dygraph('{}/model')".format(code_dir), "params = paddle.load('{}/model.pdparams')".format(os.path.abspath(code_dir)),
"model = {}()".format(self.name), "model = {}()".format(self.name),
"model.set_dict(params, use_structured_name={})".format(use_structured_name), "model.set_dict(params, use_structured_name={})".format(use_structured_name),
"model.eval()", "model.eval()",
...@@ -622,9 +622,8 @@ class PaddleGraph(object): ...@@ -622,9 +622,8 @@ class PaddleGraph(object):
return self.init_func, self.forward_func return self.init_func, self.forward_func
def dump_dygraph_parameter(self, code_dir): def dump_dygraph_parameter(self, code_dir):
params_output = open(os.path.join(code_dir, 'model.pdparams'), 'wb') save_path = os.path.join(code_dir, 'model.pdparams')
pickle.dump(self.parameters, params_output) paddle.save(self.parameters, save_path)
params_output.close()
def dygraph2static(self, save_dir, input_shapes=[], input_types=[]): def dygraph2static(self, save_dir, input_shapes=[], input_types=[]):
from paddle.fluid.dygraph.jit import declarative from paddle.fluid.dygraph.jit import declarative
...@@ -638,7 +637,7 @@ class PaddleGraph(object): ...@@ -638,7 +637,7 @@ class PaddleGraph(object):
sys.path.insert(0, save_dir) sys.path.insert(0, save_dir)
import x2paddle_code import x2paddle_code
paddle.disable_static() paddle.disable_static()
restore, _ = fluid.load_dygraph(osp.join(save_dir, "model")) restore = paddle.load(osp.join(save_dir, "model.pdparams"))
model = getattr(x2paddle_code, self.name)() model = getattr(x2paddle_code, self.name)()
if self.source_type in ["tf", "onnx"]: if self.source_type in ["tf", "onnx"]:
model.set_dict(restore, use_structured_name=False) model.set_dict(restore, use_structured_name=False)
......
...@@ -358,7 +358,7 @@ class HierarchicalTree(Tree): ...@@ -358,7 +358,7 @@ class HierarchicalTree(Tree):
run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1])) run_func_list.append(" # {}: 形状为{},类型为{}。".format(k, v[0], v[1]))
run_func_list.extend( run_func_list.extend(
[" paddle.disable_static()", [" paddle.disable_static()",
" params, _ = fluid.load_dygraph('{}/model')".format(save_dir), " params = paddle.load('{}/model.pdparams')".format(osp.abspath(save_dir)),
" model = {}()".format(self.pd_graph.name), " model = {}()".format(self.pd_graph.name),
" model.set_dict(params)", " model.set_dict(params)",
" model.eval()", " model.eval()",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册