提交 5bcd803c 编写于 作者: S SunAhong1993

add models

上级 4c85cdff
......@@ -73,7 +73,7 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
elif layer.kernel == "prim.min":
line = "{} = min({})".format(layer.outputs[0],
list(layer.inputs.values())[0])
elif layer.kernel == "prim.add":
elif layer.kernel == "prim.add_":
line = "{} = {} + {} * {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
layer.attrs["alpha"],
......@@ -124,11 +124,33 @@ def convert_prim(layer, indent=1, init_func=[], forward_func=[]):
if list(layer.inputs.values())[1] is None:
item1 = str(layer.attrs[list(layer.inputs.keys())[1]])
line = "{} = {} < {}".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.ne":
item0 = list(layer.inputs.values())[0]
item1 = list(layer.inputs.values())[1]
line = "{} = {} < {}".format(layer.outputs[0], item0, item1)
elif layer.kernel == "prim.slice":
attrs_str = ""
for k, v in layer.attrs.items():
attrs_str += "{}:".format(v)
attrs_str = attrs_str[:-1]
inputs_str = ""
for v in list(layer.inputs.values())[1:]:
inputs_str += "{}:".format(v)
inputs_str = inputs_str[:-1]
line = "{} = {}[{}]".format(layer.outputs[0],
list(layer.inputs.values())[0], attrs_str)
list(layer.inputs.values())[0], inputs_str)
elif layer.kernel == "prim.add":
line = "{} = {} + {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.sub":
line = "{} = {} - {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.mul":
line = "{} = {} * {}".format(layer.outputs[0],
list(layer.inputs.values())[0],
list(layer.inputs.values())[1])
elif layer.kernel == "prim.neg":
line = "{} = -{}".format(layer.outputs[0],
list(layer.inputs.values())[0])
else:
print(layer.kernel)
line = ""
forward_func.extend(gen_codes([line], indent=indent))
......@@ -297,6 +297,7 @@ class PaddleGraph(object):
for output_name in layer.outputs:
if not output_name.startswith("x"):
continue
print(layer.kernel)
self.outputs.append(output_name)
self.outputs = list(set(self.outputs))
......
......@@ -53,14 +53,18 @@ def prim_GetAttr(mapper, graph, node):
node = input_node
except Exception:
break
part_script = mapper.script
for field_name in field_name_list:
if hasattr(part_script, field_name):
param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor):
param = param.detach().numpy()
mapper.pytorch_params[output_name] = param
part_script = param
if ".".join(field_name_list) in mapper.pytorch_params:
mapper.pytorch_params[output_name] = mapper.pytorch_params[".".join(
field_name_list)]
else:
part_script = mapper.script
for field_name in field_name_list:
if hasattr(part_script, field_name):
param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor):
param = param.detach().numpy()
mapper.pytorch_params[output_name] = param
part_script = param
return [], [output_name]
......@@ -78,12 +82,13 @@ def prim_ListConstruct(mapper, graph, node):
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理每个输入
for i, input_name in enumerate(inputs_name):
layer_inputs["input{}".format(i)] = input_name
# 获取当前节点输入、输出的list
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
......@@ -101,12 +106,13 @@ def prim_RaiseException(mapper, graph, node):
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs)
......@@ -134,7 +140,10 @@ def prim_Loop(mapper, graph, node):
block = list(node.blocks())[0]
loop_outputs = node_outputs
for i, block_input_ivalue in enumerate(block.inputs()):
block_input_node_name = 'x' + str(mapper.output_index)
if i == 0:
block_input_node_name = '_x' + str(mapper.output_index)
else:
block_input_node_name = 'x' + str(mapper.output_index)
unique_id = block_input_ivalue.unique()
if unique_id not in mapper.outputs_info:
mapper.outputs_info[unique_id] = block_input_node_name
......@@ -226,12 +235,65 @@ def prim_min(mapper, graph, node):
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], layer_outputs)
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
def prim_SetAttr(mapper, graph, node):
""" 设置attribute信息。
TorchScript示例:
= prim::SetAttr[name="num_batches_tracked"](%260, %277)
参数含义:
%260 (-): 属性名前缀。
%277 (-): 需要设置的值。
"""
output_name = mapper._get_outputs_name(node)[0]
field_name_list = []
tmp_node = node
while True:
input_node = list(tmp_node.inputs())[0].node()
try:
field_name_list.insert(0, input_node.s('name'))
tmp_node = input_node
except Exception:
break
field_name_list.append(node.s('name'))
inputs_name, inputs_node = mapper._get_inputs_name(node)
param = {"Tensor": inputs_name[1]}
mapper.pytorch_params[".".join(field_name_list)] = param
return [], [output_name]
def prim_shape(mapper, graph, node):
""" 构造获取shape的PaddleLayer。
TorchScript示例:
%4701 : int[] = prim::shape(%result.1)
参数含义:
%4701 (list): 输出,shape信息。
%result.1 (Tensor): 需要获取shape的值。
"""
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer("prim.shape", inputs=layer_inputs, outputs=layer_outputs)
return current_inputs, current_outputs
......@@ -32,8 +32,28 @@ class PyTorchOpMapper(OpMapper):
self.output_index = 0
self.dygraph_name_id = {} # 动态图__init__输出名字中的id,key为kernel类型,value为id
# 转换
self.check_op(decoder.graph)
self.graph, _ = self.traverse(decoder.graph)
def check_op(self, script_graph):
def _update_op_list(graph):
for node in graph.nodes():
op_list.append(node.kind())
for block in node.blocks():
_update_op_list(block)
op_list = list()
_update_op_list(script_graph)
op_list = list(set(op_list))
unsupported_op_list = []
for op in op_list:
func_name = op.replace('::', '_')
if not (hasattr(prim, func_name) or hasattr(aten, func_name)):
unsupported_op_list.append(op)
if len(unsupported_op_list) > 0:
raise Exception("The kind {} in model is not supported yet.".format(
unsupported_op_list))
def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入
def _update_graph_inputs(inputs, outputs):
......@@ -65,9 +85,7 @@ class PyTorchOpMapper(OpMapper):
func = getattr(aten, func_name)
inputs, outputs = func(self, graph, node)
_update_graph_inputs(inputs, outputs)
else:
raise Exception("The kind {} in model is not supported yet.".
format(node.kind()))
# 转换输出节点
if hasattr(script_graph, 'returnNode'):
for i, ivalue in enumerate(script_graph.returnNode().inputs()):
......@@ -97,9 +115,9 @@ class PyTorchOpMapper(OpMapper):
self.outputs_info[script_unique_id] = output_name
self.output_index += 1
outputs_name.append(output_name)
# if节点没有输出的情况
# if或loop节点没有输出的情况
if len(list(node.outputs())) == 0:
output_name = 'x' + str(self.output_index)
output_name = '_x' + str(self.output_index)
self.output_index += 1
outputs_name.append(output_name)
return outputs_name
......@@ -122,11 +140,19 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name],
value="params[{}]".format(string(output_name)))
else:
graph.add_layer(
"prim.constant",
inputs={},
outputs=[output_name],
value=string(param) if isinstance(param, str) else param)
if isinstance(param, dict) and "Tensor" in param:
graph.add_layer(
"prim.constant",
inputs={},
outputs=[output_name],
value=param["Tensor"])
else:
graph.add_layer(
"prim.constant",
inputs={},
outputs=[output_name],
value=string(param)
if isinstance(param, str) else param)
node_outputs.append(output_name)
def _get_inputs_name(self, node):
......@@ -135,9 +161,9 @@ class PyTorchOpMapper(OpMapper):
for script_input_ivalue in node.inputs():
script_input_node = script_input_ivalue.node()
script_input_unique_id = script_input_ivalue.unique()
input_node_name = self.outputs_info[script_input_unique_id]
input_name = self.outputs_info[script_input_unique_id]
inputs_node.append(script_input_node)
inputs_name.append(input_node_name)
inputs_name.append(input_name)
return inputs_name, inputs_node
def data(self, graph, node, uid):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册