提交 795b3c1b 编写于 作者: S SunAhong1993

fix the code style

上级 a7fdf1da
...@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): ...@@ -62,8 +62,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
inputs_list = list(layer.inputs.values()) inputs_list = list(layer.inputs.values())
for i, input in enumerate(inputs_list): for i, input in enumerate(inputs_list):
if input is None: if input is None:
inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[ inputs_list[i] = str(layer.attrs[list(layer.inputs.keys())[i]])
i]])
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
elif layer.kernel == "prim.exception": elif layer.kernel == "prim.exception":
...@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]): ...@@ -131,6 +130,5 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
attrs_str += "{}:".format(v) attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1] attrs_str = attrs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0], line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0], list(layer.inputs.values())[0], attrs_str)
attrs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -129,7 +129,7 @@ class PaddleGraph(object): ...@@ -129,7 +129,7 @@ class PaddleGraph(object):
if len(layer.blocks) > 0: if len(layer.blocks) > 0:
for block in layer.blocks: for block in layer.blocks:
block.build(layer.inputs, layer.outputs) block.build(layer.inputs, layer.outputs)
if self.graph_type == "dygraph": if self.graph_type == "dygraph":
self.get_dygraph_inputs() self.get_dygraph_inputs()
self.get_dygraph_outputs() self.get_dygraph_outputs()
...@@ -284,6 +284,7 @@ class PaddleGraph(object): ...@@ -284,6 +284,7 @@ class PaddleGraph(object):
for block in layer.blocks: for block in layer.blocks:
block.get_dygraph_inputs() block.get_dygraph_inputs()
self.inputs.extend(block.inputs) self.inputs.extend(block.inputs)
update(self.layers) update(self.layers)
self.inputs = list(set(self.inputs)) self.inputs = list(set(self.inputs))
...@@ -310,7 +311,7 @@ class PaddleGraph(object): ...@@ -310,7 +311,7 @@ class PaddleGraph(object):
else: else:
codes.append(indent_blank + code_line + '\n') codes.append(indent_blank + code_line + '\n')
return codes return codes
def gen_head(): def gen_head():
self.head = gen_codes( self.head = gen_codes(
[ [
...@@ -332,7 +333,7 @@ class PaddleGraph(object): ...@@ -332,7 +333,7 @@ class PaddleGraph(object):
gen_codes( gen_codes(
["def forward(self, {}):".format(input_data_name)], ["def forward(self, {}):".format(input_data_name)],
indent=1)) indent=1))
def write_code(code_dir): def write_code(code_dir):
f = open(os.path.join(code_dir, 'code.py'), 'w') f = open(os.path.join(code_dir, 'code.py'), 'w')
for code_line in self.head: for code_line in self.head:
...@@ -396,9 +397,11 @@ class PaddleGraph(object): ...@@ -396,9 +397,11 @@ class PaddleGraph(object):
self.forward_func.extend(gen_codes([line], indent=indent)) self.forward_func.extend(gen_codes([line], indent=indent))
elif "prim" in layer.kernel: elif "prim" in layer.kernel:
from .convert_prim import convert_prim from .convert_prim import convert_prim
convert_prim(layer, indent=indent, convert_prim(
init_func=self.init_func, layer,
forward_func=self.forward_func) indent=indent,
init_func=self.init_func,
forward_func=self.forward_func)
else: else:
if len(layer.outputs) == 1: if len(layer.outputs) == 1:
line = layer.outputs[0] line = layer.outputs[0]
......
...@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher): ...@@ -30,11 +30,11 @@ class PyTorchMatcher(Matcher):
def __init__(self): def __init__(self):
super(PyTorchMatcher, self).__init__() super(PyTorchMatcher, self).__init__()
def match_pattern(self, pattern, graph, start_id): def match_pattern(self, pattern, graph, start_index):
pattern_index = 0 pattern_index = 0
pattern_global_layers = pattern.get_global_layers() pattern_global_layers = pattern.get_global_layers()
subgraph_global_layers = dict() subgraph_global_layers = dict()
graph_layers = dict(list(graph.layers.items())[start_id:]) graph_layers = dict(list(graph.layers.items())[start_index:])
for layer_id, layer in graph_layers.items(): for layer_id, layer in graph_layers.items():
pattern_layer = pattern.layers[list(pattern.layers.keys())[ pattern_layer = pattern.layers[list(pattern.layers.keys())[
pattern_index]] pattern_index]]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册