提交 e16d731e 编写于 作者: S SunAhong1993

fix the dygraph2static

上级 ba85e189
...@@ -76,8 +76,6 @@ class PaddleGraph(object): ...@@ -76,8 +76,6 @@ class PaddleGraph(object):
self.source_type = source_type self.source_type = source_type
self.custom_code = None self.custom_code = None
self.inputs_info = None self.inputs_info = None
self.can_dygraph2static = True
def set_name(self, name): def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_") self.name = name.replace("-", "_").replace("/", "_")
...@@ -167,8 +165,6 @@ class PaddleGraph(object): ...@@ -167,8 +165,6 @@ class PaddleGraph(object):
self.clear_edges() self.clear_edges()
outputs_from_nodes = dict() outputs_from_nodes = dict()
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if layer.kernel == "custom_layer:Gather":
self.can_dygraph2static = False
for input_key, input_var in layer.inputs.items(): for input_key, input_var in layer.inputs.items():
vs = input_var vs = input_var
if not isinstance(vs, list): if not isinstance(vs, list):
...@@ -286,13 +282,21 @@ class PaddleGraph(object): ...@@ -286,13 +282,21 @@ class PaddleGraph(object):
self.gen_dygraph_code(save_dir) self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir) self.dump_dygraph_parameter(save_dir)
# 动转静 # 动转静
if len(self.inputs_info) > 0 and self.can_dygraph2static: if len(self.inputs_info) > 0:
input_shapes = list() input_shapes = list()
input_types = list() input_types = list()
for input_name in self.inputs: for input_name in self.inputs:
input_shapes.append(self.inputs_info[input_name][0]) input_shapes.append(self.inputs_info[input_name][0])
input_types.append(self.inputs_info[input_name][1]) input_types.append(self.inputs_info[input_name][1])
try:
self.dygraph2static(save_dir, input_shapes, input_types) self.dygraph2static(save_dir, input_shapes, input_types)
except Error as e:
print("The Dygraph2Static is failed! The possible reason are:\n" +
"1. The current model is not supported yet.\n" +
"2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle.py to confirm the convertor of pytorch2paddle is wrong.\n" +
"The Error is: \n" +
e)
exit(0)
def gen_static_code(self, code_dir): def gen_static_code(self, code_dir):
def write_code(f, code_list, indent=0): def write_code(f, code_list, indent=0):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册