From 4a59b1715fd44b50128c3f5eb94c36c329fe009f Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Wed, 13 Jan 2021 16:48:30 +0800 Subject: [PATCH] fix3 --- x2paddle/core/program.py | 16 ++++++++-------- .../dygraph/pytorch2paddle/prim2code.py | 6 ++++-- x2paddle/optimizer/fusion/dygraph/if_fuser.py | 8 ++++++++ 3 files changed, 20 insertions(+), 10 deletions(-) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index ebf0e81..2e32874 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -285,14 +285,14 @@ class PaddleGraph(object): hierarchical_tree.save_source_files(save_dir) self.dump_dygraph_parameter(save_dir) else: - if self.source_type == "pytorch": - from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph - module_graph = ModuleGraph(self) - module_graph.save_source_files(save_dir) - self.dump_dygraph_parameter(save_dir) - else: - self.gen_dygraph_code(save_dir) - self.dump_dygraph_parameter(save_dir) +# if self.source_type == "pytorch": +# from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph +# module_graph = ModuleGraph(self) +# module_graph.save_source_files(save_dir) +# self.dump_dygraph_parameter(save_dir) +# else: + self.gen_dygraph_code(save_dir) + self.dump_dygraph_parameter(save_dir) # 动转静 code_path = osp.join(osp.abspath(save_dir), "x2paddle_code.py") print("Exporting inference model from python code ('{}')... \n".format(code_path)) diff --git a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py index e4f24e4..0ca02fc 100644 --- a/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py +++ b/x2paddle/op_mapper/dygraph/pytorch2paddle/prim2code.py @@ -182,7 +182,7 @@ def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): - line = "raise RaiseException({})".format(get_value(layer, "input", different_attrs)) + line = "raise Exception({})".format(get_value(layer, "input", different_attrs)) forward_func.extend(gen_codes([line], indent=indent)) @@ -458,10 +458,12 @@ def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di forward_func.extend(gen_codes([line], indent=indent)) -def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): +def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): line = "{} = {}.startswith({})".format(layer.outputs[0], get_value(layer, "input", different_attrs), get_value(layer, "start_str", different_attrs)) + if is_return_line: + return line.split(" = ")[1] forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/optimizer/fusion/dygraph/if_fuser.py b/x2paddle/optimizer/fusion/dygraph/if_fuser.py index 70cffa7..877dcc1 100644 --- a/x2paddle/optimizer/fusion/dygraph/if_fuser.py +++ b/x2paddle/optimizer/fusion/dygraph/if_fuser.py @@ -41,13 +41,21 @@ class DygraphIfFuser(FuseBase): return for id in graph.edges_in[layer_id]: input_layer = graph.layers[id] + input_layer_id = id if input_layer.outputs == [layer.inputs["input"]]: if input_layer.kernel == "prim.if": matches.pop(layer_id) return input_id = id break + if list(layer.inputs.values()).count(input_layer.outputs[0]) > 1 or \ + (input_layer_id in graph.edges_out and len(graph.edges_out[input_layer_id]) > 1): + matches.pop(layer_id) + return func_name = input_layer.kernel.replace(".", "_") + if func_name in ["prim_if", "prim_loop"]: + matches.pop(layer_id) + return from x2paddle.op_mapper.dygraph.pytorch2paddle import prim2code func = getattr(prim2code, func_name) line = func(input_layer, is_return_line=True) -- GitLab