diff --git a/x2paddle/core/convert_prim.py b/x2paddle/core/convert_prim.py index eb8900083f7974f15fdce1c4dd61fd023ab73700..6765aa632b178895f290adc5149e6a7a50a8aeff 100644 --- a/x2paddle/core/convert_prim.py +++ b/x2paddle/core/convert_prim.py @@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): inputs_list = list(layer.inputs.values()) for i, input in enumerate(inputs_list): if input is None: - inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[ - i]]) + inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]]) inputs_str = ', '.join(inputs_list) line = "{} = [{}]".format(layer.outputs[0], inputs_str) elif layer.kernel == "prim.exception": @@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): attrs_str += "{}:".format(v) attrs_str = attrs_str[:-1] line = "{} = {}[{}]".format(layer.outputs[0], - list(layer.inputs.values())[0], - attrs_str) + list(layer.inputs.values())[0], attrs_str) forward_func.extend(gen_codes([line], indent=indent)) diff --git a/x2paddle/core/program.py b/x2paddle/core/program.py index 033795b262d3e2bcfaff1c3525e8b72d3d159dd6..2ee5efebaeeba4f5cee336a6f8f30359a788a4ed 100644 --- a/x2paddle/core/program.py +++ b/x2paddle/core/program.py @@ -129,7 +129,7 @@ class PaddleGraph(object): if len(layer.blocks) > 0: for block in layer.blocks: block.build(layer.inputs, layer.outputs) - + if self.graph_type == "dygraph": self.get_dygraph_inputs() self.get_dygraph_outputs() @@ -284,6 +284,7 @@ class PaddleGraph(object): for block in layer.blocks: block.get_dygraph_inputs() self.inputs.extend(block.inputs) + update(self.layers) self.inputs = list(set(self.inputs)) @@ -310,7 +311,7 @@ class PaddleGraph(object): else: codes.append(indent_blank + code_line + '\n') return codes - + def gen_head(): self.head = gen_codes( [ @@ -332,7 +333,7 @@ class PaddleGraph(object): gen_codes( ["def forward(self, {}):".format(input_data_name)], indent=1)) - + def write_code(code_dir): f = open(os.path.join(code_dir, 'code.py'), 'w') for code_line in self.head: @@ -396,9 +397,11 @@ class PaddleGraph(object): self.forward_func.extend(gen_codes([line], indent=indent)) elif "prim" in layer.kernel: from .convert_prim import convert_prim - convert_prim(layer, indent=indent, - init_func=self.init_func, - forward_func=self.forward_func) + convert_prim( + layer, + indent=indent, + init_func=self.init_func, + forward_func=self.forward_func) else: if len(layer.outputs) == 1: line = layer.outputs[0] diff --git a/x2paddle/optimizer/passes.py b/x2paddle/optimizer/passes.py index 9f10a7ed8aa987ae3ec283320ee40d8cd9fd24a6..2987f7934e05bf443c2531a1ac70d51dc697a98e 100644 --- a/x2paddle/optimizer/passes.py +++ b/x2paddle/optimizer/passes.py @@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher): def __init__(self): super(PyTorchMatcher, self).__init__() - def match_pattern(self, pattern, graph, start_id): + def match_pattern(self, pattern, graph, start_index): pattern_index = 0 pattern_global_layers = pattern.get_global_layers() subgraph_global_layers = dict() - graph_layers = dict(list(graph.layers.items())[start_id:]) + graph_layers = dict(list(graph.layers.items())[start_index:]) for layer_id, layer in graph_layers.items(): pattern_layer = pattern.layers[list(pattern.layers.keys())[ pattern_index]]