未验证 提交 e2b3e7e0 编写于 作者: J Jason 提交者: GitHub

Merge pull request #536 from SunAhong1993/develop

add PyTorch op
...@@ -76,6 +76,7 @@ class PaddleGraph(object): ...@@ -76,6 +76,7 @@ class PaddleGraph(object):
self.source_type = source_type self.source_type = source_type
self.custom_code = None self.custom_code = None
self.inputs_info = None self.inputs_info = None
self.has_unpack = False
def set_name(self, name): def set_name(self, name):
self.name = name.replace("-", "_").replace("/", "_") self.name = name.replace("-", "_").replace("/", "_")
...@@ -112,6 +113,8 @@ class PaddleGraph(object): ...@@ -112,6 +113,8 @@ class PaddleGraph(object):
layer_id) layer_id)
layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs) layer = PaddleLayer(layer_id, kernel, inputs, outputs, scope_name=scope_name, **kwargs)
self.layers[layer_id] = layer self.layers[layer_id] = layer
if layer.kernel in ["prim.list_unpack" or "prim.tuple_unpack"]:
self.has_unpack = True
return layer_id return layer_id
def del_layer(self, layer_id): def del_layer(self, layer_id):
...@@ -272,12 +275,16 @@ class PaddleGraph(object): ...@@ -272,12 +275,16 @@ class PaddleGraph(object):
def gen_dygraph_model(self, save_dir, jit_type=None): def gen_dygraph_model(self, save_dir, jit_type=None):
if jit_type == "trace": if jit_type == "trace":
from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree if not self.has_unpack:
hierarchical_tree = HierarchicalTree(self) from x2paddle.optimizer.pytorch_code_optimizer import HierarchicalTree
for layer_id, layer in self.layers.items(): hierarchical_tree = HierarchicalTree(self)
hierarchical_tree.insert(layer) for layer_id, layer in self.layers.items():
hierarchical_tree.save_source_files(save_dir) hierarchical_tree.insert(layer)
self.dump_dygraph_parameter(save_dir) hierarchical_tree.save_source_files(save_dir)
self.dump_dygraph_parameter(save_dir)
else:
self.gen_dygraph_code(save_dir)
self.dump_dygraph_parameter(save_dir)
else: else:
if self.source_type == "pytorch": if self.source_type == "pytorch":
from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph from x2paddle.optimizer.pytorch_code_optimizer import ModuleGraph
......
...@@ -64,5 +64,5 @@ class TraceDecoder(Decoder): ...@@ -64,5 +64,5 @@ class TraceDecoder(Decoder):
print(e) print(e)
exit(0) exit(0)
self.graph = self._optimize_graph(self.script.inlined_graph) self.graph = self._optimize_graph(self.script.inlined_graph)
self.input_examples = input_examples self.input_examples = input_examples
...@@ -101,6 +101,7 @@ class TFGraphNode(GraphNode): ...@@ -101,6 +101,7 @@ class TFGraphNode(GraphNode):
@property @property
def name(self): def name(self):
if hasattr(self, 'index'): if hasattr(self, 'index'):
print(self.layer_type)
return self.layer_name + "_p{}".format(self.index) return self.layer_name + "_p{}".format(self.index)
return self.layer_name return self.layer_name
...@@ -184,7 +185,7 @@ class TFGraph(Graph): ...@@ -184,7 +185,7 @@ class TFGraph(Graph):
node = super(TFGraph, self).get_node(new_node_name, copy) node = super(TFGraph, self).get_node(new_node_name, copy)
if node is None: if node is None:
return None return None
if node.layer_type == "Switch": if node.layer_type in ["Switch", "Reshape", "Sub"]:
if hasattr(node, 'index'): if hasattr(node, 'index'):
del node.index del node.index
if len(items) == 1 and node.layer_type in self.multi_out_ops: if len(items) == 1 and node.layer_type in self.multi_out_ops:
...@@ -284,6 +285,11 @@ class TFGraph(Graph): ...@@ -284,6 +285,11 @@ class TFGraph(Graph):
if node_name in self.output_nodes: if node_name in self.output_nodes:
idx = self.output_nodes.index(node_name) idx = self.output_nodes.index(node_name)
self.output_nodes[idx] = input_node.layer_name self.output_nodes[idx] = input_node.layer_name
if len(input_node.outputs) > 0:
self.output_nodes.pop(idx)
else:
self.output_nodes[idx] = input_node.layer_name
def _remove_cast_node(self): def _remove_cast_node(self):
cast_node = list() cast_node = list()
......
...@@ -48,13 +48,17 @@ def aten_abs(mapper, graph, node): ...@@ -48,13 +48,17 @@ def aten_abs(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%n.3 # 处理输入0,即%n.3
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.abs", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.abs",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -78,7 +82,8 @@ def aten_adaptive_avg_pool2d(mapper, graph, node): ...@@ -78,7 +82,8 @@ def aten_adaptive_avg_pool2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.3 # 处理输入0,即%x.3
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -128,13 +133,20 @@ def aten_addmm(mapper, graph, node): ...@@ -128,13 +133,20 @@ def aten_addmm(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%150 # 处理输入0,即%150
mapper._check_input( mapper._check_input(
graph, inputs_node[0], inputs_name[0], current_outputs, scope_name, add_dim=True) graph,
inputs_node[0],
inputs_name[0],
current_outputs,
scope_name,
add_dim=True)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 处理输入1,即%input.3 # 处理输入1,即%input.3
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[1] layer_inputs["x"] = inputs_name[1]
# 处理输入2,即%156 # 处理输入2,即%156
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[2] layer_inputs["y"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -182,16 +194,26 @@ def aten_add(mapper, graph, node): ...@@ -182,16 +194,26 @@ def aten_add(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%i.12 # 处理输入0,即%i.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%288 # 处理输入1,即%288
mapper._check_input( mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) graph,
inputs_node[1],
inputs_name[1],
current_outputs,
scope_name,
add_dim=True)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.add", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.add",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -215,11 +237,17 @@ def aten_add_(mapper, graph, node): ...@@ -215,11 +237,17 @@ def aten_add_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%output.2 # 处理输入0,即%output.2
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%150 # 处理输入1,即%150
mapper._check_input( mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) graph,
inputs_node[1],
inputs_name[1],
current_outputs,
scope_name,
add_dim=True)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -233,7 +261,11 @@ def aten_add_(mapper, graph, node): ...@@ -233,7 +261,11 @@ def aten_add_(mapper, graph, node):
current_inputs.append(inputs_name[2]) current_inputs.append(inputs_name[2])
graph.add_layer( graph.add_layer(
"prim.add_", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) "prim.add_",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -255,15 +287,21 @@ def aten___and__(mapper, graph, node): ...@@ -255,15 +287,21 @@ def aten___and__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%i.12 # 处理输入0,即%i.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%288 # 处理输入1,即%288
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.and", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.and",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -284,15 +322,21 @@ def aten_append(mapper, graph, node): ...@@ -284,15 +322,21 @@ def aten_append(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [inputs_name[0]] current_outputs = [inputs_name[0]]
# 处理输入0,即_output_size.1 # 处理输入0,即_output_size.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["list"] = inputs_name[0] layer_inputs["list"] = inputs_name[0]
# 处理输入1,即v.1 # 处理输入1,即v.1
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["element"] = inputs_name[1] layer_inputs["element"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.append", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.append",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -421,7 +465,8 @@ def aten_avg_pool2d(mapper, graph, node): ...@@ -421,7 +465,8 @@ def aten_avg_pool2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.34 # 处理输入0,即%x.34
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -454,6 +499,7 @@ def aten_avg_pool2d(mapper, graph, node): ...@@ -454,6 +499,7 @@ def aten_avg_pool2d(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_avg_pool3d(mapper, graph, node): def aten_avg_pool3d(mapper, graph, node):
""" 构造最大池化的PaddleLayer。 """ 构造最大池化的PaddleLayer。
...@@ -479,7 +525,8 @@ def aten_avg_pool3d(mapper, graph, node): ...@@ -479,7 +525,8 @@ def aten_avg_pool3d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.34 # 处理输入0,即%x.34
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -537,7 +584,8 @@ def aten_avg_pool1d(mapper, graph, node): ...@@ -537,7 +584,8 @@ def aten_avg_pool1d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.34 # 处理输入0,即%x.34
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -599,7 +647,8 @@ def aten_batch_norm(mapper, graph, node): ...@@ -599,7 +647,8 @@ def aten_batch_norm(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.80 # 处理输入0,即%input.80
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -652,16 +701,26 @@ def aten_bmm(mapper, graph, node): ...@@ -652,16 +701,26 @@ def aten_bmm(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%i.12 # 处理输入0,即%i.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%288 # 处理输入1,即%288
mapper._check_input( mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) graph,
inputs_node[1],
inputs_name[1],
current_outputs,
scope_name,
add_dim=True)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.bmm", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.bmm",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -684,7 +743,8 @@ def aten_cat(mapper, graph, node): ...@@ -684,7 +743,8 @@ def aten_cat(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%13 # 处理输入0,即%13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -725,7 +785,8 @@ def aten_chunk(mapper, graph, node): ...@@ -725,7 +785,8 @@ def aten_chunk(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.170 # 处理输入0,即%input.170
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -774,7 +835,8 @@ def aten_clamp(mapper, graph, node): ...@@ -774,7 +835,8 @@ def aten_clamp(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.1 # 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -804,6 +866,48 @@ def aten_clamp(mapper, graph, node): ...@@ -804,6 +866,48 @@ def aten_clamp(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_clamp_min(mapper, graph, node):
""" 构造元素剪裁的PaddleLayer。
TorchScript示例:
%56 : Tensor = aten::clamp_min(%input.1, %46)
参数含义:
%56 (Tensor): 输出,累加后的结果。
%input.1 (Tensor): 输入,需要剪裁的Tensor。
%46 (float/Tensor): 最小值。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%46
if inputs_name[1] in mapper.attrs:
layer_attrs["min"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["min"] = inputs_name[1]
current_inputs.append(inputs_name[1])
graph.add_layer(
"paddle.clip",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten___contains__(mapper, graph, node): def aten___contains__(mapper, graph, node):
""" 构造in的PaddleLayer。 """ 构造in的PaddleLayer。
...@@ -822,15 +926,21 @@ def aten___contains__(mapper, graph, node): ...@@ -822,15 +926,21 @@ def aten___contains__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%50 # 处理输入0,即%50
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 处理输入1,即%name.1 # 处理输入1,即%name.1
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["element"] = inputs_name[1] layer_inputs["element"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.contain", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.contain",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -855,7 +965,8 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -855,7 +965,8 @@ def aten_constant_pad_nd(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input1.24 # 处理输入0,即%input1.24
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -887,7 +998,8 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -887,7 +998,8 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[inputs_name[0] + "_if", output_name], outputs=[inputs_name[0] + "_if", output_name],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.sub", "prim.sub",
inputs={"y": inputs_name[0] + "_len"}, inputs={"y": inputs_name[0] + "_len"},
...@@ -918,19 +1030,27 @@ def aten_constant_pad_nd(mapper, graph, node): ...@@ -918,19 +1030,27 @@ def aten_constant_pad_nd(mapper, graph, node):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name) scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
block.add_layer( block.add_layer(
kernel, inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) kernel,
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0] if_layer.inputs["input-0"] = inputs_name[0]
if_layer.inputs["input-1"] = inputs_name[0] + "_len" if_layer.inputs["input-1"] = inputs_name[0] + "_len"
if len(layer_attrs["padding"]) == 2: if len(layer_attrs["padding"]) == 2:
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad1d")
add_pad_layers("paddle.nn.Pad1D", 3) add_pad_layers("paddle.nn.Pad1D", 3)
elif len(layer_attrs["padding"]) == 4: elif len(layer_attrs["padding"]) == 4:
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad2d")
add_pad_layers("paddle.nn.Pad2D", 4) add_pad_layers("paddle.nn.Pad2D", 4)
elif len(layer_attrs["padding"]) == 6: elif len(layer_attrs["padding"]) == 6:
layer_outputs[0] = layer_outputs[0].raplace("pad", "pad3d")
add_pad_layers("paddle.nn.Pad3D", 5) add_pad_layers("paddle.nn.Pad3D", 5)
else: else:
raise Exception("The lenght of padding list must be 2, 4 or 6!") raise Exception("The lenght of padding list must be 2, 4 or 6!")
...@@ -958,12 +1078,17 @@ def aten_contiguous(mapper, graph, node): ...@@ -958,12 +1078,17 @@ def aten_contiguous(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4058 # 处理输入0,即%4058
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -992,7 +1117,8 @@ def aten_conv2d(mapper, graph, node): ...@@ -992,7 +1117,8 @@ def aten_conv2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -1056,18 +1182,20 @@ def aten__convolution(mapper, graph, node): ...@@ -1056,18 +1182,20 @@ def aten__convolution(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入1,即%18 # 处理输入1,即%18
weights = mapper.pytorch_params[inputs_name[1]] weights = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[op_name + ".weight"] = weights #np.swapaxes(weights, 0, 1) mapper.paddle_params[op_name +
".weight"] = weights #np.swapaxes(weights, 0, 1)
if mapper.attrs[inputs_name[6]]: if mapper.attrs[inputs_name[6]]:
layer_attrs["out_channels"] = weights.shape[1] layer_attrs["out_channels"] = weights.shape[1]
else: else:
layer_attrs["out_channels"] = weights.shape[0] layer_attrs["out_channels"] = weights.shape[0]
layer_attrs["kernel_size"] = weights.shape[2:] layer_attrs["kernel_size"] = weights.shape[2:]
# 处理输入2,即%10 # 处理输入2,即%10
if inputs_name[2] in mapper.pytorch_params: if inputs_name[2] in mapper.pytorch_params:
bias = mapper.pytorch_params[inputs_name[2]] bias = mapper.pytorch_params[inputs_name[2]]
...@@ -1090,11 +1218,11 @@ def aten__convolution(mapper, graph, node): ...@@ -1090,11 +1218,11 @@ def aten__convolution(mapper, graph, node):
# 处理输入8,即%12 # 处理输入8,即%12
layer_attrs["groups"] = mapper.attrs[inputs_name[8]] layer_attrs["groups"] = mapper.attrs[inputs_name[8]]
if mapper.attrs[inputs_name[6]]: if mapper.attrs[inputs_name[6]]:
layer_attrs['in_channels'] = weights.shape[0] * mapper.attrs[inputs_name[ layer_attrs['in_channels'] = weights.shape[0] * mapper.attrs[
8]] inputs_name[8]]
else: else:
layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[inputs_name[ layer_attrs['in_channels'] = weights.shape[1] * mapper.attrs[
8]] inputs_name[8]]
if mapper.attrs[inputs_name[6]]: if mapper.attrs[inputs_name[6]]:
graph.add_layer( graph.add_layer(
"paddle.nn.Conv2DTranspose", "paddle.nn.Conv2DTranspose",
...@@ -1138,7 +1266,8 @@ def aten_conv_transpose2d(mapper, graph, node): ...@@ -1138,7 +1266,8 @@ def aten_conv_transpose2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -1166,8 +1295,7 @@ def aten_conv_transpose2d(mapper, graph, node): ...@@ -1166,8 +1295,7 @@ def aten_conv_transpose2d(mapper, graph, node):
layer_attrs["groups"] = mapper.attrs[inputs_name[6]] layer_attrs["groups"] = mapper.attrs[inputs_name[6]]
# 处理输入7,即%22 # 处理输入7,即%22
layer_attrs["dilation"] = mapper.attrs[inputs_name[7]] layer_attrs["dilation"] = mapper.attrs[inputs_name[7]]
layer_attrs['in_channels'] = weights.shape[0] * mapper.attrs[inputs_name[ layer_attrs['in_channels'] = weights.shape[0] * mapper.attrs[inputs_name[6]]
6]]
graph.add_layer( graph.add_layer(
"paddle.nn.Conv2DTranspose", "paddle.nn.Conv2DTranspose",
inputs=layer_inputs, inputs=layer_inputs,
...@@ -1194,12 +1322,17 @@ def aten_cos(mapper, graph, node): ...@@ -1194,12 +1322,17 @@ def aten_cos(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%sinusoid_inp.1 # 处理输入0,即%sinusoid_inp.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.cos", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.cos",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1223,7 +1356,8 @@ def aten_cumsum(mapper, graph, node): ...@@ -1223,7 +1356,8 @@ def aten_cumsum(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%mask.1 # 处理输入0,即%mask.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -1270,11 +1404,16 @@ def aten_detach(mapper, graph, node): ...@@ -1270,11 +1404,16 @@ def aten_detach(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%end.1 # 处理输入0,即%end.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1295,7 +1434,11 @@ def aten_dict(mapper, graph, node): ...@@ -1295,7 +1434,11 @@ def aten_dict(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
graph.add_layer("prim.dict", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.dict",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1315,15 +1458,22 @@ def aten_dim(mapper, graph, node): ...@@ -1315,15 +1458,22 @@ def aten_dim(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.shape", inputs=layer_inputs, outputs=[output_name], scope_name=scope_name) "prim.shape",
inputs=layer_inputs,
outputs=[output_name],
scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.len", inputs={"input": output_name}, outputs=[output_name], scope_name=scope_name) "prim.len",
inputs={"input": output_name},
outputs=[output_name],
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1344,15 +1494,21 @@ def aten_div_(mapper, graph, node): ...@@ -1344,15 +1494,21 @@ def aten_div_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.div", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.div",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1374,15 +1530,21 @@ def aten_div(mapper, graph, node): ...@@ -1374,15 +1530,21 @@ def aten_div(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.div", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.div",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1405,13 +1567,18 @@ def aten_dropout(mapper, graph, node): ...@@ -1405,13 +1567,18 @@ def aten_dropout(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%119 # 处理输入0,即%119
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.Dropout", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, p=0.0) "paddle.nn.Dropout",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
p=0.0)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1434,13 +1601,18 @@ def aten_dropout_(mapper, graph, node): ...@@ -1434,13 +1601,18 @@ def aten_dropout_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%119 # 处理输入0,即%119
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.Dropout", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, p=0.0) "paddle.nn.Dropout",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
p=0.0)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1472,7 +1644,8 @@ def aten_embedding(mapper, graph, node): ...@@ -1472,7 +1644,8 @@ def aten_embedding(mapper, graph, node):
layer_attrs["num_embeddings"] = weights.shape[0] layer_attrs["num_embeddings"] = weights.shape[0]
layer_attrs["embedding_dim"] = weights.shape[1] layer_attrs["embedding_dim"] = weights.shape[1]
# 处理输入1,即%input_ids.1 # 处理输入1,即%input_ids.1
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[1] layer_inputs["input"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -1511,18 +1684,24 @@ def aten_eq(mapper, graph, node): ...@@ -1511,18 +1684,24 @@ def aten_eq(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
x_value = list(node.inputs())[0] x_value = list(node.inputs())[0]
x_type = x_value.type() x_type = x_value.type()
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
y_value = list(node.inputs())[1] y_value = list(node.inputs())[1]
y_type = y_value.type() y_type = y_value.type()
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.eq", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.eq",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1543,12 +1722,17 @@ def aten_erf(mapper, graph, node): ...@@ -1543,12 +1722,17 @@ def aten_erf(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%sinusoid_inp.1 # 处理输入0,即%sinusoid_inp.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.erf", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.erf",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1569,13 +1753,17 @@ def aten_exp(mapper, graph, node): ...@@ -1569,13 +1753,17 @@ def aten_exp(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.exp", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.exp",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1599,7 +1787,8 @@ def aten_expand(mapper, graph, node): ...@@ -1599,7 +1787,8 @@ def aten_expand(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1875 # 处理输入0,即%1875
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入1,即%51 # 处理输入1,即%51
...@@ -1611,9 +1800,9 @@ def aten_expand(mapper, graph, node): ...@@ -1611,9 +1800,9 @@ def aten_expand(mapper, graph, node):
layer_inputs["shape"] = inputs_name[1] layer_inputs["shape"] = inputs_name[1]
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.expand", "paddle.expand",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1637,14 +1826,16 @@ def aten_expand_as(mapper, graph, node): ...@@ -1637,14 +1826,16 @@ def aten_expand_as(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1875 # 处理输入0,即%1875
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1888 # 处理输入1,即%1888
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.type", "prim.type",
inputs={"input": inputs_name[0]}, inputs={"input": inputs_name[0]},
...@@ -1666,7 +1857,8 @@ def aten_expand_as(mapper, graph, node): ...@@ -1666,7 +1857,8 @@ def aten_expand_as(mapper, graph, node):
outputs=[inputs_name[0] + "_if1"], outputs=[inputs_name[0] + "_if1"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.type", "prim.type",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
...@@ -1679,18 +1871,23 @@ def aten_expand_as(mapper, graph, node): ...@@ -1679,18 +1871,23 @@ def aten_expand_as(mapper, graph, node):
scope_name=scope_name, scope_name=scope_name,
dtype=inputs_name[1] + "_type") dtype=inputs_name[1] + "_type")
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0] if_layer.inputs["input-0"] = inputs_name[0]
if_layer.inputs["input-1"] = inputs_name[1] if_layer.inputs["input-1"] = inputs_name[1]
graph.add_layer( graph.add_layer(
"paddle.expand_as", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.expand_as",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.if", {'input': inputs_name[0] + "_cond"}, "prim.if", {'input': inputs_name[0] + "_cond"},
outputs=[inputs_name[0] + "_if2"], outputs=[inputs_name[0] + "_if2"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"paddle.cast", "paddle.cast",
inputs={"x": layer_outputs[0]}, inputs={"x": layer_outputs[0]},
...@@ -1698,20 +1895,21 @@ def aten_expand_as(mapper, graph, node): ...@@ -1698,20 +1895,21 @@ def aten_expand_as(mapper, graph, node):
scope_name=scope_name, scope_name=scope_name,
dtype=string("bool")) dtype=string("bool"))
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = layer_outputs[0] if_layer.inputs["input-0"] = layer_outputs[0]
# TODO(syf): check expand_as # TODO(syf): check expand_as
# # 处理输入0,即%1875 # # 处理输入0,即%1875
# mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) # mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
# layer_inputs["x"] = inputs_name[0] # layer_inputs["x"] = inputs_name[0]
# # 处理输入1,即%1888 # # 处理输入1,即%1888
# mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) # mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name)
# layer_inputs["y"] = inputs_name[1] # layer_inputs["y"] = inputs_name[1]
# # 获取当前节点输入的list # # 获取当前节点输入的list
# current_inputs = list(layer_inputs.values()) # current_inputs = list(layer_inputs.values())
# graph.add_layer( # graph.add_layer(
# "paddle.expand_as", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) # "paddle.expand_as", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1738,7 +1936,8 @@ def aten_eye(mapper, graph, node): ...@@ -1738,7 +1936,8 @@ def aten_eye(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%49 # 处理输入0,即%49
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["num_rows"] = inputs_name[0] layer_inputs["num_rows"] = inputs_name[0]
if len(inputs_name) > 5: if len(inputs_name) > 5:
# 处理输入1,即%_50 # 处理输入1,即%_50
...@@ -1758,6 +1957,7 @@ def aten_eye(mapper, graph, node): ...@@ -1758,6 +1957,7 @@ def aten_eye(mapper, graph, node):
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_feature_dropout(mapper, graph, node): def aten_feature_dropout(mapper, graph, node):
""" 构造Dropout的PaddleLayer。 """ 构造Dropout的PaddleLayer。
...@@ -1777,13 +1977,18 @@ def aten_feature_dropout(mapper, graph, node): ...@@ -1777,13 +1977,18 @@ def aten_feature_dropout(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%119 # 处理输入0,即%119
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.Dropout", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, p=0.0) "paddle.nn.Dropout",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
p=0.0)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1808,7 +2013,8 @@ def aten_flatten(mapper, graph, node): ...@@ -1808,7 +2013,8 @@ def aten_flatten(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x # 处理输入0,即%x
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
# 处理输入1,即%4 # 处理输入1,即%4
layer_attrs["start_axis"] = mapper.attrs[inputs_name[1]] layer_attrs["start_axis"] = mapper.attrs[inputs_name[1]]
# 处理输入2,即%20 # 处理输入2,即%20
...@@ -1843,12 +2049,17 @@ def aten_Float(mapper, graph, node): ...@@ -1843,12 +2049,17 @@ def aten_Float(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%3991 # 处理输入0,即%3991
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.float", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.float",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1869,37 +2080,44 @@ def aten_floor(mapper, graph, node): ...@@ -1869,37 +2080,44 @@ def aten_floor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%scale.18 # 处理输入0,即%scale.18
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.type", "prim.type", {'input': inputs_name[0]},
{'input': inputs_name[0]},
outputs=[inputs_name[0] + "_type"], outputs=[inputs_name[0] + "_type"],
scope_name=scope_name) scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.str", "prim.str", {'input': inputs_name[0] + "_type"},
{'input': inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_type"], outputs=[inputs_name[0] + "_type"],
scope_name=scope_name) scope_name=scope_name)
graph.add_layer( graph.add_layer(
"prim.startswith", "prim.startswith", {'input': inputs_name[0] + "_type"},
{'input': inputs_name[0] + "_type"},
outputs=[inputs_name[0] + "_cond"], outputs=[inputs_name[0] + "_cond"],
scope_name=scope_name, scope_name=scope_name,
start_str=string("VarType")) start_str=string("VarType"))
graph.add_layer( graph.add_layer(
"prim.if", "prim.if", {'input': inputs_name[0] + "_cond"},
{'input': inputs_name[0] + "_cond"},
outputs=[inputs_name[0] + "_if"], outputs=[inputs_name[0] + "_if"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
block.add_layer("paddle.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name) source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"paddle.floor",
inputs=copy.deepcopy(layer_inputs),
outputs=copy.deepcopy(layer_outputs),
scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
block.add_layer("prim.floor", inputs=copy.deepcopy(layer_inputs), outputs=copy.deepcopy(layer_outputs), scope_name=scope_name) source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer(
"prim.floor",
inputs=copy.deepcopy(layer_inputs),
outputs=copy.deepcopy(layer_outputs),
scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[0] if_layer.inputs["input-0"] = inputs_name[0]
if_layer.outputs.append(output_name) if_layer.outputs.append(output_name)
...@@ -1924,15 +2142,21 @@ def aten_floordiv(mapper, graph, node): ...@@ -1924,15 +2142,21 @@ def aten_floordiv(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.floordiv", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.floordiv",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1954,15 +2178,21 @@ def aten_floor_divide(mapper, graph, node): ...@@ -1954,15 +2178,21 @@ def aten_floor_divide(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.floordiv", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.floordiv",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -1990,7 +2220,8 @@ def aten_full_like(mapper, graph, node): ...@@ -1990,7 +2220,8 @@ def aten_full_like(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%val_if_large.3 # 处理输入0,即%val_if_large.3
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2035,20 +2266,22 @@ def aten_gather(mapper, graph, node): ...@@ -2035,20 +2266,22 @@ def aten_gather(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%18 # 处理输入1,即%18
layer_attrs["dim"] = mapper.attrs[inputs_name[1]] layer_attrs["dim"] = mapper.attrs[inputs_name[1]]
# 处理输入2,即%19 # 处理输入2,即%19
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["index"] = inputs_name[2] layer_inputs["index"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"custom_layer:Gather", "custom_layer:Gather",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2074,13 +2307,17 @@ def aten_gelu(mapper, graph, node): ...@@ -2074,13 +2307,17 @@ def aten_gelu(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.GELU", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.GELU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2102,15 +2339,21 @@ def aten___getitem__(mapper, graph, node): ...@@ -2102,15 +2339,21 @@ def aten___getitem__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%72 # 处理输入0,即%72
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["list"] = inputs_name[0] layer_inputs["list"] = inputs_name[0]
# 处理输入1,即%88 # 处理输入1,即%88
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["index"] = inputs_name[1] layer_inputs["index"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.getitem", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.getitem",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2132,15 +2375,112 @@ def aten_gt(mapper, graph, node): ...@@ -2132,15 +2375,112 @@ def aten_gt(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%82 # 处理输入0,即%82
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%78 # 处理输入1,即%78
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.gt", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.gt",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs
def aten_gru(mapper, graph, node):
""" 构造门控循环单元网络(GRU)的PaddleLayer。
TorchScript示例:
%21, %22 = aten::gru(%input, %hx, %20, %11, %10, %9, %11, %8, %11)
参数含义:
%21 (Tensor): 输出,由前向和后向cell的输出拼接得到。
%22 (Tensor): 输出,最终状态。
%input (Tensor): 网络输入。
%hx (Tensor): 网络的初始状态。
%20 (list): 所有权重组合成的list。
%11 (bool): 是否使用bias。
%10 (int): 网络层数。
%9 (float): dropout概率。
%11 (bool): 是否为训练阶段。
%8 (bool): 是否使用双向LSTM。
%11 (bool): 第一个维度是否为batch size。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("gru", mapper.nn_name2id)
output_names = mapper._get_outputs_name(node)
layer_outputs = [op_name]
layer_outputs.extend(output_names)
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = output_names
# 处理输入0,即%input.95
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input0"] = inputs_name[0]
# 处理输入1,即%734
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["input1"] = inputs_name[1]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入2,即%734
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
graph.layers.pop(mapper.output2id[inputs_name[2]])
param_inputs_name, _ = mapper._get_inputs_name(inputs_node[2])
new_param_inputs_name = list()
for i, param_name in enumerate(param_inputs_name):
if i == 0:
layer_attrs["hidden_size"] = int(
mapper.paddle_params[param_name].shape[0] / 3)
layer_attrs["input_size"] = int(mapper.paddle_params[param_name]
.shape[1])
if len(mapper.paddle_params[param_name].shape) > 1:
part_name = param_name.split("_weight_")[-1]
mapper.paddle_params["{}.weight_{}".format(
op_name, part_name)] = mapper.paddle_params[param_name]
new_param_inputs_name.append("{}.weight_{}".format(op_name,
part_name))
else:
part_name = param_name.split("_bias_")[-1]
mapper.paddle_params["{}.bias_{}".format(
op_name, part_name)] = mapper.paddle_params[param_name]
mapper.paddle_params.pop(param_name)
# 处理输入3,即%526
is_bias = mapper.attrs[inputs_name[3]]
if not is_bias:
for param_name in new_param_inputs_name:
bias_name = param_name.replace("weight", "bias")
bias_shape = mapper.paddle_params[param_name].shape[:1]
mapper.paddle_params[bias_name] = np.zeros(bias_shape).astype(
"float32")
# 处理输入4,即%525
layer_attrs["num_layers"] = mapper.attrs[inputs_name[4]]
# 处理输入5,即%524
layer_attrs["dropout"] = mapper.attrs[inputs_name[5]]
# 处理输入7,即%526
is_bidirectional = mapper.attrs[inputs_name[7]]
if is_bidirectional:
layer_attrs["direction"] = string("bidirectional")
# 处理输入8,即%526
batch_first = mapper.attrs[inputs_name[8]]
if not batch_first:
layer_attrs["time_major"] = True
graph.add_layer(
"paddle.nn.GRU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2165,7 +2505,8 @@ def aten_hardtanh_(mapper, graph, node): ...@@ -2165,7 +2505,8 @@ def aten_hardtanh_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.20 # 处理输入0,即%input.20
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2174,9 +2515,12 @@ def aten_hardtanh_(mapper, graph, node): ...@@ -2174,9 +2515,12 @@ def aten_hardtanh_(mapper, graph, node):
# 处理输入2,即%66 # 处理输入2,即%66
layer_attrs["max"] = mapper.attrs[inputs_name[2]] layer_attrs["max"] = mapper.attrs[inputs_name[2]]
if layer_attrs["min"] ==0 and layer_attrs["max"] == 6: if layer_attrs["min"] == 0 and layer_attrs["max"] == 6:
graph.add_layer( graph.add_layer(
"paddle.nn.ReLU6", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.ReLU6",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
graph.add_layer( graph.add_layer(
'paddle.nn.Hardtanh', 'paddle.nn.Hardtanh',
...@@ -2207,7 +2551,8 @@ def aten_index_select(mapper, graph, node): ...@@ -2207,7 +2551,8 @@ def aten_index_select(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x2.3 # 处理输入0,即%x2.3
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%320 # 处理输入1,即%320
if inputs_name[1] in mapper.attrs: if inputs_name[1] in mapper.attrs:
...@@ -2218,7 +2563,8 @@ def aten_index_select(mapper, graph, node): ...@@ -2218,7 +2563,8 @@ def aten_index_select(mapper, graph, node):
layer_inputs["axis"] = inputs_name[1] layer_inputs["axis"] = inputs_name[1]
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
# 处理输入2,即%371 # 处理输入2,即%371
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["index"] = inputs_name[2] layer_inputs["index"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2232,6 +2578,69 @@ def aten_index_select(mapper, graph, node): ...@@ -2232,6 +2578,69 @@ def aten_index_select(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_instance_norm(mapper, graph, node):
"""构造InstanceNorm的PaddleLayer
TorchScript示例:
%res.7 : Tensor = aten::instance_norm(%res.5, %88, %85, %84, %83, %87, %91, %92, %87)
参数含义:
%res.7 (Tensor): 输出,InstanceNorm的结果。
%res.5 (Tensor): 需要进行InstanceNorm的特征层。
%88 (Tensor): weights。
%85 (Tensor): bias。
%84 (Tensor): 全局均值。
%83 (Tensor): 全局方差。
%87 (bool): 是否使用输入的统计。
%91 (float): momentum。
%92 (float): eps。
%87 (bool): 是否启用cudnn。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("instance_norm", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.80
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%88
if inputs_name[1] in mapper.pytorch_params:
weights = mapper.pytorch_params[inputs_name[1]]
mapper.paddle_params[op_name + ".weight"] = weights
layer_attrs['num_features'] = weights.shape[0]
# 处理输入2,即%85
if inputs_name[2] in mapper.pytorch_params:
bias = mapper.pytorch_params[inputs_name[2]]
mapper.paddle_params[op_name + ".bias"] = bias
# 处理输入3,即%84
if inputs_name[3] in mapper.pytorch_params:
mean = mapper.pytorch_params[inputs_name[3]]
mapper.paddle_params[op_name + "._mean"] = mean
# 处理输入4,即%83
if inputs_name[4] in mapper.pytorch_params:
var = mapper.pytorch_params[inputs_name[4]]
mapper.paddle_params[op_name + "._variance"] = var
# 处理输入6,即%91
layer_attrs["momentum"] = 1 - mapper.attrs[inputs_name[6]]
# 处理输入7,即%92
layer_attrs["epsilon"] = mapper.attrs[inputs_name[7]]
graph.add_layer(
"custom_layer:InstanceNorm",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_Int(mapper, graph, node): def aten_Int(mapper, graph, node):
""" 构造强转为int的PaddleLayer。 """ 构造强转为int的PaddleLayer。
...@@ -2249,12 +2658,17 @@ def aten_Int(mapper, graph, node): ...@@ -2249,12 +2658,17 @@ def aten_Int(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1738 # 处理输入0,即%1738
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.int", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.int",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2276,15 +2690,21 @@ def aten___is__(mapper, graph, node): ...@@ -2276,15 +2690,21 @@ def aten___is__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.122 # 处理输入0,即%size.122
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%3931 # 处理输入1,即%3931
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.is", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.is",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2306,15 +2726,21 @@ def aten___isnot__(mapper, graph, node): ...@@ -2306,15 +2726,21 @@ def aten___isnot__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.122 # 处理输入0,即%size.122
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%3931 # 处理输入1,即%3931
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.isnot", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.isnot",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2342,7 +2768,8 @@ def aten_layer_norm(mapper, graph, node): ...@@ -2342,7 +2768,8 @@ def aten_layer_norm(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.6 # 处理输入0,即%input.6
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2388,15 +2815,21 @@ def aten_le(mapper, graph, node): ...@@ -2388,15 +2815,21 @@ def aten_le(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%78 # 处理输入0,即%78
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%79 # 处理输入1,即%79
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.le", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.le",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2420,7 +2853,8 @@ def aten_leaky_relu_(mapper, graph, node): ...@@ -2420,7 +2853,8 @@ def aten_leaky_relu_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2453,12 +2887,17 @@ def aten_len(mapper, graph, node): ...@@ -2453,12 +2887,17 @@ def aten_len(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%72 # 处理输入0,即%72
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.len", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.len",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2479,13 +2918,17 @@ def aten_log(mapper, graph, node): ...@@ -2479,13 +2918,17 @@ def aten_log(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%786 # 处理输入0,即%786
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.log", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.log",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2519,38 +2962,47 @@ def aten_lstm(mapper, graph, node): ...@@ -2519,38 +2962,47 @@ def aten_lstm(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = output_names current_outputs = output_names
# 处理输入0,即%input.95 # 处理输入0,即%input.95
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input0"] = inputs_name[0] layer_inputs["input0"] = inputs_name[0]
# 处理输入1,即%734 # 处理输入1,即%734
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["input1"] = inputs_name[1] layer_inputs["input1"] = inputs_name[1]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入2,即%734 # 处理输入2,即%734
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
graph.layers.pop(mapper.output2id[inputs_name[2]]) graph.layers.pop(mapper.output2id[inputs_name[2]])
param_inputs_name, _ = mapper._get_inputs_name(inputs_node[2]) param_inputs_name, _ = mapper._get_inputs_name(inputs_node[2])
new_param_inputs_name = list() new_param_inputs_name = list()
for i, param_name in enumerate(param_inputs_name): for i, param_name in enumerate(param_inputs_name):
if i == 0: if i == 0:
layer_attrs["hidden_size"] = int(mapper.paddle_params[param_name].shape[0] / 4) layer_attrs["hidden_size"] = int(
layer_attrs["input_size"] = int(mapper.paddle_params[param_name].shape[1]) mapper.paddle_params[param_name].shape[0] / 4)
layer_attrs["input_size"] = int(mapper.paddle_params[param_name]
.shape[1])
if len(mapper.paddle_params[param_name].shape) > 1: if len(mapper.paddle_params[param_name].shape) > 1:
part_name = param_name.split("_weight_")[-1] part_name = param_name.split("_weight_")[-1]
mapper.paddle_params["{}.weight_{}".format(op_name, part_name)] = mapper.paddle_params[param_name] mapper.paddle_params["{}.weight_{}".format(
new_param_inputs_name.append("{}.weight_{}".format(op_name, part_name)) op_name, part_name)] = mapper.paddle_params[param_name]
new_param_inputs_name.append("{}.weight_{}".format(op_name,
part_name))
else: else:
part_name = param_name.split("_bias_")[-1] part_name = param_name.split("_bias_")[-1]
mapper.paddle_params["{}.bias_{}".format(op_name, part_name)] = mapper.paddle_params[param_name] mapper.paddle_params["{}.bias_{}".format(
op_name, part_name)] = mapper.paddle_params[param_name]
mapper.paddle_params.pop(param_name) mapper.paddle_params.pop(param_name)
# 处理输入3,即%526 # 处理输入3,即%526
is_bias = mapper.attrs[inputs_name[3]] is_bias = mapper.attrs[inputs_name[3]]
if not is_bias: if not is_bias:
for param_name in new_param_inputs_name: for param_name in new_param_inputs_name:
bias_name = param_name.replace("weight", "bias") bias_name = param_name.replace("weight", "bias")
bias_shape= mapper.paddle_params[param_name].shape[:1] bias_shape = mapper.paddle_params[param_name].shape[:1]
mapper.paddle_params[bias_name] = np.zeros(bias_shape).astype("float32") mapper.paddle_params[bias_name] = np.zeros(bias_shape).astype(
"float32")
# 处理输入4,即%525 # 处理输入4,即%525
layer_attrs["num_layers"] = mapper.attrs[inputs_name[4]] layer_attrs["num_layers"] = mapper.attrs[inputs_name[4]]
# 处理输入5,即%524 # 处理输入5,即%524
...@@ -2590,15 +3042,21 @@ def aten_lt(mapper, graph, node): ...@@ -2590,15 +3042,21 @@ def aten_lt(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%78 # 处理输入0,即%78
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%79 # 处理输入1,即%79
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.lt", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.lt",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2623,7 +3081,8 @@ def aten_masked_fill_(mapper, graph, node): ...@@ -2623,7 +3081,8 @@ def aten_masked_fill_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.4 # 处理输入0,即%input.4
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
current_inputs.append(inputs_name[0]) current_inputs.append(inputs_name[0])
graph.add_layer( graph.add_layer(
"prim.type", "prim.type",
...@@ -2631,7 +3090,8 @@ def aten_masked_fill_(mapper, graph, node): ...@@ -2631,7 +3090,8 @@ def aten_masked_fill_(mapper, graph, node):
outputs=[inputs_name[0] + "_type"], outputs=[inputs_name[0] + "_type"],
scope_name=scope_name) scope_name=scope_name)
# 处理输入1,即%scores.2 # 处理输入1,即%scores.2
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.logical_not", "paddle.logical_not",
...@@ -2657,7 +3117,8 @@ def aten_masked_fill_(mapper, graph, node): ...@@ -2657,7 +3117,8 @@ def aten_masked_fill_(mapper, graph, node):
outputs=[inputs_name[0] + "_not_mask"], outputs=[inputs_name[0] + "_not_mask"],
scope_name=scope_name) scope_name=scope_name)
# 处理输入2,即%46 # 处理输入2,即%46
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.eq", "prim.eq",
inputs={"x": inputs_name[2]}, inputs={"x": inputs_name[2]},
...@@ -2683,14 +3144,16 @@ def aten_masked_fill_(mapper, graph, node): ...@@ -2683,14 +3144,16 @@ def aten_masked_fill_(mapper, graph, node):
outputs=[inputs_name[2] + "_if"], outputs=[inputs_name[2] + "_if"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.equal", "prim.equal",
inputs={"input": inputs_name[1] + "_mask"}, inputs={"input": inputs_name[1] + "_mask"},
outputs=[inputs_name[2] + "_1"], outputs=[inputs_name[2] + "_1"],
scope_name=scope_name) scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.mul", "prim.mul",
inputs={"x": inputs_name[1] + "_mask", inputs={"x": inputs_name[1] + "_mask",
...@@ -2731,7 +3194,8 @@ def aten_masked_fill(mapper, graph, node): ...@@ -2731,7 +3194,8 @@ def aten_masked_fill(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.4 # 处理输入0,即%input.4
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
current_inputs.append(inputs_name[0]) current_inputs.append(inputs_name[0])
graph.add_layer( graph.add_layer(
"prim.type", "prim.type",
...@@ -2739,7 +3203,8 @@ def aten_masked_fill(mapper, graph, node): ...@@ -2739,7 +3203,8 @@ def aten_masked_fill(mapper, graph, node):
outputs=[inputs_name[0] + "_type"], outputs=[inputs_name[0] + "_type"],
scope_name=scope_name) scope_name=scope_name)
# 处理输入1,即%scores.2 # 处理输入1,即%scores.2
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.logical_not", "paddle.logical_not",
...@@ -2765,7 +3230,8 @@ def aten_masked_fill(mapper, graph, node): ...@@ -2765,7 +3230,8 @@ def aten_masked_fill(mapper, graph, node):
outputs=[inputs_name[0] + "_not_mask"], outputs=[inputs_name[0] + "_not_mask"],
scope_name=scope_name) scope_name=scope_name)
# 处理输入2,即%46 # 处理输入2,即%46
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.eq", "prim.eq",
inputs={"x": inputs_name[2]}, inputs={"x": inputs_name[2]},
...@@ -2791,14 +3257,16 @@ def aten_masked_fill(mapper, graph, node): ...@@ -2791,14 +3257,16 @@ def aten_masked_fill(mapper, graph, node):
outputs=[inputs_name[2] + "_if"], outputs=[inputs_name[2] + "_if"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.equal", "prim.equal",
inputs={"input": inputs_name[1] + "_mask"}, inputs={"input": inputs_name[1] + "_mask"},
outputs=[inputs_name[2] + "_1"], outputs=[inputs_name[2] + "_1"],
scope_name=scope_name) scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.mul", "prim.mul",
inputs={"x": inputs_name[1] + "_mask", inputs={"x": inputs_name[1] + "_mask",
...@@ -2848,7 +3316,10 @@ def aten_max(mapper, graph, node): ...@@ -2848,7 +3316,10 @@ def aten_max(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.maximum", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.maximum",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
pass pass
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2879,7 +3350,8 @@ def aten_max_pool2d(mapper, graph, node): ...@@ -2879,7 +3350,8 @@ def aten_max_pool2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.11 # 处理输入0,即%result.11
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -2904,7 +3376,7 @@ def aten_max_pool2d(mapper, graph, node): ...@@ -2904,7 +3376,7 @@ def aten_max_pool2d(mapper, graph, node):
# 处理输入5,即%19 # 处理输入5,即%19
layer_attrs["ceil_mode"] = mapper.attrs[inputs_name[5]] layer_attrs["ceil_mode"] = mapper.attrs[inputs_name[5]]
layer_attrs_tmp["ceil_mode"] = mapper.attrs[inputs_name[5]] layer_attrs_tmp["ceil_mode"] = mapper.attrs[inputs_name[5]]
graph.add_layer( graph.add_layer(
"paddle.nn.MaxPool2D", "paddle.nn.MaxPool2D",
inputs=layer_inputs, inputs=layer_inputs,
...@@ -2932,15 +3404,21 @@ def aten_matmul(mapper, graph, node): ...@@ -2932,15 +3404,21 @@ def aten_matmul(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%101 # 处理输入0,即%101
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%102 # 处理输入1,即%102
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.matmul", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.matmul",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -2974,7 +3452,10 @@ def aten_min(mapper, graph, node): ...@@ -2974,7 +3452,10 @@ def aten_min(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.minimum", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.minimum",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
pass pass
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3001,7 +3482,8 @@ def aten_mean(mapper, graph, node): ...@@ -3001,7 +3482,8 @@ def aten_mean(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.1 # 处理输入0,即%result.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入1,即%4967 # 处理输入1,即%4967
...@@ -3047,13 +3529,18 @@ def aten_meshgrid(mapper, graph, node): ...@@ -3047,13 +3529,18 @@ def aten_meshgrid(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.1 # 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["args"] = inputs_name[0] layer_inputs["args"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = layer_inputs.values() current_inputs = layer_inputs.values()
current_outputs = layer_outputs current_outputs = layer_outputs
graph.add_layer("paddle.meshgrid", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.meshgrid",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3075,16 +3562,22 @@ def aten_mul(mapper, graph, node): ...@@ -3075,16 +3562,22 @@ def aten_mul(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size_prods.38 # 处理输入0,即%size_prods.38
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%114 # 处理输入1,即%114
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs current_outputs = layer_outputs
graph.add_layer("prim.mul", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.mul",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3106,16 +3599,22 @@ def aten_mul_(mapper, graph, node): ...@@ -3106,16 +3599,22 @@ def aten_mul_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size_prods.38 # 处理输入0,即%size_prods.38
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%114 # 处理输入1,即%114
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
current_outputs = layer_outputs current_outputs = layer_outputs
graph.add_layer("prim.mul", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.mul",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3137,15 +3636,21 @@ def aten_ne(mapper, graph, node): ...@@ -3137,15 +3636,21 @@ def aten_ne(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%123 # 处理输入1,即%123
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.ne", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.ne",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3166,12 +3671,76 @@ def aten_neg(mapper, graph, node): ...@@ -3166,12 +3671,76 @@ def aten_neg(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.neg", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.neg",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs
def aten_norm(mapper, graph, node):
""" 构造计算范数的PaddleLayer。
TorchScript示例:
%25 = aten::norm(%input, %21, %58, %24)
参数含义:
%25 (Tensor): 取范数后的结果。
%input (Tensor): 输入。
%21 (int): 范数的种类。
%58 (int): 使用范数计算的轴。
%24 (bool): 是否在输出的Tensor中保留和输入一样的维度。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
current_inputs = list(layer_inputs.values())
# 处理输入1,即%21
if inputs_name[1] in mapper.attrs:
layer_attrs["p"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["p"] = inputs_name[1]
current_inputs.append(inputs_name[1])
# 处理输入2,即%58
if inputs_name[1] in mapper.attrs:
layer_attrs["axis"] = mapper.attrs[inputs_name[2]]
else:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["axis"] = inputs_name[2]
current_inputs.append(inputs_name[2])
# 处理输入3,即%24
if inputs_name[1] in mapper.attrs:
layer_attrs["keepdim"] = mapper.attrs[inputs_name[3]]
else:
mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name)
layer_inputs["keepdim"] = inputs_name[3]
current_inputs.append(inputs_name[3])
graph.add_layer(
"paddle.norm",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3192,12 +3761,17 @@ def aten___not__(mapper, graph, node): ...@@ -3192,12 +3761,17 @@ def aten___not__(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%124 # 处理输入0,即%124
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.not", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.not",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3262,7 +3836,8 @@ def aten_permute(mapper, graph, node): ...@@ -3262,7 +3836,8 @@ def aten_permute(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%cls_confs0.2 # 处理输入0,即%cls_confs0.2
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3303,7 +3878,8 @@ def aten_pixel_shuffle(mapper, graph, node): ...@@ -3303,7 +3878,8 @@ def aten_pixel_shuffle(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.101 # 处理输入0,即%input.101
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入1,即%726 # 处理输入1,即%726
...@@ -3317,6 +3893,7 @@ def aten_pixel_shuffle(mapper, graph, node): ...@@ -3317,6 +3893,7 @@ def aten_pixel_shuffle(mapper, graph, node):
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_pow(mapper, graph, node): def aten_pow(mapper, graph, node):
""" 构造指数激活的PaddleLayer。 """ 构造指数激活的PaddleLayer。
...@@ -3335,7 +3912,8 @@ def aten_pow(mapper, graph, node): ...@@ -3335,7 +3912,8 @@ def aten_pow(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4700 # 处理输入0,即%4700
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3376,7 +3954,8 @@ def aten_prelu(mapper, graph, node): ...@@ -3376,7 +3954,8 @@ def aten_prelu(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.150 # 处理输入0,即%result.150
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%999 # 处理输入1,即%999
weight = mapper.pytorch_params[inputs_name[1]] weight = mapper.pytorch_params[inputs_name[1]]
...@@ -3385,14 +3964,108 @@ def aten_prelu(mapper, graph, node): ...@@ -3385,14 +3964,108 @@ def aten_prelu(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.PReLU", "paddle.nn.PReLU",
inputs=layer_inputs, inputs=layer_inputs,
outputs=layer_outputs, outputs=layer_outputs,
scope_name=scope_name, scope_name=scope_name,
num_parameters=weight.shape[0]) num_parameters=weight.shape[0])
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_reflection_pad1d(mapper, graph, node):
""" 构造1维映射填充的PaddleLayer。
TorchScript示例:
%6 = aten::reflection_pad1d(%input, %7)
参数含义:
%6 (Tensor): 输出,填充后的Tensor。
%input (Tensor): 需要填充的Tensor。
%7 (list|Tensor): 填充大小。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("pad1d", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%7
if inputs_name[1] in mapper.attrs:
layer_attrs["padding"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
ipt_node = inputs_node[1]
while ipt_node.kind() != "prim::GetAttr":
inputs_name, inputs_node = mapper._get_inputs_name(ipt_node)
ipt_node = inputs_node[0]
layer_attrs["padding"] = list(mapper.pytorch_params[inputs_name[0]])
layer_attrs["mode"] = string("reflect")
graph.add_layer(
"paddle.nn.Pad1D",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_reflection_pad2d(mapper, graph, node):
""" 构造2维映射填充的PaddleLayer。
TorchScript示例:
%6 = aten::reflection_pad2d(%input, %7)
参数含义:
%6 (Tensor): 输出,填充后的Tensor。
%input (Tensor): 需要填充的Tensor。
%7 (list|Tensor): 填充大小。
"""
scope_name = mapper.normalize_scope_name(node)
op_name = name_generator("pad2d", mapper.nn_name2id)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [op_name, output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%7
if inputs_name[1] in mapper.attrs:
layer_attrs["padding"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
ipt_node = inputs_node[1]
while ipt_node.kind() != "prim::GetAttr":
inputs_name, inputs_node = mapper._get_inputs_name(ipt_node)
ipt_node = inputs_node[0]
layer_attrs["padding"] = list(mapper.pytorch_params[inputs_name[0]])
layer_attrs["mode"] = string("reflect")
graph.add_layer(
"paddle.nn.Pad2D",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_relu(mapper, graph, node): def aten_relu(mapper, graph, node):
""" 构造ReLU激活的PaddleLayer。 """ 构造ReLU激活的PaddleLayer。
...@@ -3413,13 +4086,17 @@ def aten_relu(mapper, graph, node): ...@@ -3413,13 +4086,17 @@ def aten_relu(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.ReLU", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.ReLU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3443,13 +4120,17 @@ def aten_relu_(mapper, graph, node): ...@@ -3443,13 +4120,17 @@ def aten_relu_(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.ReLU", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.ReLU",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3473,13 +4154,17 @@ def aten_relu6(mapper, graph, node): ...@@ -3473,13 +4154,17 @@ def aten_relu6(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.ReLU6", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.ReLU6",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3502,7 +4187,8 @@ def aten_repeat(mapper, graph, node): ...@@ -3502,7 +4187,8 @@ def aten_repeat(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%699 # 处理输入0,即%699
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3543,7 +4229,8 @@ def aten_reshape(mapper, graph, node): ...@@ -3543,7 +4229,8 @@ def aten_reshape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4700 # 处理输入0,即%4700
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3555,7 +4242,7 @@ def aten_reshape(mapper, graph, node): ...@@ -3555,7 +4242,7 @@ def aten_reshape(mapper, graph, node):
current_outputs, scope_name) current_outputs, scope_name)
layer_inputs["shape"] = inputs_name[1] layer_inputs["shape"] = inputs_name[1]
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.reshape", "paddle.reshape",
inputs=layer_inputs, inputs=layer_inputs,
...@@ -3585,18 +4272,25 @@ def aten_rsub(mapper, graph, node): ...@@ -3585,18 +4272,25 @@ def aten_rsub(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%30 # 处理输入0,即%30
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%13 # 处理输入1,即%13
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 处理输入2,即%7 # 处理输入2,即%7
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["alpha"] = inputs_name[2] layer_inputs["alpha"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.rsub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.rsub",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3620,14 +4314,18 @@ def aten_ScalarImplicit(mapper, graph, node): ...@@ -3620,14 +4314,18 @@ def aten_ScalarImplicit(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%end.1 # 处理输入0,即%end.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
input_type = list(node.inputs())[0].type() input_type = list(node.inputs())[0].type()
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
if str(input_type) == "Tensor": if str(input_type) == "Tensor":
graph.add_layer( graph.add_layer(
"prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
raise Exception( raise Exception(
"The input type {} of aten::ScalarImplicit is not implemented yet!" "The input type {} of aten::ScalarImplicit is not implemented yet!"
...@@ -3655,12 +4353,14 @@ def aten_select(mapper, graph, node): ...@@ -3655,12 +4353,14 @@ def aten_select(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%18 # 处理输入0,即%18
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 处理输入1,即%8 # 处理输入1,即%8
layer_attrs["dim"] = mapper.attrs[inputs_name[1]] layer_attrs["dim"] = mapper.attrs[inputs_name[1]]
# 处理输入2,即%75 # 处理输入2,即%75
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["index"] = inputs_name[2] layer_inputs["index"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3690,18 +4390,22 @@ def aten__set_item(mapper, graph, node): ...@@ -3690,18 +4390,22 @@ def aten__set_item(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [] current_outputs = []
# 处理输入0,即%features.1 # 处理输入0,即%features.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["dict"] = inputs_name[0] layer_inputs["dict"] = inputs_name[0]
# 处理输入1,即%out_name.1 # 处理输入1,即%out_name.1
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["key"] = inputs_name[1] layer_inputs["key"] = inputs_name[1]
# 处理输入2,即%x.3 # 处理输入2,即%x.3
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["value"] = inputs_name[2] layer_inputs["value"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.set_item", inputs=layer_inputs, outputs=[], scope_name=scope_name) graph.add_layer(
"prim.set_item", inputs=layer_inputs, outputs=[], scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3723,13 +4427,17 @@ def aten_sigmoid(mapper, graph, node): ...@@ -3723,13 +4427,17 @@ def aten_sigmoid(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%54 # 处理输入0,即%54
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.Sigmoid", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.Sigmoid",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3750,12 +4458,17 @@ def aten_sin(mapper, graph, node): ...@@ -3750,12 +4458,17 @@ def aten_sin(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%sinusoid_inp.1 # 处理输入0,即%sinusoid_inp.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.sin", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.sin",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3778,7 +4491,8 @@ def aten_size(mapper, graph, node): ...@@ -3778,7 +4491,8 @@ def aten_size(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.12 # 处理输入0,即%x.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -3800,7 +4514,10 @@ def aten_size(mapper, graph, node): ...@@ -3800,7 +4514,10 @@ def aten_size(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
graph.add_layer( graph.add_layer(
"prim.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.shape",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3939,7 +4656,10 @@ def aten_slice(mapper, graph, node): ...@@ -3939,7 +4656,10 @@ def aten_slice(mapper, graph, node):
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.slice", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.slice",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -3964,7 +4684,8 @@ def aten_softmax(mapper, graph, node): ...@@ -3964,7 +4684,8 @@ def aten_softmax(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.31 # 处理输入0,即%x.31
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4000,7 +4721,8 @@ def aten_softplus(mapper, graph, node): ...@@ -4000,7 +4721,8 @@ def aten_softplus(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.31 # 处理输入0,即%x.31
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4036,7 +4758,8 @@ def aten_split_with_sizes(mapper, graph, node): ...@@ -4036,7 +4758,8 @@ def aten_split_with_sizes(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%1446 # 处理输入0,即%1446
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%1750 # 处理输入1,即%1750
if inputs_name[1] in mapper.attrs: if inputs_name[1] in mapper.attrs:
...@@ -4083,13 +4806,17 @@ def aten_sqrt(mapper, graph, node): ...@@ -4083,13 +4806,17 @@ def aten_sqrt(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%786 # 处理输入0,即%786
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.sqrt", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.sqrt",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4112,7 +4839,8 @@ def aten_squeeze(mapper, graph, node): ...@@ -4112,7 +4839,8 @@ def aten_squeeze(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%start_logits.1 # 处理输入0,即%start_logits.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4152,7 +4880,8 @@ def aten_stack(mapper, graph, node): ...@@ -4152,7 +4880,8 @@ def aten_stack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%13 # 处理输入0,即%13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4193,11 +4922,17 @@ def aten_sub(mapper, graph, node): ...@@ -4193,11 +4922,17 @@ def aten_sub(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%839 # 处理输入0,即%839
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%836 # 处理输入1,即%836
mapper._check_input( mapper._check_input(
graph, inputs_node[1], inputs_name[1], current_outputs, scope_name, add_dim=True) graph,
inputs_node[1],
inputs_name[1],
current_outputs,
scope_name,
add_dim=True)
layer_inputs["y"] = inputs_name[1] layer_inputs["y"] = inputs_name[1]
# 处理输入2,即%3 # 处理输入2,即%3
if len(inputs_node) > 2: if len(inputs_node) > 2:
...@@ -4213,7 +4948,12 @@ def aten_sub(mapper, graph, node): ...@@ -4213,7 +4948,12 @@ def aten_sub(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.sub", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name, **layer_attrs) graph.add_layer(
"prim.sub",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4230,6 +4970,7 @@ def aten_sub_(mapper, graph, node): ...@@ -4230,6 +4970,7 @@ def aten_sub_(mapper, graph, node):
""" """
return aten_sub(mapper, graph, node) return aten_sub(mapper, graph, node)
def aten_t(mapper, graph, node): def aten_t(mapper, graph, node):
""" 构造矩阵转置的PaddleLayer。 """ 构造矩阵转置的PaddleLayer。
...@@ -4247,7 +4988,8 @@ def aten_t(mapper, graph, node): ...@@ -4247,7 +4988,8 @@ def aten_t(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.12 # 处理输入0,即%x.12
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4279,13 +5021,17 @@ def aten_tanh(mapper, graph, node): ...@@ -4279,13 +5021,17 @@ def aten_tanh(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%result.5 # 处理输入0,即%result.5
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.nn.Tanh", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.nn.Tanh",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4309,13 +5055,16 @@ def aten_split(mapper, graph, node): ...@@ -4309,13 +5055,16 @@ def aten_split(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%159 # 处理输入0,即%159
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入2,即%723 # 处理输入2,即%723
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["axis"] = inputs_name[2] layer_inputs["axis"] = inputs_name[2]
# 处理输入1,即%135 # 处理输入1,即%135
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
input_type = list(node.inputs())[0].type() input_type = list(node.inputs())[0].type()
if "[]" in str(input_type): if "[]" in str(input_type):
layer_inputs["num_or_sections"] = inputs_name[1] layer_inputs["num_or_sections"] = inputs_name[1]
...@@ -4353,16 +5102,19 @@ def aten_transpose(mapper, graph, node): ...@@ -4353,16 +5102,19 @@ def aten_transpose(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.21 # 处理输入0,即%x.21
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%704 # 处理输入1,即%704
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
dim1 = inputs_name[1] dim1 = inputs_name[1]
# 处理输入2,即%705 # 处理输入2,即%705
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
dim2 = inputs_name[2] dim2 = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.shape", "prim.shape",
inputs={"input": inputs_name[0]}, inputs={"input": inputs_name[0]},
...@@ -4441,7 +5193,8 @@ def aten_to(mapper, graph, node): ...@@ -4441,7 +5193,8 @@ def aten_to(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%13 # 处理输入0,即%13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4478,12 +5231,14 @@ def aten_type_as(mapper, graph, node): ...@@ -4478,12 +5231,14 @@ def aten_type_as(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%56 # 处理输入0,即%56
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
# 处理输入0,即%mask.1 # 处理输入0,即%mask.1
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.type", "prim.type",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
...@@ -4493,7 +5248,10 @@ def aten_type_as(mapper, graph, node): ...@@ -4493,7 +5248,10 @@ def aten_type_as(mapper, graph, node):
current_inputs.append(inputs_name[1]) current_inputs.append(inputs_name[1])
graph.add_layer( graph.add_layer(
"paddle.cast", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.cast",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4516,7 +5274,8 @@ def aten_unsqueeze(mapper, graph, node): ...@@ -4516,7 +5274,8 @@ def aten_unsqueeze(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%13 # 处理输入0,即%13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4559,7 +5318,8 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -4559,7 +5318,8 @@ def aten_upsample_bilinear2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.13 # 处理输入0,即%x.13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4583,14 +5343,16 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -4583,14 +5343,16 @@ def aten_upsample_bilinear2d(mapper, graph, node):
outputs=[inputs_name[0] + "_if1"], outputs=[inputs_name[0] + "_if1"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.var2list", "prim.var2list",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
outputs=[inputs_name[1]], outputs=[inputs_name[1]],
scope_name=scope_name) scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[1] if_layer.inputs["input-0"] = inputs_name[1]
# 处理输入2,即%5421 # 处理输入2,即%5421
...@@ -4615,6 +5377,7 @@ def aten_upsample_bilinear2d(mapper, graph, node): ...@@ -4615,6 +5377,7 @@ def aten_upsample_bilinear2d(mapper, graph, node):
**layer_attrs) **layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_upsample_nearest2d(mapper, graph, node): def aten_upsample_nearest2d(mapper, graph, node):
""" 构造使用nearest上采样的PaddleLayer。 """ 构造使用nearest上采样的PaddleLayer。
...@@ -4636,7 +5399,8 @@ def aten_upsample_nearest2d(mapper, graph, node): ...@@ -4636,7 +5399,8 @@ def aten_upsample_nearest2d(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.13 # 处理输入0,即%x.13
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4660,14 +5424,16 @@ def aten_upsample_nearest2d(mapper, graph, node): ...@@ -4660,14 +5424,16 @@ def aten_upsample_nearest2d(mapper, graph, node):
outputs=[inputs_name[0] + "_if1"], outputs=[inputs_name[0] + "_if1"],
scope_name=scope_name) scope_name=scope_name)
if_layer = graph.layers[list(graph.layers.keys())[-1]] if_layer = graph.layers[list(graph.layers.keys())[-1]]
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
block.add_layer( block.add_layer(
"prim.var2list", "prim.var2list",
inputs={"input": inputs_name[1]}, inputs={"input": inputs_name[1]},
outputs=[inputs_name[1]], outputs=[inputs_name[1]],
scope_name=scope_name) scope_name=scope_name)
if_layer.add_block(block) if_layer.add_block(block)
block = PaddleGraph(source_type="pytorch", parent_layer=if_layer, graph_type="dygraph") block = PaddleGraph(
source_type="pytorch", parent_layer=if_layer, graph_type="dygraph")
if_layer.add_block(block) if_layer.add_block(block)
if_layer.inputs["input-0"] = inputs_name[1] if_layer.inputs["input-0"] = inputs_name[1]
if "size" in layer_attrs and layer_attrs["size"] is None: if "size" in layer_attrs and layer_attrs["size"] is None:
...@@ -4702,12 +5468,17 @@ def aten_values(mapper, graph, node): ...@@ -4702,12 +5468,17 @@ def aten_values(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%78 # 处理输入0,即%78
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict2values", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.dict2values",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4736,7 +5507,8 @@ def aten_view(mapper, graph, node): ...@@ -4736,7 +5507,8 @@ def aten_view(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%x.20 # 处理输入0,即%x.20
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list # 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4775,7 +5547,8 @@ def aten_warn(mapper, graph, node): ...@@ -4775,7 +5547,8 @@ def aten_warn(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%3 # 处理输入0,即%3
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
...@@ -4816,18 +5589,25 @@ def aten_where(mapper, graph, node): ...@@ -4816,18 +5589,25 @@ def aten_where(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%209 # 处理输入0,即%209
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["condition"] = inputs_name[0] layer_inputs["condition"] = inputs_name[0]
# 处理输入1,即%w0.2 # 处理输入1,即%w0.2
mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs, scope_name) mapper._check_input(graph, inputs_node[1], inputs_name[1], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[1] layer_inputs["x"] = inputs_name[1]
# 处理输入1,即%w0.2 # 处理输入1,即%w0.2
mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs, scope_name) mapper._check_input(graph, inputs_node[2], inputs_name[2], current_outputs,
scope_name)
layer_inputs["y"] = inputs_name[2] layer_inputs["y"] = inputs_name[2]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("paddle.where", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"paddle.where",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -4896,7 +5676,8 @@ def aten_zeros_like(mapper, graph, node): ...@@ -4896,7 +5676,8 @@ def aten_zeros_like(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%n.2 # 处理输入0,即%n.2
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0] layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
......
...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node): ...@@ -37,20 +37,24 @@ def prim_Constant(mapper, graph, node):
tensor_value = value tensor_value = value
value = "{}".format(value) value = "{}".format(value)
if "tensor" in value: if "tensor" in value:
if isinstance(tensor_value, list) or isinstance(tensor_value, tuple): if isinstance(tensor_value, list) or isinstance(tensor_value,
tuple):
name_dict = dict() name_dict = dict()
for i, tv in enumerate(tensor_value): for i, tv in enumerate(tensor_value):
output_name_i = "{}_p{}".format(output_name,i) output_name_i = "{}_p{}".format(output_name, i)
key_i = "input{}".format(i) key_i = "input{}".format(i)
mapper.paddle_params[output_name_i] = tv.cpu().detach().numpy() mapper.paddle_params[output_name_i] = tv.cpu().detach(
).numpy()
graph.add_layer( graph.add_layer(
"self.create_parameter", "self.create_parameter",
inputs={}, inputs={},
outputs=[output_name_i], outputs=[output_name_i],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name_i].dtype)), dtype=string(
shape = mapper.paddle_params[output_name_i].shape, str(mapper.paddle_params[output_name_i].dtype)),
default_initializer="paddle.nn.initializer.Constant(value=0.0)") shape=mapper.paddle_params[output_name_i].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
name_dict[key_i] = output_name_i name_dict[key_i] = output_name_i
graph.add_layer( graph.add_layer(
"prim.list", "prim.list",
...@@ -59,8 +63,19 @@ def prim_Constant(mapper, graph, node): ...@@ -59,8 +63,19 @@ def prim_Constant(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
return [], [output_name] return [], [output_name]
else: else:
mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper.paddle_params[output_name] = tensor_value.cpu().detach(
).numpy()
graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[output_name],
scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name].dtype)),
shape=mapper.paddle_params[output_name].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
return [], [output_name]
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
if str(value).startswith("-"): if str(value).startswith("-"):
...@@ -72,7 +87,11 @@ def prim_Constant(mapper, graph, node): ...@@ -72,7 +87,11 @@ def prim_Constant(mapper, graph, node):
value = int(math.pow(2, 31) - 1) value = int(math.pow(2, 31) - 1)
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=value) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=value)
return [], [output_name] return [], [output_name]
...@@ -96,18 +115,23 @@ def prim_data(mapper, graph, node): ...@@ -96,18 +115,23 @@ def prim_data(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%4336 # 处理输入0,即%4336
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_DictConstruct(mapper, graph, node): def prim_DictConstruct(mapper, graph, node):
""" 构建dict。 """ 构建dict。
TorchScript示例: TorchScript示例:
%32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29) %32 : Dict(str, Tensor) = prim::DictConstruct(%30, %23, %31, %29)
参数含义: 参数含义:
...@@ -127,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node): ...@@ -127,22 +151,22 @@ def prim_DictConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
if i%2 == 0: if i % 2 == 0:
layer_attrs["key{}".format(int(i/2))] = mapper.attrs[input_name] layer_attrs["key{}".format(int(i / 2))] = mapper.attrs[input_name]
else: else:
layer_inputs["value{}".format(int(i/2))] = input_name layer_inputs["value{}".format(int(i / 2))] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.dict_construct", graph.add_layer(
inputs=layer_inputs, "prim.dict_construct",
outputs=layer_outputs, inputs=layer_inputs,
scope_name=scope_name, outputs=layer_outputs,
**layer_attrs) scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_GetAttr(mapper, graph, node): def prim_GetAttr(mapper, graph, node):
""" 获取attribute信息。 """ 获取attribute信息。
...@@ -203,8 +227,13 @@ def prim_If(mapper, graph, node): ...@@ -203,8 +227,13 @@ def prim_If(mapper, graph, node):
input_node = list(node.inputs())[0].node() input_node = list(node.inputs())[0].node()
script_input_unique_id = list(node.inputs())[0].unique() script_input_unique_id = list(node.inputs())[0].unique()
input_node_name = mapper.outputs_info[script_input_unique_id] input_node_name = mapper.outputs_info[script_input_unique_id]
mapper._check_input(graph, input_node, input_node_name, current_outputs, scope_name) mapper._check_input(graph, input_node, input_node_name, current_outputs,
graph.add_layer("prim.if", inputs={'input': input_node_name}, outputs=node_outputs, scope_name=scope_name) scope_name)
graph.add_layer(
"prim.if",
inputs={'input': input_node_name},
outputs=node_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block0 = list(node.blocks())[0] block0 = list(node.blocks())[0]
block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer) block0_graph, graph_inputs0 = mapper.traverse(block0, current_layer)
...@@ -240,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node): ...@@ -240,12 +269,17 @@ def prim_ListConstruct(mapper, graph, node):
current_outputs = [output_name] current_outputs = [output_name]
# 处理每个输入 # 处理每个输入
for i, input_name in enumerate(inputs_name): for i, input_name in enumerate(inputs_name):
mapper._check_input(graph, inputs_node[i], input_name, current_outputs, scope_name) mapper._check_input(graph, inputs_node[i], input_name, current_outputs,
scope_name)
layer_inputs["input{}".format(i)] = input_name layer_inputs["input{}".format(i)] = input_name
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
layer_id = graph.add_layer("prim.list", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) layer_id = graph.add_layer(
"prim.list",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.output2id[output_name] = layer_id mapper.output2id[output_name] = layer_id
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -268,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node): ...@@ -268,13 +302,17 @@ def prim_ListUnpack(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = layer_outputs.copy() current_outputs = layer_outputs.copy()
# 处理输入0,即%4354 # 处理输入0,即%4354
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.list_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.list_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs) mapper.split_len[list(layer_inputs.values())[0]] = len(layer_outputs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -333,7 +371,11 @@ def prim_Loop(mapper, graph, node): ...@@ -333,7 +371,11 @@ def prim_Loop(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
node_outputs.append(block_input_node_name) node_outputs.append(block_input_node_name)
graph.add_layer("prim.loop", inputs=loop_inputs, outputs=loop_outputs, scope_name=scope_name) graph.add_layer(
"prim.loop",
inputs=loop_inputs,
outputs=loop_outputs,
scope_name=scope_name)
current_layer = list(graph.layers.values())[-1] current_layer = list(graph.layers.values())[-1]
block_graph, graph_inputs = mapper.traverse(block, current_layer) block_graph, graph_inputs = mapper.traverse(block, current_layer)
for i, input_name in enumerate(graph_inputs): for i, input_name in enumerate(graph_inputs):
...@@ -361,12 +403,17 @@ def prim_min(mapper, graph, node): ...@@ -361,12 +403,17 @@ def prim_min(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.min", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.min",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -388,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node): ...@@ -388,14 +435,19 @@ def prim_NumToTensor(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(inputs_node[0]) scope_name)
inputs_inputs_name, inputs_inputs_node = mapper._get_inputs_name(
inputs_node[0])
if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1: if inputs_node[0].kind() == "aten::size" and len(inputs_inputs_name) > 1:
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim_equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim_equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
else: else:
layer_inputs["fill_value"] = inputs_name[0] layer_inputs["fill_value"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
...@@ -428,13 +480,17 @@ def prim_RaiseException(mapper, graph, node): ...@@ -428,13 +480,17 @@ def prim_RaiseException(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%76 # 处理输入0,即%76
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.exception", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.exception",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -455,13 +511,17 @@ def prim_requires_grad(mapper, graph, node): ...@@ -455,13 +511,17 @@ def prim_requires_grad(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%86 # 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.requires_grad", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.requires_grad",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -518,13 +578,17 @@ def prim_shape(mapper, graph, node): ...@@ -518,13 +578,17 @@ def prim_shape(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%input.8 # 处理输入0,即%input.8
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"paddle.shape", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "paddle.shape",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -551,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node): ...@@ -551,7 +615,11 @@ def prim_TupleConstruct(mapper, graph, node):
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.tuple", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.tuple",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -569,15 +637,23 @@ def prim_TupleUnpack(mapper, graph, node): ...@@ -569,15 +637,23 @@ def prim_TupleUnpack(mapper, graph, node):
outputs_name = mapper._get_outputs_name(node) outputs_name = mapper._get_outputs_name(node)
layer_outputs = outputs_name layer_outputs = outputs_name
layer_inputs = {} layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node) inputs_name, inputs_node = mapper._get_inputs_name(node)
if inputs_node[0].kind() == "prim::GetAttr":
layer_attrs["input"] = list(mapper.pytorch_params[inputs_name[0]])
else:
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = outputs_name current_outputs = outputs_name
layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer( graph.add_layer(
"prim.tuple_unpack", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) "prim.tuple_unpack",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -601,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node): ...@@ -601,12 +677,17 @@ def prim_unchecked_cast(mapper, graph, node):
# 获取当前节点输出的list # 获取当前节点输出的list
current_outputs = [output_name] current_outputs = [output_name]
# 处理输入0,即%size.63 # 处理输入0,即%size.63
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name) mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["input"] = inputs_name[0] layer_inputs["input"] = inputs_name[0]
# 获取当前节点输入的list # 获取当前节点输入的list
current_inputs = list(layer_inputs.values()) current_inputs = list(layer_inputs.values())
graph.add_layer("prim.equal", inputs=layer_inputs, outputs=layer_outputs, scope_name=scope_name) graph.add_layer(
"prim.equal",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name)
return current_inputs, current_outputs return current_inputs, current_outputs
...@@ -623,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node): ...@@ -623,5 +704,9 @@ def prim_Uninitialized(mapper, graph, node):
output = list(node.outputs())[0] output = list(node.outputs())[0]
mapper.attrs[output_name] = None mapper.attrs[output_name] = None
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], scope_name=scope_name, value=None) "prim.constant",
inputs={},
outputs=[output_name],
scope_name=scope_name,
value=None)
return [], [output_name] return [], [output_name]
...@@ -14,7 +14,8 @@ ...@@ -14,7 +14,8 @@
# limitations under the License. # limitations under the License.
NO_OUTPUT_COUNT = 0 NO_OUTPUT_COUNT = 0
def gen_codes(code_list, indent=0): def gen_codes(code_list, indent=0):
indent_blank = " " * indent indent_blank = " " * indent
codes = [] codes = []
...@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None): ...@@ -53,13 +54,24 @@ def get_value(layer, key, layer_id=None, different_attrs=None):
return str(layer.attrs[key]) return str(layer.attrs[key])
def prim_add(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_add(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {}".format(layer.outputs[0], line = "{} = {} + {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_add_(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} + {} * {}".format(layer.outputs[0], line = "{} = {} + {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
layer.attrs["alpha"], layer.attrs["alpha"],
...@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -67,22 +79,39 @@ def prim_add_(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_and(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_and(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} and {}".format(layer.outputs[0], line = "{} = {} and {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "y", different_attrs)) get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_append(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_append(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}.append({})".format( line = "{}.append({})".format(
get_value(layer, "list", layer_id, different_attrs), get_value(layer, "list", layer_id, different_attrs),
get_value(layer, "element", layer_id, different_attrs)) get_value(layer, "element", layer_id, different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_assert(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
if layer.attrs["type"] == "eq": if layer.attrs["type"] == "eq":
values = get_value(layer, "key") values = get_value(layer, "key")
if "value" in layer.attrs: if "value" in layer.attrs:
...@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d ...@@ -93,15 +122,15 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
s += "{} == {} or ".format(get_value(layer, "key"), v) s += "{} == {} or ".format(get_value(layer, "key"), v)
if len(s) > 0: if len(s) > 0:
s = s[:-4] s = s[:-4]
lc=locals() lc = locals()
exec("assert_result = {}".format(s)) exec("assert_result = {}".format(s))
assert_result = lc['assert_result'] assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format( line = "assert {}, \'The {} must be {}!\'".format(
s, get_value(layer, "key"), get_value(layer, "value")) s, get_value(layer, "key"), get_value(layer, "value"))
else: else:
s = "{} == {}".format(get_value(layer, "key"), s = "{} == {}".format(
get_value(layer, "value")) get_value(layer, "key"), get_value(layer, "value"))
lc=locals() lc = locals()
exec("assert_result = {}".format(s)) exec("assert_result = {}".format(s))
assert_result = lc['assert_result'] assert_result = lc['assert_result']
line = "assert {}, \'The {} must be {}!\'".format( line = "assert {}, \'The {} must be {}!\'".format(
...@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d ...@@ -112,7 +141,12 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[], layer_id=None, d
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_check_dim(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = [] lines = []
dim = get_value(layer, "dim", different_attrs) dim = get_value(layer, "dim", different_attrs)
lines.append("if {} < 0:".format(dim)) lines.append("if {} < 0:".format(dim))
...@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None ...@@ -123,12 +157,23 @@ def prim_check_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_constant(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_constant(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"]) line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_contain(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} in {}".format(layer.outputs[0], line = "{} = {} in {}".format(layer.outputs[0],
get_value(layer, "element", different_attrs), get_value(layer, "element", different_attrs),
get_value(layer, "input", different_attrs)) get_value(layer, "input", different_attrs))
...@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None, ...@@ -137,108 +182,182 @@ def prim_contain(layer, indent=1, init_func=[], forward_func=[], layer_id=None,
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_dict(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = dict()".format(layer.outputs[0]) line = "{} = dict()".format(layer.outputs[0])
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_dict_construct(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict_construct(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = list() lines = list()
line = "{} = dict()".format(layer.outputs[0]) line = "{} = dict()".format(layer.outputs[0])
lines.append(line) lines.append(line)
for i in range(len(layer.inputs)): for i in range(len(layer.inputs)):
line = "{}[{}] = {}".format(layer.outputs[0], line = "{}[{}] = {}".format(
get_value(layer, "key{}".format(i), different_attrs), layer.outputs[0],
get_value(layer, "value{}".format(i), different_attrs)) get_value(layer, "key{}".format(i), different_attrs),
get_value(layer, "value{}".format(i), different_attrs))
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_dict2values(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_dict2values(layer,
line = "{} = list({}.values())".format(layer.outputs[0], indent=1,
get_value(layer, "x", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = list({}.values())".format(
layer.outputs[0], get_value(layer, "x", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_div(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_div(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} / {}".format(layer.outputs[0], line = "{} = {} / {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_eq(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None,is_return_line=False): def prim_eq(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} == {}".format(layer.outputs[0], line = "{} = {} == {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_equal(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_equal(layer,
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_exception(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_exception(layer,
line = "raise Exception({})".format(get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "raise Exception({})".format(
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_float(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_float(layer,
line = "{} = float({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = float({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_floor(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_floor(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = math.floor({})".format(layer.outputs[0], line = "{} = math.floor({})".format(layer.outputs[0],
get_value(layer, "x", different_attrs)) get_value(layer, "x", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_floordiv(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_floordiv(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} // {}".format(layer.outputs[0], line = "{} = {} // {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_getitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_getitem(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[{}]".format(layer.outputs[0], line = "{} = {}[{}]".format(layer.outputs[0],
get_value(layer, "list", different_attrs), get_value(layer, "list", different_attrs),
get_value(layer, "index", different_attrs)) get_value(layer, "index", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_gt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_gt(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} > {}".format(layer.outputs[0], line = "{} = {} > {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_if(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
try: try:
exec_s = None exec_s = None
for line in forward_func: for line in forward_func:
s = line.replace(" ", "") s = line.replace(" ", "")
if s.startswith("{} = ".format(get_value(layer, "input", different_attrs))): if s.startswith("{} = ".format(
get_value(layer, "input", different_attrs))):
exec_s = s.split(" = ")[1] exec_s = s.split(" = ")[1]
lc=locals() lc = locals()
if exec_s is not None: if exec_s is not None:
exec("if_result = {}".format(exec_s)) exec("if_result = {}".format(exec_s))
else: else:
exec("if_result = {}".format(get_value(layer, "input", different_attrs))) exec("if_result = {}".format(
get_value(layer, "input", different_attrs)))
if_result = lc['if_result'] if_result = lc['if_result']
if if_result: if if_result:
block = layer.blocks[0] block = layer.blocks[0]
else: else:
block = layer.blocks[1] block = layer.blocks[1]
if len(block.layers) > 0: if len(block.layers) > 0:
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent) b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent)
init_func.extend(b_init_lines) init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
except: except:
...@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe ...@@ -249,7 +368,8 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
line = "pass" line = "pass"
forward_func.extend(gen_codes([line], indent=indent + 1)) forward_func.extend(gen_codes([line], indent=indent + 1))
else: else:
b_init_lines, b_forward_lines = block.gen_dygraph_code(indent=indent + 1) b_init_lines, b_forward_lines = block.gen_dygraph_code(
indent=indent + 1)
init_func.extend(b_init_lines) init_func.extend(b_init_lines)
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
block = layer.blocks[1] block = layer.blocks[1]
...@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe ...@@ -263,30 +383,54 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[], layer_id=None, diffe
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
def prim_int(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_int(layer,
line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = int({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_is(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_is(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} is {}".format(layer.outputs[0], line = "{} = {} is {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_isinstance(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_isinstance(layer,
line = "{} = isinstance({}, {})".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
layer.attrs["cls"]) forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = isinstance({}, {})".format(
layer.outputs[0],
get_value(layer, "input", different_attrs), layer.attrs["cls"])
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_isnot(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} is not {}".format(layer.outputs[0], line = "{} = {} is not {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
...@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di ...@@ -295,53 +439,94 @@ def prim_isnot(layer, indent=1, init_func=[], forward_func=[], layer_id=None, di
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_le(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_le(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} <= {}".format(layer.outputs[0], line = "{} = {} <= {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_len(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_len(layer,
line = "{} = len({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = len({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_len2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_len2list(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = [] lines = []
lines.append("{} = []".format(layer.outputs[0])) lines.append("{} = []".format(layer.outputs[0]))
lines.append("for i in range({}):".format(get_value(layer, "len", different_attrs))) lines.append("for i in range({}):".format(
get_value(layer, "len", different_attrs)))
lines.append(" {}.append(i)".format(layer.outputs[0])) lines.append(" {}.append(i)".format(layer.outputs[0]))
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
def prim_lt(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_lt(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} < {}".format(layer.outputs[0], line = "{} = {} < {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
input_len = len(layer.inputs) + len(layer.attrs) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list() inputs_list = list()
for i in range(input_len): for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) inputs_list.append(
get_value(layer, "input{}".format(i), different_attrs))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = [{}]".format(layer.outputs[0], inputs_str) line = "{} = [{}]".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_list_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_list_unpack(layer,
line = "{} = {}".format(", ".join(layer.outputs), get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(", ".join(layer.outputs),
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_loop(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
loop_range = get_value(layer, "input", different_attrs) loop_range = get_value(layer, "input", different_attrs)
line = "for {} in range({}):".format(layer.outputs[1], loop_range) line = "for {} in range({}):".format(layer.outputs[1], loop_range)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif ...@@ -351,171 +536,303 @@ def prim_loop(layer, indent=1, init_func=[], forward_func=[], layer_id=None, dif
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
def prim_min(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_min(layer,
line = "{} = min({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = min({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_mul(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_mul(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {} * {}".format(layer.outputs[0], line = "{} = {} * {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_ne(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_ne(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} != {}".format(layer.outputs[0], line = "{} = {} != {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_neg(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_neg(layer,
line = "{} = -{}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = -{}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_not(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_not(layer,
line = "{} = not {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = not {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_or(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): def prim_or(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {} or {}".format(layer.outputs[0], line = "{} = {} or {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_replaceitem(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_replaceitem(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}[{}] = {}".format( line = "{}[{}] = {}".format(
get_value(layer, "list", layer_id, different_attrs), get_value(layer, "list", layer_id, different_attrs),
get_value(layer, "index", layer_id, different_attrs), get_value(layer, "index", layer_id, different_attrs),
get_value(layer, "item", layer_id, different_attrs)) get_value(layer, "item", layer_id, different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_requires_grad(layer,
line = "{} = not {}.stop_gradient".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = not {}.stop_gradient".format(
layer.outputs[0], get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_rsub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_rsub(layer,
line = "{} = {} - {} * {}".format(layer.outputs[0], indent=1,
get_value(layer, "y", different_attrs), init_func=[],
get_value(layer, "x", different_attrs), forward_func=[],
get_value(layer, "alpha", different_attrs)) layer_id=None,
different_attrs=None):
line = "{} = {} - {} * {}".format(
layer.outputs[0],
get_value(layer, "y", different_attrs),
get_value(layer, "x", different_attrs),
get_value(layer, "alpha", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_select(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_select(layer,
line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
for dim in range(layer.attrs["dim"]): for dim in range(layer.attrs["dim"]):
line += ":, " line += ":, "
line += (get_value(layer, "index", different_attrs) + "]") line += (get_value(layer, "index", different_attrs) + "]")
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_set_attr(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_set_attr(layer,
line = "{} = {}".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_set_item(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_set_item(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{}[{}] = {}".format( line = "{}[{}] = {}".format(
get_value(layer, "dict", different_attrs), get_value(layer, "dict", different_attrs),
get_value(layer, "key", different_attrs), get_value(layer, "value", different_attrs)) get_value(layer, "key", different_attrs),
get_value(layer, "value", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_shape(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None):
line = "{} = {}.shape".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_shape(layer,
line = "{} = {}.shape[{}]".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
get_value(layer, "dim", different_attrs)) forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.shape".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_slice(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_shape_dim(layer,
line = "{} = {}[{}: {}: {}]".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs), init_func=[],
get_value(layer, "start", different_attrs), forward_func=[],
get_value(layer, "end", different_attrs), layer_id=None,
get_value(layer, "step", different_attrs)) different_attrs=None):
forward_func.extend(gen_codes([line], indent=indent)) line = "{} = {}.shape[{}]".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
def prim_startswith(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None, is_return_line=False): get_value(layer, "dim", different_attrs))
line = "{} = {}.startswith({})".format(layer.outputs[0], forward_func.extend(gen_codes([line], indent=indent))
get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs))
def prim_slice(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}[{}: {}: {}]".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "start", different_attrs),
get_value(layer, "end", different_attrs),
get_value(layer, "step", different_attrs))
forward_func.extend(gen_codes([line], indent=indent))
def prim_startswith(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None,
is_return_line=False):
line = "{} = {}.startswith({})".format(
layer.outputs[0],
get_value(layer, "input", different_attrs),
get_value(layer, "start_str", different_attrs))
if is_return_line: if is_return_line:
return line.split(" = ")[1] return line.split(" = ")[1]
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_str(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_str(layer,
line = "{} = str({})".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = str({})".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_sub(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_sub(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
if int(float(get_value(layer, "alpha", different_attrs))) == 1: if int(float(get_value(layer, "alpha", different_attrs))) == 1:
line = "{} = {} - {}".format(layer.outputs[0], line = "{} = {} - {}".format(layer.outputs[0],
get_value(layer, "x", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "y", different_attrs))
else: else:
line = "{} = {} - {} * {}".format(layer.outputs[0], line = "{} = {} - {} * {}".format(
get_value(layer, "x", different_attrs), layer.outputs[0],
get_value(layer, "alpha", different_attrs), get_value(layer, "x", different_attrs),
get_value(layer, "y", different_attrs)) get_value(layer, "alpha", different_attrs),
get_value(layer, "y", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_tuple(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
input_len = len(layer.inputs) + len(layer.attrs) input_len = len(layer.inputs) + len(layer.attrs)
inputs_list = list() inputs_list = list()
for i in range(input_len): for i in range(input_len):
inputs_list.append(get_value(layer, "input{}".format(i), different_attrs)) inputs_list.append(
get_value(layer, "input{}".format(i), different_attrs))
inputs_str = ', '.join(inputs_list) inputs_str = ', '.join(inputs_list)
line = "{} = ({})".format(layer.outputs[0], inputs_str) line = "{} = ({})".format(layer.outputs[0], inputs_str)
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_tuple_unpack(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_tuple_unpack(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
outputs_str = ', '.join(layer.outputs) outputs_str = ', '.join(layer.outputs)
line = "{} = {}".format(outputs_str, get_value(layer, "input", different_attrs)) line = "{} = {}".format(outputs_str,
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_type(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_type(layer,
line = "{} = {}.dtype".format(layer.outputs[0], get_value(layer, "input", different_attrs)) indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.dtype".format(layer.outputs[0],
get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_var2list(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_var2list(layer,
line = "{} = {}.numpy().tolist()".format(layer.outputs[0], indent=1,
get_value(layer, "input", different_attrs)) init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
line = "{} = {}.numpy().tolist()".format(
layer.outputs[0], get_value(layer, "input", different_attrs))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_warnings(layer, indent=1, init_func=[], forward_func=[], layer_id=None, different_attrs=None): def prim_warnings(layer,
indent=1,
init_func=[],
forward_func=[],
layer_id=None,
different_attrs=None):
lines = ["import warnings"] lines = ["import warnings"]
line = "warnings.warn({}, stacklevel={})".format( line = "warnings.warn({}, stacklevel={})".format(
get_value(layer, "input", different_attrs), layer.attrs["stacklevel"]) get_value(layer, "input", different_attrs), layer.attrs["stacklevel"])
lines.append(line) lines.append(line)
forward_func.extend(gen_codes(lines, indent=indent)) forward_func.extend(gen_codes(lines, indent=indent))
...@@ -13,4 +13,5 @@ ...@@ -13,4 +13,5 @@
# limitations under the License. # limitations under the License.
from .gather import Gather from .gather import Gather
\ No newline at end of file from .instance_norm import InstanceNorm
\ No newline at end of file
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.nn.functional import instance_norm
from paddle.fluid.initializer import Constant
class InstanceNorm(paddle.nn.Layer):
"""
This class is based class for InstanceNorm1D, 2d, 3d.
See InstaceNorm1D, InstanceNorm2D or InstanceNorm3D for more details.
"""
def __init__(self,
num_features,
epsilon=1e-5,
momentum=0.9,
weight_attr=None,
bias_attr=None,
data_format="NCHW",
name=None):
super(InstanceNorm, self).__init__()
if weight_attr == False or bias_attr == False:
assert weight_attr == bias_attr, "weight_attr and bias_attr must be set to Fasle at the same time in InstanceNorm"
self._epsilon = epsilon
self._weight_attr = weight_attr
self._bias_attr = bias_attr
if weight_attr != False and bias_attr != False:
self.scale = self.create_parameter(
attr=self._weight_attr,
shape=[num_features],
default_initializer=Constant(1.0),
is_bias=False)
self.bias = self.create_parameter(
attr=self._bias_attr,
shape=[num_features],
default_initializer=Constant(0.0),
is_bias=True)
else:
self.scale = None
self.bias = None
def forward(self, input):
return instance_norm(
input, weight=self.scale, bias=self.bias, eps=self._epsilon)
def extra_repr(self):
return 'num_features={}, epsilon={}'.format(self.scale.shape[0],
self._epsilon)
...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -37,7 +37,7 @@ class PyTorchOpMapper(OpMapper):
self.scope_name_list = list() self.scope_name_list = list()
self.scope_name2id = dict() self.scope_name2id = dict()
self.inputs_info = dict() self.inputs_info = dict()
self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node self.output2id = dict() # output名字和layer_id的关系,用于lstm去除前面的node
# 转换 # 转换
if not self.op_checker(decoder.graph): if not self.op_checker(decoder.graph):
raise Exception("Model is not supported yet.") raise Exception("Model is not supported yet.")
...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper): ...@@ -50,6 +50,7 @@ class PyTorchOpMapper(OpMapper):
op_list.append(node.kind()) op_list.append(node.kind())
for block in node.blocks(): for block in node.blocks():
_update_op_list(block) _update_op_list(block)
op_list = list() op_list = list()
_update_op_list(script_graph) _update_op_list(script_graph)
op_list = list(set(op_list)) op_list = list(set(op_list))
...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper): ...@@ -62,11 +63,11 @@ class PyTorchOpMapper(OpMapper):
return True return True
else: else:
if len(unsupported_ops) > 0: if len(unsupported_ops) > 0:
print("\n========= {} OPs are not supported yet ===========".format( print("\n========= {} OPs are not supported yet ===========".
len(unsupported_ops))) format(len(unsupported_ops)))
for op in unsupported_ops: for op in unsupported_ops:
print("========== {} ============".format(op)) print("========== {} ============".format(op))
return False return False
def traverse(self, script_graph, parent_layer=None): def traverse(self, script_graph, parent_layer=None):
# 用于获取graph的输入 # 用于获取graph的输入
...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper): ...@@ -85,20 +86,24 @@ class PyTorchOpMapper(OpMapper):
current_node_outputs.extend(outputs) current_node_outputs.extend(outputs)
# 初始化 # 初始化
graph = PaddleGraph(source_type="pytorch", parent_layer=parent_layer, graph_type="dygraph") graph = PaddleGraph(
source_type="pytorch",
parent_layer=parent_layer,
graph_type="dygraph")
if "TopLevelTracedModule" in str(type(self.script)): if "TopLevelTracedModule" in str(type(self.script)):
graph.set_script(self.script) graph.set_script(self.script)
current_node_outputs = [] current_node_outputs = []
graph_inputs = [] graph_inputs = []
# 转换输入节点 # 转换输入节点
if isinstance(script_graph, torch._C.Graph): if isinstance(script_graph, torch._C.Graph):
input_ct = 0 input_ct = 0
for i, ivalue in enumerate(script_graph.inputs()): for i, ivalue in enumerate(script_graph.inputs()):
node = ivalue.node() node = ivalue.node()
if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]: if str(ivalue.type()) not in ["Tensor", "Dict[str, Tensor]"]:
graph.set_name(str(ivalue.type()).split(".")[-1]) graph.set_name(str(ivalue.type()).split(".")[-1])
continue continue
inputs, outputs = self.data(graph, node, ivalue.unique(), input_ct) inputs, outputs = self.data(graph, node,
ivalue.unique(), input_ct)
input_ct += 1 input_ct += 1
# 转换中间节点 # 转换中间节点
for node in script_graph.nodes(): for node in script_graph.nodes():
...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -183,8 +188,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
else: else:
if isinstance(param, dict) and "Tensor" in param and \ if isinstance(param, dict) and "Tensor" in param and \
...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper): ...@@ -211,8 +217,9 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)"
)
node_outputs.append(output_name) node_outputs.append(output_name)
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
return return
...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -232,7 +239,8 @@ class PyTorchOpMapper(OpMapper):
value=string(param) value=string(param)
if isinstance(param, str) else param) if isinstance(param, str) else param)
node_outputs.append(output_name) node_outputs.append(output_name)
elif node.kind() == "prim::Constant" and output_name in self.pytorch_params: elif node.kind(
) == "prim::Constant" and output_name in self.pytorch_params:
param = self.pytorch_params[output_name] param = self.pytorch_params[output_name]
self.paddle_params[output_name] = param self.paddle_params[output_name] = param
layer_id = graph.add_layer( layer_id = graph.add_layer(
...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper): ...@@ -241,11 +249,10 @@ class PyTorchOpMapper(OpMapper):
outputs=[output_name], outputs=[output_name],
scope_name=scope_name, scope_name=scope_name,
dtype=string(str(param.dtype)), dtype=string(str(param.dtype)),
shape = param.shape, shape=param.shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)") default_initializer="paddle.nn.initializer.Constant(value=0.0)")
self.output2id[output_name] = layer_id self.output2id[output_name] = layer_id
def _get_inputs_name(self, node): def _get_inputs_name(self, node):
inputs_name = [] inputs_name = []
inputs_node = [] inputs_node = []
...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper): ...@@ -256,7 +263,6 @@ class PyTorchOpMapper(OpMapper):
inputs_node.append(script_input_node) inputs_node.append(script_input_node)
inputs_name.append(input_name) inputs_name.append(input_name)
return inputs_name, inputs_node return inputs_name, inputs_node
def data(self, graph, node, uid, input_ct): def data(self, graph, node, uid, input_ct):
scope_name = self.normalize_scope_name(node) scope_name = self.normalize_scope_name(node)
...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -276,7 +282,8 @@ class PyTorchOpMapper(OpMapper):
data=output_name) data=output_name)
if self.input_examples is not None: if self.input_examples is not None:
input_np = self.input_examples[input_ct].detach().numpy() input_np = self.input_examples[input_ct].detach().numpy()
self.inputs_info[output_name] = [list(input_np.shape), str(input_np.dtype)] self.inputs_info[
output_name] = [list(input_np.shape), str(input_np.dtype)]
return [], [output_name] return [], [output_name]
def equal(self, graph, node, uid=None, parent_layer=None, index=None): def equal(self, graph, node, uid=None, parent_layer=None, index=None):
...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper): ...@@ -289,7 +296,8 @@ class PyTorchOpMapper(OpMapper):
control_output_id = index - 1 control_output_id = index - 1
output_node_name = parent_layer.outputs[control_output_id] output_node_name = parent_layer.outputs[control_output_id]
current_outputs = [output_node_name] current_outputs = [output_node_name]
self._check_input(graph, node, input_node_name, current_outputs, scope_name) self._check_input(graph, node, input_node_name, current_outputs,
scope_name)
graph.add_layer( graph.add_layer(
"prim.equal", "prim.equal",
inputs={'input': input_node_name}, inputs={'input': input_node_name},
...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper): ...@@ -321,20 +329,20 @@ class PyTorchOpMapper(OpMapper):
self.scope_name2id[i][ns] = 0 self.scope_name2id[i][ns] = 0
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
real_father_scope_name = "/".join(name_segments[1:-1]) real_father_scope_name = "/".join(name_segments[1:-1])
for i, ns in enumerate(name_segments): for i, ns in enumerate(name_segments):
if i == 0: if i == 0:
continue continue
if self.scope_name2id[i][ns] != 0: if self.scope_name2id[i][ns] != 0:
name_segments[i] = name_segments[i] + \ name_segments[i] = name_segments[i] + \
"__{}".format(self.scope_name2id[i][ns]) "__{}".format(self.scope_name2id[i][ns])
prefix_scope_name = "/".join(name_segments[1 :i + 1]) prefix_scope_name = "/".join(name_segments[1:i + 1])
is_found = False is_found = False
for j in range(len(self.scope_name_list)): for j in range(len(self.scope_name_list)):
last_scope_name = self.scope_name_list[-1-j] last_scope_name = self.scope_name_list[-1 - j]
if last_scope_name.startswith(prefix_scope_name + "/") \ if last_scope_name.startswith(prefix_scope_name + "/") \
or last_scope_name == prefix_scope_name: or last_scope_name == prefix_scope_name:
if j != 0: # and i != len(name_segments) - 1: if j != 0: # and i != len(name_segments) - 1:
is_found = True is_found = True
origin_name_segment_i = name_segments[i].split("__")[0] origin_name_segment_i = name_segments[i].split("__")[0]
self.scope_name2id[i][origin_name_segment_i] += 1 self.scope_name2id[i][origin_name_segment_i] += 1
...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper): ...@@ -346,4 +354,3 @@ class PyTorchOpMapper(OpMapper):
real_scope_name = "/".join(name_segments[1:]) real_scope_name = "/".join(name_segments[1:])
self.scope_name_list.append(real_scope_name) self.scope_name_list.append(real_scope_name)
return real_scope_name return real_scope_name
\ No newline at end of file
...@@ -248,8 +248,10 @@ class TFOpMapper(OpMapper): ...@@ -248,8 +248,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_input_node(node, 0) input = self.graph.get_input_node(node, 0)
perm = self.graph.get_input_node(node, 1) perm = self.graph.get_input_node(node, 1)
assert perm.layer_type == "Const", "Perm of transpose OP should be Const" if perm.layer_type == "Const":
perm = perm.value.tolist() perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.transpose", "paddle.transpose",
...@@ -641,12 +643,18 @@ class TFOpMapper(OpMapper): ...@@ -641,12 +643,18 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_input_node(node, 1) paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
constant_values = 0
if len(node.layer.input) > 2:
constant_values = self.graph.get_input_node(node, 2)
assert constant_values.layer_type == "Const", "Padding should be Const"
constant_values = constant_values.value
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad", kernel="paddle.nn.functional.pad",
inputs={"x": input.name}, inputs={"x": input.name},
outputs=[node.name], outputs=[node.name],
pad=paddings) pad=paddings,
value=constant_values)
def MirrorPad(self, node): def MirrorPad(self, node):
self.Pad(node) self.Pad(node)
......
...@@ -238,8 +238,10 @@ class TFOpMapper(OpMapper): ...@@ -238,8 +238,10 @@ class TFOpMapper(OpMapper):
def Transpose(self, node): def Transpose(self, node):
input = self.graph.get_node(node.layer.input[0]) input = self.graph.get_node(node.layer.input[0])
perm = self.graph.get_node(node.layer.input[1]) perm = self.graph.get_node(node.layer.input[1])
assert perm.layer_type == "Const", "Perm of transpose OP should be Const" if perm.layer_type == "Const":
perm = perm.value.tolist() perm = perm.value.tolist()
else:
perm = self.decoder.infer_tensor(perm, use_diff_inputs=False).tolist()
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.transpose", kernel="paddle.transpose",
...@@ -629,12 +631,18 @@ class TFOpMapper(OpMapper): ...@@ -629,12 +631,18 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_input_node(node, 1) paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const" assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist() paddings = paddings.value.flatten().tolist()
constant_values = 0
if len(node.layer.input) > 2:
constant_values = self.graph.get_input_node(node, 2)
assert constant_values.layer_type == "Const", "Padding should be Const"
constant_values = constant_values.value
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad", kernel="paddle.nn.functional.pad",
inputs={"x": input.name}, inputs={"x": input.name},
outputs=[node.name], outputs=[node.name],
pad=paddings) pad=paddings,
value=constant_values)
def MirrorPad(self, node): def MirrorPad(self, node):
self.Pad(node) self.Pad(node)
......
...@@ -27,6 +27,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -27,6 +27,8 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Linear": "linear", "paddle.nn.Linear": "linear",
"paddle.nn.Conv2DTranspose": "conv", "paddle.nn.Conv2DTranspose": "conv",
"paddle.nn.LSTM": "lstm", "paddle.nn.LSTM": "lstm",
"paddle.nn.GRU": "gru",
"custom_layer:InstanceNorm": "instance_norm",
"paddle.nn.PReLU": "prelu", "paddle.nn.PReLU": "prelu",
"paddle.nn.ReLU": "relu", "paddle.nn.ReLU": "relu",
"paddle.nn.ReLU6": "relu", "paddle.nn.ReLU6": "relu",
...@@ -35,14 +37,14 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn", ...@@ -35,14 +37,14 @@ NN_KERNEL_NAME = {"paddle.nn.BatchNorm": "bn",
"paddle.nn.Tanh": "tanh", "paddle.nn.Tanh": "tanh",
"paddle.nn.AvgPool2D": "avgpool", "paddle.nn.AvgPool2D": "avgpool",
"paddle.nn.MaxPool2D": "maxpool", "paddle.nn.MaxPool2D": "maxpool",
"paddle.nn.Pad1D": "pad", "paddle.nn.Pad1D": "pad1d",
"paddle.nn.Pad2D": "pad", "paddle.nn.Pad2D": "pad2d",
"paddle.nn.Pad3D": "pad", "paddle.nn.Pad3D": "pad3d",
"paddle.nn.Dropout": "dropout", "paddle.nn.Dropout": "dropout",
"paddle.nn.GELU": "gelu", "paddle.nn.GELU": "gelu",
"paddle.nn.Hardtanh": "tanh", "paddle.nn.Hardtanh": "tanh",
"paddle.nn.LeakyReLU": "leakly_relu"} "paddle.nn.LeakyReLU": "leakly_relu"}
NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:8] NN_KERNEL_WITH_PARAMS = list(NN_KERNEL_NAME.keys())[:10]
def rename_layers(layers, param_tree=None, is_rename_module=False): def rename_layers(layers, param_tree=None, is_rename_module=False):
""" 对子模块的输入输出等进行重命名。 """ 对子模块的输入输出等进行重命名。
...@@ -143,7 +145,10 @@ def _update_attrs(layer, different_attrs): ...@@ -143,7 +145,10 @@ def _update_attrs(layer, different_attrs):
if key_name in different_attrs: if key_name in different_attrs:
common_attrs.pop(k) common_attrs.pop(k)
special_attrs[k] = v special_attrs[k] = v
remove_default_attrs(layer.kernel, common_attrs) remove_kernel = layer.kernel
if remove_kernel == "custom_layer:InstanceNorm":
remove_kernel = "paddle.nn.InstanceNorm2D"
remove_default_attrs(remove_kernel, common_attrs)
common_attrs.update(special_attrs) common_attrs.update(special_attrs)
layer.attrs = common_attrs layer.attrs = common_attrs
......
...@@ -212,6 +212,8 @@ class ModuleGraph(object): ...@@ -212,6 +212,8 @@ class ModuleGraph(object):
layer_id_list2 = list(sub_layers2.keys()) layer_id_list2 = list(sub_layers2.keys())
for i, layer_id1 in enumerate(layer_id_list1): for i, layer_id1 in enumerate(layer_id_list1):
layer_id2 = layer_id_list2[i] layer_id2 = layer_id_list2[i]
if layer_id2 not in self.pd_graph.edges_in:
return False
if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]): if len(self.pd_graph.edges_in[layer_id1]) != len(self.pd_graph.edges_in[layer_id2]):
return False return False
for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]): for j, ipt_layer_id1 in enumerate(self.pd_graph.edges_in[layer_id1]):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册