From 795b3c1b827d7e651be2a5bc2a1eb67f38e78304 Mon Sep 17 00:00:00 2001 From: SunAhong1993 Date: Tue, 11 Aug 2020 20:48:17 +0800 Subject: [PATCH] fix the code style --- x2paddle/core/convert_prim.py | 6 ++---- x2paddle/core/program.py | 15 +++++++++------ x2paddle/optimizer/passes.py | 4 ++-- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/x2paddle/core/convert_prim.py b/x2paddle/core/convert_prim.py index eb89000..6765aa6 100644 --- a/x2paddle/core/convert_prim.py +++ b/x2paddle/core/convert_prim.py @@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): inputs_list = list(layer.inputs.values()) for i, input in enumerate(inputs_list): if input is None: - inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[ - i]]) + inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]]) inputs_str = ', '.join(inputs_list) line = "{} = [{}]".format(layer.outputs[0], inputs_str) elif layer.kernel == "prim.exception": @@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): attrs_str += "{}:".format(v) attrs_str = attrs_str[:-1] line = "{} = {}[{}]".format(layer.outputs[0], - list(layer.inputs.values())[0], - attrs_str) + list(layer.inputs.values())[0], attrs_str) forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 033795b..2ee5efe 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -129,7 +129,7 @@ class PaddleGraph(object): if len(layer.blocks) > 0: for block in layer.blocks: block.build(layer.inputs, layer.outputs) - + if self.graph_type == "dygraph": self.get_dygraph_inputs() self.get_dygraph_outputs() @@ -284,6 +284,7 @@ class PaddleGraph(object): for block in layer.blocks: block.get_dygraph_inputs() self.inputs.extend(block.inputs) + update(self.layers) self.inputs = list(set(self.inputs)) @@ -310,7 +311,7 @@ class PaddleGraph(object): else: codes.append(indent_blank + code_line + '\n') return codes - + def gen_head(): self.head = gen_codes( [ @@ -332,7 +333,7 @@ class PaddleGraph(object): gen_codes( ["def forward(self, {}):".format(input_data_name)], indent=1)) - + def write_code(code_dir): f = open(os.path.join(code_dir, 'code.py'), 'w') for code_line in self.head: @@ -396,9 +397,11 @@ class PaddleGraph(object): self.forward_func.extend(gen_codes([line], indent=indent)) elif "prim" in layer.kernel: from .convert_prim import convert_prim - convert_prim(layer, indent=indent, - init_func=self.init_func, - forward_func=self.forward_func) + convert_prim( + layer, + indent=indent, + init_func=self.init_func, + forward_func=self.forward_func) else: if len(layer.outputs) == 1: line = layer.outputs[0] diff --git a/x2paddle/optimizer/passes.py b/x2paddle/optimizer/passes.py index 9f10a7e..2987f79 100644 --- a/x2paddle/optimizer/passes.py +++ b/x2paddle/optimizer/passes.py @@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher): def __init__(self): super(PyTorchMatcher, self).__init__() - def match_pattern(self, pattern, graph, start_id): + def match_pattern(self, pattern, graph, start_index): pattern_index = 0 pattern_global_layers = pattern.get_global_layers() subgraph_global_layers = dict() - graph_layers = dict(list(graph.layers.items())[start_id:]) + graph_layers = dict(list(graph.layers.items())[start_index:]) for layer_id, layer in graph_layers.items(): pattern_layer = pattern.layers[list(pattern.layers.keys())[ pattern_index]] -- GitLab