未验证 提交 79c06f88 编写于 作者: S SunAhong1993 提交者: GitHub

fix the pad in tf (#567)

* fix the convert.py args

* fix teh pad and add log_softmax
上级 269d21ff
......@@ -2962,6 +2962,51 @@ def aten_log(mapper, graph, node):
return current_inputs, current_outputs
def aten_log_softmax(mapper, graph, node):
""" 构造log_softmax的PaddleLayer。
TorchScript示例:
%4 = aten::log_softmax(%input, %2, %3)
参数含义:
%4 (Tensor): 输出的Tensor。
%input (Tensor): 输入的Tensor。
%2 (int): 指定对输入进行运算的轴。
%3 (int): 输入Tensor的数据类型。
"""
scope_name = mapper.normalize_scope_name(node)
output_name = mapper._get_outputs_name(node)[0]
layer_outputs = [output_name]
layer_inputs = {}
layer_attrs = {}
inputs_name, inputs_node = mapper._get_inputs_name(node)
# 获取当前节点输出的list
current_outputs = [output_name]
current_inputs = []
# 处理输入0,即%input
mapper._check_input(graph, inputs_node[0], inputs_name[0], current_outputs,
scope_name)
layer_inputs["x"] = inputs_name[0]
# 处理输入1,即%2,代表dtype
if inputs_name[1] in mapper.attrs:
layer_attrs["axis"] = mapper.attrs[inputs_name[1]]
else:
mapper._check_input(graph, inputs_node[1], inputs_name[1],
current_outputs, scope_name)
layer_inputs["axis"] = inputs_name[1]
# 处理输入2,即%3,代表dtype
if mapper.attrs[inputs_name[2]] is not None:
layer_attrs["dtype"] = dtype_dict[mapper.attrs[inputs_name[2]]]
# 获取当前节点输入的list
current_inputs = list(layer_inputs.values())
graph.add_layer(
"paddle.nn.functional.log_softmax",
inputs=layer_inputs,
outputs=layer_outputs,
scope_name=scope_name,
**layer_attrs)
return current_inputs, current_outputs
def aten_lstm(mapper, graph, node):
""" 构造长短期记忆网络(LSTM)的PaddleLayer。
TorchScript示例:
......
......@@ -649,18 +649,30 @@ class TFOpMapper(OpMapper):
paddings = self.graph.get_input_node(node, 1)
assert paddings.layer_type == "Const", "Padding should be Const"
paddings = paddings.value.flatten().tolist()
constant_values = 0
if len(node.layer.input) > 2:
constant_values = self.graph.get_input_node(node, 2)
assert constant_values.layer_type == "Const", "Padding should be Const"
constant_values = constant_values.value
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": input.name},
outputs=[node.name],
pad=paddings,
value=constant_values)
if len(paddings) == 8 and sum(paddings[:2]) == 0 \
and sum(paddings[-2:]) == 0:
paddings = paddings[2: -2]
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": input.name},
outputs=[node.name],
pad=paddings,
value=constant_values,
data_format=string('NHWC'))
else:
self.paddle_graph.add_layer(
kernel="paddle.nn.functional.pad",
inputs={"x": input.name},
outputs=[node.name],
pad=paddings,
value=constant_values)
def MirrorPad(self, node):
self.Pad(node)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册