提交 795b3c1b 编写于 作者: S SunAhong1993

fix the code style

上级 a7fdf1da
......@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list):
if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[
i]])
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]])
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception":
......@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0],
attrs_str)
list(layer.inputs.values())[0], attrs_str)
forward_func.extend(gen_codes([line], indent=indent))
......@@ -129,7 +129,7 @@ class PaddleGraph(object):
if len(layer.blocks) > 0:
for block in layer.blocks:
block.build(layer.inputs, layer.outputs)
if self.graph_type == "dygraph":
self.get_dygraph_inputs()
self.get_dygraph_outputs()
......@@ -284,6 +284,7 @@ class PaddleGraph(object):
for block in layer.blocks:
block.get_dygraph_inputs()
self.inputs.extend(block.inputs)
update(self.layers)
self.inputs = list(set(self.inputs))
......@@ -310,7 +311,7 @@ class PaddleGraph(object):
else:
codes.append(indent_blank + code_line + '\n')
return codes
def gen_head():
self.head = gen_codes(
[
......@@ -332,7 +333,7 @@ class PaddleGraph(object):
gen_codes(
["def forward(self, {}):".format(input_data_name)],
indent=1))
def write_code(code_dir):
f = open(os.path.join(code_dir, 'code.py'), 'w')
for code_line in self.head:
......@@ -396,9 +397,11 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel:
from .convert_prim import convert_prim
convert_prim(layer, indent=indent,
init_func=self.init_func,
forward_func=self.forward_func)
convert_prim(
layer,
indent=indent,
init_func=self.init_func,
forward_func=self.forward_func)
else:
if len(layer.outputs) == 1:
line = layer.outputs[0]
......
......@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
def __init__(self):
super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_id):
def match_pattern(self, pattern, graph, start_index):
pattern_index = 0
pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_id:])
graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册