提交 4f892dd9 编写于 作者: S SunAhong1993

fix the gru

上级 78d96bbf
...@@ -804,6 +804,47 @@ def aten_clamp(mapper, graph, node): ...@@ -804,6 +804,47 @@ def aten_clamp(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_clamp_min(mapper, graph, node):
""" 构造元素剪裁的PaddleLayer。
TorchScript示例:
%56 : Tensor = aten::clamp_min(%input.1, %46)
参数含义:
%56 (Tensor): 输出,累加后的结果。
%input.1 (Tensor): 输入,需要剪裁的Tensor。
%46 (float/Tensor): 最小值。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input.1
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
# 获取当前节点输入、输出的list
current_inputs = list(layer_inputs.values())
# 处理输入1,即%46
if inputs_name[1] in mapper.attrs:
layer_attrs["min"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["min"] = inputs_name[1]
current_inputs.append(inputs_name[1])
graph.add_layer(
"paddle.clip",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten___contains__(mapper, graph, node): def aten___contains__(mapper, graph, node):
""" 构造in的PaddleLayer。 """ 构造in的PaddleLayer。
...@@ -3322,6 +3363,64 @@ def aten_neg(mapper, graph, node): ...@@ -3322,6 +3363,64 @@ def aten_neg(mapper, graph, node):
return current_inputs, current_outputs return current_inputs, current_outputs
def aten_norm(mapper, graph, node):
""" 构造计算范数的PaddleLayer。
TorchScript示例:
%25 = aten::norm(%input, %21, %58, %24)
参数含义:
%25 (Tensor): 取范数后的结果。
%input (Tensor): 输入。
%21 (int): 范数的种类。
%58 (int): 使用范数计算的轴。
%24 (bool): 是否在输出的Tensor中保留和输入一样的维度。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
# 处理输入0,即%input
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs, scope_name)
layer_inputs["x"] = inputs_name[0]
current_inputs = list(layer_inputs.values())
# 处理输入1,即%21
if inputs_name[1] in mapper.attrs:
layer_attrs["p"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["p"] = inputs_name[1]
current_inputs.append(inputs_name[1])
# 处理输入2,即%58
if inputs_name[1] in mapper.attrs:
layer_attrs["axis"] = mapper.attrs[inputs_name[2]]
else:
mapper._check_input(graph, inputs_node[2], inputs_name[2],
current_outputs, scope_name)
layer_inputs["axis"] = inputs_name[2]
current_inputs.append(inputs_name[2])
# 处理输入3,即%24
if inputs_name[1] in mapper.attrs:
layer_attrs["keepdim"] = mapper.attrs[inputs_name[3]]
else:
mapper._check_input(graph, inputs_node[3], inputs_name[3],
current_outputs, scope_name)
layer_inputs["keepdim"] = inputs_name[3]
current_inputs.append(inputs_name[3])
graph.add_layer(
"paddle.norm",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten___not__(mapper, graph, node): def aten___not__(mapper, graph, node):
""" 构造对bool型取负的PaddleLayer。 """ 构造对bool型取负的PaddleLayer。
......
...@@ -59,8 +59,17 @@ def prim_Constant(mapper, graph, node): ...@@ -59,8 +59,17 @@ def prim_Constant(mapper, graph, node):
scope_name=scope_name) scope_name=scope_name)
return [], [output_name] return [], [output_name]
else: else:
mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy() # mapper.pytorch_params[output_name] = tensor_value.cpu().detach().numpy()
mapper.paddle_params[output_name] = tensor_value.cpu().detach().numpy()
graph.add_layer(
"self.create_parameter",
inputs={},
outputs=[output_name],
scope_name=scope_name,
dtype=string(str(mapper.paddle_params[output_name].dtype)),
shape = mapper.paddle_params[output_name].shape,
default_initializer="paddle.nn.initializer.Constant(value=0.0)")
return [], [output_name]
if "inf" in str(value): if "inf" in str(value):
t = str(type(value)).split("'")[1] t = str(type(value)).split("'")[1]
if str(value).startswith("-"): if str(value).startswith("-"):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册