提交 795b3c1b 编写于 作者: S SunAhong1993

fix the code style

上级 a7fdf1da
......@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list):
if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[
i]])
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]])
inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception":
......@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0],
attrs_str)
list(layer.inputs.values())[0], attrs_str)
forward_func.extend(gen_codes([line], indent=indent))
......@@ -284,6 +284,7 @@ class PaddleGraph(object):
for block in layer.blocks:
block.get_dygraph_inputs()
self.inputs.extend(block.inputs)
update(self.layers)
self.inputs = list(set(self.inputs))
......@@ -396,7 +397,9 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel:
from .convert_prim import convert_prim
convert_prim(layer, indent=indent,
convert_prim(
layer,
indent=indent,
init_func=self.init_func,
forward_func=self.forward_func)
else:
......
......@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
def __init__(self):
super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_id):
def match_pattern(self, pattern, graph, start_index):
pattern_index = 0
pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_id:])
graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册