提交 4c85cdff 编写于 作者: S SunAhong1993

add comment and add readability

上级 795b3c1b
......@@ -59,14 +59,14 @@ class PaddleLayer(object):
class PaddleGraph(object):
def __init__(self, father_layer=None, graph_type="dygraph"):
def __init__(self, parent_layer=None, graph_type="dygraph"):
self.layers = OrderedDict()
self.edges_out = dict()
self.edges_in = dict()
self.inputs = list()
self.outputs = list()
self.parameters = dict()
self.father_layer = father_layer
self.parent_layer = parent_layer
self.graph_type = graph_type
def set_name(self, name):
......@@ -89,9 +89,9 @@ class PaddleGraph(object):
def add_layer(self, kernel, inputs, outputs, **kwargs):
layer_id = str(len(self.layers))
if self.father_layer is not None:
layer_id = "{}.{}.{}".format(self.father_layer.id,
len(self.father_layer.blocks),
if self.parent_layer is not None:
layer_id = "{}.{}.{}".format(self.parent_layer.id,
len(self.parent_layer.blocks),
layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs)
self.layers[layer_id] = layer
......@@ -135,7 +135,7 @@ class PaddleGraph(object):
self.get_dygraph_outputs()
def get_global_layers(self):
# 该全局layers的信息是按住奥拓扑排序组成的
# 该全局layers的信息是按拓扑排序组成的
def update(layers):
global_layers = dict()
for layer_id, layer in layers.items():
......@@ -295,8 +295,7 @@ class PaddleGraph(object):
continue
if self.edges_out.get(layer_id, 0) == 0:
for output_name in layer.outputs:
if output_name.endswith(
"_assert") or not output_name.startswith("x"):
if not output_name.startswith("x"):
continue
self.outputs.append(output_name)
self.outputs = list(set(self.outputs))
......@@ -358,7 +357,7 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0:
layer_id, 0) == 0 and layer.kernel != "prim.assert":
continue
if "dygraph" in layer.kernel:
line = "{}".format(
......
......@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import re
import torch
......
......@@ -19,13 +19,12 @@ from x2paddle.core.util import *
def prim_Constant(mapper, graph, node):
""" 构造constant的PaddleLayer,该节点实现常量赋值。
PyTorch Script 示例:
TorchScript示例:
%2 : int = prim::Constant[value=-1]()
参数含义:
%2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
output = list(node.outputs())[0]
value = output.toIValue()
mapper.attrs[output_name] = value
......@@ -33,20 +32,19 @@ def prim_Constant(mapper, graph, node):
value = string(value)
graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value)
return [], node_outputs
return [], [output_name]
def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。
PyTorch Script 示例:
TorchScript示例:
%27 : Tensor? = prim::GetAttr[name="bias"](%7)
参数含义:
%7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
field_name_list = [node.s('name')]
while True:
input_node = list(node.inputs())[0].node()
......@@ -63,13 +61,13 @@ def prim_GetAttr(mapper, graph, node):
param = param.detach().numpy()
mapper.pytorch_params[output_name] = param
part_script = param
return [], node_outputs
return [], [output_name]
def prim_ListConstruct(mapper, graph, node):
""" 构造list的PaddleLayer。
PyTorch Script 示例:
TorchScript示例:
%86 : int[] = prim::ListConstruct(%84, %85)
参数含义:
%84 (int/其他): list第一个元素信息。
......@@ -77,42 +75,48 @@ def prim_ListConstruct(mapper, graph, node):
%86 (list): list节点输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
inputs = {}
for i, input_ivalue in enumerate(node.inputs()):
input_node = input_ivalue.node()
script_input_unique_id = input_ivalue.unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
inputs['input{}'.format(i)] = input_node_name
graph.add_layer("prim.list", inputs=inputs, outputs=[output_name])
return list(inputs.values()), node_outputs
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 处理每个输入
for i, input_name in enumerate(inputs_name):
layer_inputs["input{}".format(i)] = input_name
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def prim_RaiseException(mapper, graph, node):
""" 构造抛出异常的PaddleLayer。
PyTorch Script 示例:
TorchScript示例:
= prim::RaiseException(%76)
参数含义:
%76 (str): 异常信息。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, node_outputs)
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer(
"prim.exception",
inputs={'input': input_node_name},
outputs=[output_name])
return [input_node_name], node_outputs
"prim.exception", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def prim_Loop(mapper, graph, node):
""" 构造loop循环的PaddleLayer。
PyTorch Script 示例:
TorchScript示例:
%x : Tensor = prim::Loop(%4, %3, %x.3)
block0(%i : int, %x.12 : Tensor):
%72 : int[] = prim::Constant[value=[6, 6]]()
......@@ -125,11 +129,10 @@ def prim_Loop(mapper, graph, node):
%x.3 (Tensor): 循环中修改的Tensor。
%x (Tensor): loop循环的输出,与%x.5对应。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
node_outputs = mapper._get_outputs_name(node)
loop_inputs = {}
block = list(node.blocks())[0]
loop_outputs = [output_name]
loop_outputs = node_outputs
for i, block_input_ivalue in enumerate(block.inputs()):
block_input_node_name = 'x' + str(mapper.output_index)
unique_id = block_input_ivalue.unique()
......@@ -161,7 +164,7 @@ def prim_Loop(mapper, graph, node):
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs)
current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, node, current_layer)
block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs):
if input_name == loop_outputs[1]:
continue
......@@ -173,7 +176,7 @@ def prim_Loop(mapper, graph, node):
def prim_If(mapper, graph, node):
""" 构造if控制流的PaddleLayer。
PyTorch Script 示例:
TorchScript示例:
%input.5 : Tensor = prim::If(%107)
block0():
%109 : Tensor = aten::t(%102)
......@@ -196,14 +199,14 @@ def prim_If(mapper, graph, node):
graph.add_layer("prim.if", {'input': input_node_name}, [output_name])
current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, node, current_layer)
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
len0 = 0
for i, input_name in enumerate(graph_inputs0):
current_layer.inputs['input-{}'.format(i)] = input_name
len0 = i
current_layer.add_block(block0_graph)
block1 = list(node.blocks())[1]
block1_graph, graph_inputs1 = mapper.traverse(block1, node, current_layer)
block1_graph, graph_inputs1 = mapper.traverse(block1, current_layer)
for i, input_name in enumerate(graph_inputs1):
current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name
current_layer.add_block(block1_graph)
......@@ -213,18 +216,22 @@ def prim_If(mapper, graph, node):
def prim_min(mapper, graph, node):
""" 构造min的PaddleLayer。
PyTorch Script 示例:
TorchScript示例:
%87 : int = prim::min(%86)
参数含义:
%86 (list): 输入。
%87 (int): 输出。
"""
output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, node_outputs)
graph.add_layer(
"prim.min", inputs={'input': input_node_name}, outputs=[output_name])
return [input_node_name], node_outputs
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
......@@ -34,7 +34,7 @@ class PyTorchOpMapper(OpMapper):
# 转换
self.graph, _ = self.traverse(decoder.graph)
def traverse(self, script_graph, control_node=None, father_layer=None):
def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入
def _update_graph_inputs(inputs, outputs):
current_node_outputs.extend(outputs)
......@@ -43,7 +43,7 @@ class PyTorchOpMapper(OpMapper):
graph_inputs.append(name)
# 初始化
graph = PaddleGraph(father_layer)
graph = PaddleGraph(parent_layer)
current_node_outputs = []
graph_inputs = []
# 转换输入节点
......@@ -71,7 +71,7 @@ class PyTorchOpMapper(OpMapper):
# 转换输出节点
if hasattr(script_graph, 'returnNode'):
for i, ivalue in enumerate(script_graph.returnNode().inputs()):
if control_node.kind() == "prim::Loop" and i == 0:
if parent_layer.kernel == "prim.loop" and i == 0:
continue
node = ivalue.node()
script_unique_id = ivalue.unique()
......@@ -79,7 +79,7 @@ class PyTorchOpMapper(OpMapper):
graph,
node,
uid=script_unique_id,
control_node=control_node,
parent_layer=parent_layer,
index=i)
_update_graph_inputs(inputs, outputs)
# 设置graph的参数
......@@ -129,6 +129,17 @@ class PyTorchOpMapper(OpMapper):
value=string(param) if isinstance(param, str) else param)
node_outputs.append(output_name)
def _get_inputs_name(self, node):
inputs_name = []
inputs_node = []
for script_input_ivalue in node.inputs():
script_input_node = script_input_ivalue.node()
script_input_unique_id = script_input_ivalue.unique()
input_node_name = self.outputs_info[script_input_unique_id]
inputs_node.append(script_input_node)
inputs_name.append(input_node_name)
return inputs_name, inputs_node
def data(self, graph, node, uid):
for output_ivalue in node.outputs():
script_unique_id = output_ivalue.unique()
......@@ -145,17 +156,14 @@ class PyTorchOpMapper(OpMapper):
value=output_name)
return [], [output_name]
def equal(self, graph, node, uid=None, control_node=None, index=None):
if control_node is not None and index is not None:
kind = control_node.kind()
def equal(self, graph, node, uid=None, parent_layer=None, index=None):
if parent_layer is not None and index is not None:
# block的输出
input_node_name = self.outputs_info[uid]
control_output_id = index
if kind == "prim::Loop":
if parent_layer.kernel == "prim.loop":
control_output_id = index - 1
output_ivalue = list(control_node.outputs())[
control_output_id].unique()
output_node_name = self.outputs_info[output_ivalue]
output_node_name = parent_layer.outputs[control_output_id]
graph.add_layer(
"prim.equal",
inputs={'input': input_node_name},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册