提交 4c85cdff 编写于 作者: S SunAhong1993

add comment and add readability

上级 795b3c1b
...@@ -59,14 +59,14 @@ class PaddleLayer(object): ...@@ -59,14 +59,14 @@ class PaddleLayer(object):
class PaddleGraph(object): class PaddleGraph(object):
def __init__(self, father_layer=None, graph_type="dygraph"): def __init__(self, parent_layer=None, graph_type="dygraph"):
self.layers = OrderedDict() self.layers = OrderedDict()
self.edges_out = dict() self.edges_out = dict()
self.edges_in = dict() self.edges_in = dict()
self.inputs = list() self.inputs = list()
self.outputs = list() self.outputs = list()
self.parameters = dict() self.parameters = dict()
self.father_layer = father_layer self.parent_layer = parent_layer
self.graph_type = graph_type self.graph_type = graph_type
def set_name(self, name): def set_name(self, name):
...@@ -89,9 +89,9 @@ class PaddleGraph(object): ...@@ -89,9 +89,9 @@ class PaddleGraph(object):
def add_layer(self, kernel, inputs, outputs, **kwargs): def add_layer(self, kernel, inputs, outputs, **kwargs):
layer_id = str(len(self.layers)) layer_id = str(len(self.layers))
if self.father_layer is not None: if self.parent_layer is not None:
layer_id = "{}.{}.{}".format(self.father_layer.id, layer_id = "{}.{}.{}".format(self.parent_layer.id,
len(self.father_layer.blocks), len(self.parent_layer.blocks),
layer_id) layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs) layer = PaddleLayer(layer_id, kernel, inputs, outputs, **kwargs)
self.layers[layer_id] = layer self.layers[layer_id] = layer
...@@ -135,7 +135,7 @@ class PaddleGraph(object): ...@@ -135,7 +135,7 @@ class PaddleGraph(object):
self.get_dygraph_outputs() self.get_dygraph_outputs()
def get_global_layers(self): def get_global_layers(self):
# 该全局layers的信息是按住奥拓扑排序组成的 # 该全局layers的信息是按拓扑排序组成的
def update(layers): def update(layers):
global_layers = dict() global_layers = dict()
for layer_id, layer in layers.items(): for layer_id, layer in layers.items():
...@@ -295,8 +295,7 @@ class PaddleGraph(object): ...@@ -295,8 +295,7 @@ class PaddleGraph(object):
continue continue
if self.edges_out.get(layer_id, 0) == 0: if self.edges_out.get(layer_id, 0) == 0:
for output_name in layer.outputs: for output_name in layer.outputs:
if output_name.endswith( if not output_name.startswith("x"):
"_assert") or not output_name.startswith("x"):
continue continue
self.outputs.append(output_name) self.outputs.append(output_name)
self.outputs = list(set(self.outputs)) self.outputs = list(set(self.outputs))
...@@ -358,7 +357,7 @@ class PaddleGraph(object): ...@@ -358,7 +357,7 @@ class PaddleGraph(object):
for layer_id, layer in self.layers.items(): for layer_id, layer in self.layers.items():
if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get( if self.edges_in.get(layer_id, 0) == 0 and self.edges_out.get(
layer_id, 0) == 0: layer_id, 0) == 0 and layer.kernel != "prim.assert":
continue continue
if "dygraph" in layer.kernel: if "dygraph" in layer.kernel:
line = "{}".format( line = "{}".format(
......
...@@ -12,8 +12,6 @@ ...@@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import os
import re
import torch import torch
......
...@@ -19,13 +19,12 @@ from x2paddle.core.util import * ...@@ -19,13 +19,12 @@ from x2paddle.core.util import *
def prim_Constant(mapper, graph, node): def prim_Constant(mapper, graph, node):
""" 构造constant的PaddleLayer,该节点实现常量赋值。 """ 构造constant的PaddleLayer,该节点实现常量赋值。
PyTorch Script 示例: TorchScript示例:
%2 : int = prim::Constant[value=-1]() %2 : int = prim::Constant[value=-1]()
参数含义: 参数含义:
%2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。 %2 (常量类型由赋值类型定义,该示例中为int型): 常量赋值结果输出。
""" """
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
output = list(node.outputs())[0] output = list(node.outputs())[0]
value = output.toIValue() value = output.toIValue()
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
...@@ -33,20 +32,19 @@ def prim_Constant(mapper, graph, node): ...@@ -33,20 +32,19 @@ def prim_Constant(mapper, graph, node):
value = string(value) value = string(value)
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value) "prim.constant", inputs={}, outputs=[output_name], value=value)
return [], node_outputs return [], [output_name]
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
PyTorch Script 示例: TorchScript示例:
%27 : Tensor? = prim::GetAttr[name="bias"](%7) %27 : Tensor? = prim::GetAttr[name="bias"](%7)
参数含义: 参数含义:
%7 (Tensor): 输入Tensor。 %7 (Tensor): 输入Tensor。
%27 (Tensor): 输入Tensor。 %27 (Tensor): 输入Tensor。
""" """
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name]
field_name_list = [node.s('name')] field_name_list = [node.s('name')]
while True: while True:
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
...@@ -63,13 +61,13 @@ def prim_GetAttr(mapper, graph, node): ...@@ -63,13 +61,13 @@ def prim_GetAttr(mapper, graph, node):
param = param.detach().numpy() param = param.detach().numpy()
mapper.pytorch_params[output_name] = param mapper.pytorch_params[output_name] = param
part_script = param part_script = param
return [], node_outputs return [], [output_name]
def prim_ListConstruct(mapper, graph, node): def prim_ListConstruct(mapper, graph, node):
""" 构造list的PaddleLayer。 """ 构造list的PaddleLayer。
PyTorch Script 示例: TorchScript示例:
%86 : int[] = prim::ListConstruct(%84, %85) %86 : int[] = prim::ListConstruct(%84, %85)
参数含义: 参数含义:
%84 (int/其他): list第一个元素信息。 %84 (int/其他): list第一个元素信息。
...@@ -77,42 +75,48 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -77,42 +75,48 @@ def prim_ListConstruct(mapper, graph, node):
%86 (list): list节点输出。 %86 (list): list节点输出。
""" """
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name] layer_outputs = [output_name]
inputs = {} layer_inputs = {}
for i, input_ivalue in enumerate(node.inputs()): inputs_name, inputs_node = mapper._get_inputs_name(node)
input_node = input_ivalue.node() # 处理每个输入
script_input_unique_id = input_ivalue.unique() for i, input_name in enumerate(inputs_name):
input_node_name = mapper.outputs_info[script_input_unique_id] layer_inputs["input{}".format(i)] = input_name
inputs['input{}'.format(i)] = input_node_name # 获取当前节点输入、输出的list
graph.add_layer("prim.list", inputs=inputs, outputs=[output_name]) current_inputs = list(layer_inputs.values())
return list(inputs.values()), node_outputs current_outputs = layer_outputs
graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def prim_RaiseException(mapper, graph, node): def prim_RaiseException(mapper, graph, node):
""" 构造抛出异常的PaddleLayer。 """ 构造抛出异常的PaddleLayer。
PyTorch Script 示例: TorchScript示例:
= prim::RaiseException(%76) = prim::RaiseException(%76)
参数含义: 参数含义:
%76 (str): 异常信息。 %76 (str): 异常信息。
""" """
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name] layer_outputs = [output_name]
input_node = list(node.inputs())[0].node() layer_inputs = {}
script_input_unique_id = list(node.inputs())[0].unique() inputs_name, inputs_node = mapper._get_inputs_name(node)
input_node_name = mapper.outputs_info[script_input_unique_id] # 处理输入0,即%76
mapper._check_input(graph, input_node, input_node_name, node_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer( graph.add_layer(
"prim.exception", "prim.exception", inputs=layer_inputs, outputs=layer_outputs)
inputs={'input': input_node_name}, return current_inputs, current_outputs
outputs=[output_name])
return [input_node_name], node_outputs
def prim_Loop(mapper, graph, node): def prim_Loop(mapper, graph, node):
""" 构造loop循环的PaddleLayer。 """ 构造loop循环的PaddleLayer。
PyTorch Script 示例: TorchScript示例:
%x : Tensor = prim::Loop(%4, %3, %x.3) %x : Tensor = prim::Loop(%4, %3, %x.3)
block0(%i : int, %x.12 : Tensor): block0(%i : int, %x.12 : Tensor):
%72 : int[] = prim::Constant[value=[6, 6]]() %72 : int[] = prim::Constant[value=[6, 6]]()
...@@ -125,11 +129,10 @@ def prim_Loop(mapper, graph, node): ...@@ -125,11 +129,10 @@ def prim_Loop(mapper, graph, node):
%x.3 (Tensor): 循环中修改的Tensor。 %x.3 (Tensor): 循环中修改的Tensor。
%x (Tensor): loop循环的输出,与%x.5对应。 %x (Tensor): loop循环的输出,与%x.5对应。
""" """
output_name = mapper._get_outputs_name(node)[0] node_outputs = mapper._get_outputs_name(node)
node_outputs = [output_name]
loop_inputs = {} loop_inputs = {}
block = list(node.blocks())[0] block = list(node.blocks())[0]
loop_outputs = [output_name] loop_outputs = node_outputs
for i, block_input_ivalue in enumerate(block.inputs()): for i, block_input_ivalue in enumerate(block.inputs()):
block_input_node_name = 'x' + str(mapper.output_index) block_input_node_name = 'x' + str(mapper.output_index)
unique_id = block_input_ivalue.unique() unique_id = block_input_ivalue.unique()
...@@ -161,7 +164,7 @@ def prim_Loop(mapper, graph, node): ...@@ -161,7 +164,7 @@ def prim_Loop(mapper, graph, node):
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs) graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, node, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
if input_name == loop_outputs[1]: if input_name == loop_outputs[1]:
continue continue
...@@ -173,7 +176,7 @@ def prim_Loop(mapper, graph, node): ...@@ -173,7 +176,7 @@ def prim_Loop(mapper, graph, node):
def prim_If(mapper, graph, node): def prim_If(mapper, graph, node):
""" 构造if控制流的PaddleLayer。 """ 构造if控制流的PaddleLayer。
PyTorch Script 示例: TorchScript示例:
%input.5 : Tensor = prim::If(%107) %input.5 : Tensor = prim::If(%107)
block0(): block0():
%109 : Tensor = aten::t(%102) %109 : Tensor = aten::t(%102)
...@@ -196,14 +199,14 @@ def prim_If(mapper, graph, node): ...@@ -196,14 +199,14 @@ def prim_If(mapper, graph, node):
graph.add_layer("prim.if", {'input': input_node_name}, [output_name]) graph.add_layer("prim.if", {'input': input_node_name}, [output_name])
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, node, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
len0 = 0 len0 = 0
for i, input_name in enumerate(graph_inputs0): for i, input_name in enumerate(graph_inputs0):
current_layer.inputs['input-{}'.format(i)] = input_name current_layer.inputs['input-{}'.format(i)] = input_name
len0 = i len0 = i
current_layer.add_block(block0_graph) current_layer.add_block(block0_graph)
block1 = list(node.blocks())[1] block1 = list(node.blocks())[1]
block1_graph, graph_inputs1 = mapper.traverse(block1, node, current_layer) block1_graph, graph_inputs1 = mapper.traverse(block1, current_layer)
for i, input_name in enumerate(graph_inputs1): for i, input_name in enumerate(graph_inputs1):
current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name current_layer.inputs['input-{}'.format(len0 + 1 + i)] = input_name
current_layer.add_block(block1_graph) current_layer.add_block(block1_graph)
...@@ -213,18 +216,22 @@ def prim_If(mapper, graph, node): ...@@ -213,18 +216,22 @@ def prim_If(mapper, graph, node):
def prim_min(mapper, graph, node): def prim_min(mapper, graph, node):
""" 构造min的PaddleLayer。 """ 构造min的PaddleLayer。
PyTorch Script 示例: TorchScript示例:
%87 : int = prim::min(%86) %87 : int = prim::min(%86)
参数含义: 参数含义:
%86 (list): 输入。 %86 (list): 输入。
%87 (int): 输出。 %87 (int): 输出。
""" """
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
node_outputs = [output_name] layer_outputs = [output_name]
input_node = list(node.inputs())[0].node() layer_inputs = {}
script_input_unique_id = list(node.inputs())[0].unique() inputs_name, inputs_node = mapper._get_inputs_name(node)
input_node_name = mapper.outputs_info[script_input_unique_id] # 处理输入0,即%86
mapper._check_input(graph, input_node, input_node_name, node_outputs) mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
graph.add_layer( layer_inputs["input"] = inputs_name[0]
"prim.min", inputs={'input': input_node_name}, outputs=[output_name]) # 获取当前节点输入、输出的list
return [input_node_name], node_outputs current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
...@@ -34,7 +34,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -34,7 +34,7 @@ class PyTorchOpMapper(OpMapper):
# 转换 # 转换
self.graph, _ = self.traverse(decoder.graph) self.graph, _ = self.traverse(decoder.graph)
def traverse(self, script_graph, control_node=None, father_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
def _update_graph_inputs(inputs, outputs): def _update_graph_inputs(inputs, outputs):
current_node_outputs.extend(outputs) current_node_outputs.extend(outputs)
...@@ -43,7 +43,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -43,7 +43,7 @@ class PyTorchOpMapper(OpMapper):
graph_inputs.append(name) graph_inputs.append(name)
# 初始化 # 初始化
graph = PaddleGraph(father_layer) graph = PaddleGraph(parent_layer)
current_node_outputs = [] current_node_outputs = []
graph_inputs = [] graph_inputs = []
# 转换输入节点 # 转换输入节点
...@@ -71,7 +71,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -71,7 +71,7 @@ class PyTorchOpMapper(OpMapper):
# 转换输出节点 # 转换输出节点
if hasattr(script_graph, 'returnNode'): if hasattr(script_graph, 'returnNode'):
for i, ivalue in enumerate(script_graph.returnNode().inputs()): for i, ivalue in enumerate(script_graph.returnNode().inputs()):
if control_node.kind() == "prim::Loop" and i == 0: if parent_layer.kernel == "prim.loop" and i == 0:
continue continue
node = ivalue.node() node = ivalue.node()
script_unique_id = ivalue.unique() script_unique_id = ivalue.unique()
...@@ -79,7 +79,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -79,7 +79,7 @@ class PyTorchOpMapper(OpMapper):
graph, graph,
node, node,
uid=script_unique_id, uid=script_unique_id,
control_node=control_node, parent_layer=parent_layer,
index=i) index=i)
_update_graph_inputs(inputs, outputs) _update_graph_inputs(inputs, outputs)
# 设置graph的参数 # 设置graph的参数
...@@ -129,6 +129,17 @@ class PyTorchOpMapper(OpMapper): ...@@ -129,6 +129,17 @@ class PyTorchOpMapper(OpMapper):
value=string(param) if isinstance(param, str) else param) value=string(param) if isinstance(param, str) else param)
node_outputs.append(output_name) node_outputs.append(output_name)
def _get_inputs_name(self, node):
inputs_name = []
inputs_node = []
for script_input_ivalue in node.inputs():
script_input_node = script_input_ivalue.node()
script_input_unique_id = script_input_ivalue.unique()
input_node_name = self.outputs_info[script_input_unique_id]
inputs_node.append(script_input_node)
inputs_name.append(input_node_name)
return inputs_name, inputs_node
def data(self, graph, node, uid): def data(self, graph, node, uid):
for output_ivalue in node.outputs(): for output_ivalue in node.outputs():
script_unique_id = output_ivalue.unique() script_unique_id = output_ivalue.unique()
...@@ -145,17 +156,14 @@ class PyTorchOpMapper(OpMapper): ...@@ -145,17 +156,14 @@ class PyTorchOpMapper(OpMapper):
value=output_name) value=output_name)
return [], [output_name] return [], [output_name]
def equal(self, graph, node, uid=None, control_node=None, index=None): def equal(self, graph, node, uid=None, parent_layer=None, index=None):
if control_node is not None and index is not None: if parent_layer is not None and index is not None:
kind = control_node.kind()
# block的输出 # block的输出
input_node_name = self.outputs_info[uid] input_node_name = self.outputs_info[uid]
control_output_id = index control_output_id = index
if kind == "prim::Loop": if parent_layer.kernel == "prim.loop":
control_output_id = index - 1 control_output_id = index - 1
output_ivalue = list(control_node.outputs())[ output_node_name = parent_layer.outputs[control_output_id]
control_output_id].unique()
output_node_name = self.outputs_info[output_ivalue]
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': input_node_name}, inputs={'input': input_node_name},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册