提交 795b3c1b 编写于 作者: S SunAhong1993

fix the code style

上级 a7fdf1da
...@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): ...@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values()) inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list): for i, input in enumerate(inputs_list):
if input is None: if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[ inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]])
i]])
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception": elif layer.kernel == "prim.exception":
...@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): ...@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
attrs_str += "{}:".format(v) attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1] attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0], line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0], list(layer.inputs.values())[0], attrs_str)
attrs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -284,6 +284,7 @@ class PaddleGraph(object): ...@@ -284,6 +284,7 @@ class PaddleGraph(object):
for block in layer.blocks: for block in layer.blocks:
block.get_dygraph_inputs() block.get_dygraph_inputs()
self.inputs.extend(block.inputs) self.inputs.extend(block.inputs)
update(self.layers) update(self.layers)
self.inputs = list(set(self.inputs)) self.inputs = list(set(self.inputs))
...@@ -396,7 +397,9 @@ class PaddleGraph(object): ...@@ -396,7 +397,9 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
from .convert_prim import convert_prim from .convert_prim import convert_prim
convert_prim(layer, indent=indent, convert_prim(
layer,
indent=indent,
init_func=self.init_func, init_func=self.init_func,
forward_func=self.forward_func) forward_func=self.forward_func)
else: else:
......
...@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher): ...@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
def __init__(self): def __init__(self):
super(PyTorchMatcher, self).__init__() super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_id): def match_pattern(self, pattern, graph, start_index):
pattern_index = 0 pattern_index = 0
pattern_global_layers = pattern.get_global_layers() pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict() subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_id:]) graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items(): for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[ pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]] pattern_index]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册