From e16d731e3f1d499e1ef7e3e32ecc62b829976be0 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Thu, 3 Dec 2020 16:29:25 +0800 Subject: [PATCH] fix the dygraph2static --- x2paddle/core/program.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 2621edc..9440cfa 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -76,8 +76,6 @@ class PaddleGraph(object): self.source_type = source_type self.custom_code = None self.inputs_info = None - self.can_dygraph2static = True - def set_name(self, name): self.name = name.replace("-", "_").replace("/", "_") @@ -167,8 +165,6 @@ class PaddleGraph(object): self.clear_edges() outputs_from_nodes = dict() for layer_id, layer in self.layers.items(): - if layer.kernel == "custom_layer:Gather": - self.can_dygraph2static = False for input_key, input_var in layer.inputs.items(): vs = input_var if not isinstance(vs, list): @@ -286,13 +282,21 @@ class PaddleGraph(object): self.gen_dygraph_code(save_dir) self.dump_dygraph_parameter(save_dir) # 动转静 - if len(self.inputs_info) > 0 and self.can_dygraph2static: + if len(self.inputs_info) > 0: input_shapes = list() input_types = list() for input_name in self.inputs: input_shapes.append(self.inputs_info[input_name][0]) input_types.append(self.inputs_info[input_name][1]) - self.dygraph2static(save_dir, input_shapes, input_types) + try: + self.dygraph2static(save_dir, input_shapes, input_types) + except Error as e: + print("The Dygraph2Static is failed! The possible reason are:\n" + + "1. The current model is not supported yet.\n" + + "2. The convertor of pytorch2paddle is wrong. You can run the code of x2paddle.py to confirm the convertor of pytorch2paddle is wrong.\n" + + "The Error is: \n" + + e) + exit(0) def gen_static_code(self, code_dir): def write_code(f, code_list, indent=0): -- GitLab