提交 c9e7f42e 编写于 作者: J jiangjiajun

replace pad with pad2d

上级 cdd27c7a
......@@ -667,8 +667,14 @@ class TFOpMapper(OpMapper):
paddings = paddings.value.flatten().tolist()
if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4:
paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]]
pad_op = "pad"
if len(input.out_shapes[0]) == 4:
if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0:
paddings = paddings[4:]
pad_op = "pad2d"
attr = {"paddings": paddings}
node.fluid_code.add_layer("pad",
node.fluid_code.add_layer(pad_op,
inputs=input,
output=node,
param_attr=attr)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册