提交 e3b6e8a0 编写于 作者: S SunAhong1993

add aten and prim for nlp

上级 872d81fa
...@@ -323,6 +323,7 @@ class PaddleGraph(object): ...@@ -323,6 +323,7 @@ class PaddleGraph(object):
[ [
"from paddle.fluid.initializer import Constant", "from paddle.fluid.initializer import Constant",
"from paddle.fluid.param_attr import ParamAttr", "from paddle.fluid.param_attr import ParamAttr",
"import paddle",
"import paddle.fluid as fluid", "import paddle.fluid as fluid",
"", "",
"class {}(fluid.dygraph.Layer):".format(self.name), "class {}(fluid.dygraph.Layer):".format(self.name),
...@@ -369,7 +370,7 @@ class PaddleGraph(object): ...@@ -369,7 +370,7 @@ class PaddleGraph(object):
and layer.kernel != "prim.exception" \ and layer.kernel != "prim.exception" \
and layer.kernel != "prim.warnings": and layer.kernel != "prim.warnings":
continue continue
if "dygraph" in layer.kernel: if "paddle.nn" in layer.kernel or layer.kernel == "fluid.dygraph.base.to_variable":
line = "{}".format( line = "{}".format(
layer.outputs[0] layer.outputs[0]
) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[ ) if layer.kernel == "fluid.dygraph.base.to_variable" and not layer.attrs[
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
from x2paddle.core.util import * from x2paddle.core.util import *
...@@ -27,9 +28,12 @@ def prim_Constant(mapper, graph, node): ...@@ -27,9 +28,12 @@ def prim_Constant(mapper, graph, node):
output_name = mapper._get_outputs_name(node)[0] output_name = mapper._get_outputs_name(node)[0]
output = list(node.outputs())[0] output = list(node.outputs())[0]
value = output.toIValue() value = output.toIValue()
output_type = output.type()
mapper.attrs[output_name] = value mapper.attrs[output_name] = value
if isinstance(value, str): if isinstance(value, str):
value = string(value) value = string(value)
if str(output_type) == "Tensor":
value = "paddle.to_tensor({})".format(value)
graph.add_layer( graph.add_layer(
"prim.constant", inputs={}, outputs=[output_name], value=value) "prim.constant", inputs={}, outputs=[output_name], value=value)
return [], [output_name] return [], [output_name]
...@@ -89,6 +93,10 @@ def prim_GetAttr(mapper, graph, node): ...@@ -89,6 +93,10 @@ def prim_GetAttr(mapper, graph, node):
param = getattr(part_script, field_name) param = getattr(part_script, field_name)
if isinstance(param, torch.Tensor): if isinstance(param, torch.Tensor):
param = param.detach().numpy() param = param.detach().numpy()
if len(param.shape) == 0:
param = np.reshape(param, 1)
if str(param.dtype) == "uint8":
param = param.astype("int32")
mapper.pytorch_params[output_name] = param mapper.pytorch_params[output_name] = param
part_script = param part_script = param
return [], [output_name] return [], [output_name]
...@@ -276,6 +284,40 @@ def prim_min(mapper, graph, node): ...@@ -276,6 +284,40 @@ def prim_min(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def prim_NumToTensor(mapper, graph, node):
""" 构造转为Tensor的PaddleLayer。
TorchScript示例:
%other.2 : Tensor = prim::NumToTensor(%1736)
参数含义:
%other.2 (Tensor): 输出。
%1736 (-): 输入。
"""
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%86
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs)
layer_inputs["value"] = inputs_name[0]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
input_type = list(node.inputs())[0].type()
layer_attrs["dtype"] = input_type
layer_attrs["persistable"] = True
layer_attrs["shape"] = [1]
graph.add_layer(
"fluid.layers.create_global_var",
inputs=layer_inputs,
outputs=layer_outputs,
**layer_attrs)
return current_inputs, current_outputs
def prim_RaiseException(mapper, graph, node): def prim_RaiseException(mapper, graph, node):
""" 构造抛出异常的PaddleLayer。 """ 构造抛出异常的PaddleLayer。
......
...@@ -83,6 +83,17 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[]): ...@@ -83,6 +83,17 @@ def prim_assert(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_check_dim(layer, indent=1, init_func=[], forward_func=[]):
lines = []
lines.append("if {} < 0:".format(get_value(layer, "dim")))
lines.append(" {} = {} + {}".format(layer.outputs[
0], get_value(layer, "dim"), get_value(layer, "len")))
lines.append("else:")
lines.append(" {} = {}".format(layer.outputs[0], get_value(layer,
"dim")))
forward_func.extend(gen_codes(lines, indent=indent))
def prim_constant(layer, indent=1, init_func=[], forward_func=[]): def prim_constant(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}".format(layer.outputs[0], layer.attrs["value"]) line = "{} = {}".format(layer.outputs[0], layer.attrs["value"])
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
...@@ -100,6 +111,12 @@ def prim_dict(layer, indent=1, init_func=[], forward_func=[]): ...@@ -100,6 +111,12 @@ def prim_dict(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_div(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} / {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_eq(layer, indent=1, init_func=[], forward_func=[]): def prim_eq(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} == {}".format(layer.outputs[0], line = "{} = {} == {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y")) get_value(layer, "x"), get_value(layer, "y"))
...@@ -163,6 +180,11 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]): ...@@ -163,6 +180,11 @@ def prim_if(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(b_forward_lines) forward_func.extend(b_forward_lines)
def prim_int(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = int({})".format(layer.outputs[0], get_value(layer, "input"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_is(layer, indent=1, init_func=[], forward_func=[]): def prim_is(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} is {}".format(layer.outputs[0], line = "{} = {} is {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y")) get_value(layer, "x"), get_value(layer, "y"))
...@@ -235,6 +257,8 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[]): ...@@ -235,6 +257,8 @@ def prim_mul(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} * {}".format(layer.outputs[0], line = "{} = {} * {}".format(layer.outputs[0],
get_value(layer, "x"), get_value(layer, "y")) get_value(layer, "x"), get_value(layer, "y"))
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
if "x2589" in layer.outputs:
print(layer.inputs["y"])
def prim_ne(layer, indent=1, init_func=[], forward_func=[]): def prim_ne(layer, indent=1, init_func=[], forward_func=[]):
...@@ -266,6 +290,14 @@ def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]): ...@@ -266,6 +290,14 @@ def prim_requires_grad(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_rsub(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {} - {} * {}".format(layer.outputs[0],
get_value(layer, "y"),
get_value(layer, "x"),
get_value(layer, "alpha"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_select(layer, indent=1, init_func=[], forward_func=[]): def prim_select(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input")) line = "{} = {}[".format(layer.outputs[0], get_value(layer, "input"))
for dim in range(layer.attrs["dim"]): for dim in range(layer.attrs["dim"]):
...@@ -291,6 +323,13 @@ def prim_shape(layer, indent=1, init_func=[], forward_func=[]): ...@@ -291,6 +323,13 @@ def prim_shape(layer, indent=1, init_func=[], forward_func=[]):
forward_func.extend(gen_codes([line], indent=indent)) forward_func.extend(gen_codes([line], indent=indent))
def prim_shape_dim(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}.shape[{}]".format(layer.outputs[0],
get_value(layer, "input"),
get_value(layer, "dim"))
forward_func.extend(gen_codes([line], indent=indent))
def prim_slice(layer, indent=1, init_func=[], forward_func=[]): def prim_slice(layer, indent=1, init_func=[], forward_func=[]):
line = "{} = {}[{}: {}: {}]".format(layer.outputs[0], line = "{} = {}[{}: {}: {}]".format(layer.outputs[0],
get_value(layer, "input"), get_value(layer, "input"),
......
...@@ -215,7 +215,7 @@ class BatchNorm2dFuser(FuseBase): ...@@ -215,7 +215,7 @@ class BatchNorm2dFuser(FuseBase):
pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph") pattern_block1 = PaddleGraph(if_layer3, graph_type="dygraph")
if_layer3.add_block(pattern_block1) if_layer3.add_block(pattern_block1)
self.pattern.add_layer( self.pattern.add_layer(
"fluid.dygraph.BatchNorm", "paddle.nn.BatchNorm",
inputs={"input": "bn-input-0"}, inputs={"input": "bn-input-0"},
outputs=[gen_name(34), gen_name(35)], outputs=[gen_name(34), gen_name(35)],
is_test=True, is_test=True,
......
...@@ -34,7 +34,7 @@ class FcFuser(FuseBase): ...@@ -34,7 +34,7 @@ class FcFuser(FuseBase):
classifier_6_weight = self.classifier_6_weight classifier_6_weight = self.classifier_6_weight
x136 = fluid.layers.transpose(x=classifier_6_weight, perm=[1, 0]) x136 = fluid.layers.transpose(x=classifier_6_weight, perm=[1, 0])
classifier_6_bias = self.classifier_6_bias classifier_6_bias = self.classifier_6_bias
x137 = fluid.layers.addmm(input=classifier_6_bias, x=x128, y=x136, beta=1, alpha=1) x137 = paddle.addmm(input=classifier_6_bias, x=x128, y=x136, beta=1, alpha=1)
x135 = x137 x135 = x137
else: else:
classifier_6_weight = self.classifier_6_weight classifier_6_weight = self.classifier_6_weight
...@@ -82,7 +82,7 @@ class FcFuser(FuseBase): ...@@ -82,7 +82,7 @@ class FcFuser(FuseBase):
outputs=[gen_name(7)], outputs=[gen_name(7)],
value="params[{}]".format(string(gen_name(7)))) value="params[{}]".format(string(gen_name(7))))
pattern_block0.add_layer( pattern_block0.add_layer(
"fluid.layers.addmm", "paddle.addmm",
inputs={"input": gen_name(7), inputs={"input": gen_name(7),
"x": "fc-input-0", "x": "fc-input-0",
"y": gen_name(6)}, "y": gen_name(6)},
...@@ -155,7 +155,7 @@ class FcFuser(FuseBase): ...@@ -155,7 +155,7 @@ class FcFuser(FuseBase):
bias_name]) bias_name])
new_layer = PaddleLayer( new_layer = PaddleLayer(
layers_id[0], layers_id[0],
"fluid.dygraph.Linear", "paddle.nn.Linear",
inputs={"input": input_name}, inputs={"input": input_name},
outputs=[linear_name, output_name], outputs=[linear_name, output_name],
**attrs) **attrs)
......
...@@ -49,10 +49,12 @@ class PatternMatcher(object): ...@@ -49,10 +49,12 @@ class PatternMatcher(object):
# 判断输入连接是否一致 # 判断输入连接是否一致
if layer_id in graph.edges_in: if layer_id in graph.edges_in:
if pattern_layer_id not in pattern.edges_in: if pattern_layer_id not in pattern.edges_in:
print("1--")
return False return False
else: else:
if len(graph.edges_in[layer_id]) != len( if len(graph.edges_in[layer_id]) != len(
pattern.edges_in[pattern_layer_id]): pattern.edges_in[pattern_layer_id]):
print("2--")
return False return False
layer_in = graph.edges_in[layer_id] layer_in = graph.edges_in[layer_id]
pattern_layer_in = pattern.edges_in[pattern_layer_id] pattern_layer_in = pattern.edges_in[pattern_layer_id]
...@@ -66,6 +68,7 @@ class PatternMatcher(object): ...@@ -66,6 +68,7 @@ class PatternMatcher(object):
# 判断pattern输入在pattern_ids的索引 # 判断pattern输入在pattern_ids的索引
# 和graph输入在subgraph_ids的索引一致 # 和graph输入在subgraph_ids的索引一致
continue continue
print("3--")
return False return False
# 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效) # 判断subgraph中的节点是否被外部图使用到(如若被使用到则无效)
if layer_id in graph.edges_out: if layer_id in graph.edges_out:
...@@ -73,6 +76,7 @@ class PatternMatcher(object): ...@@ -73,6 +76,7 @@ class PatternMatcher(object):
if not set(pattern_layer.outputs).issubset( if not set(pattern_layer.outputs).issubset(
pattern.outputs): pattern.outputs):
# 若pattern当前layer的输出是pattern的输出,则是正确的 # 若pattern当前layer的输出是pattern的输出,则是正确的
return False return False
else: else:
if len(graph.edges_out[layer_id]) != len( if len(graph.edges_out[layer_id]) != len(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册