diff --git a/x2paddle/op_mapper/tf_op_mapper.py b/x2paddle/op_mapper/tf_op_mapper.py index 5140ab7aa7745895c917f3c02e16cb1341469fd9..21366046b873dcfca6d4bd9259499bc11cc2be79 100644 --- a/x2paddle/op_mapper/tf_op_mapper.py +++ b/x2paddle/op_mapper/tf_op_mapper.py @@ -667,8 +667,14 @@ class TFOpMapper(OpMapper): paddings = paddings.value.flatten().tolist() if input.tf_data_format == "NHWC" and len(input.out_shapes[0]) == 4: paddings = [paddings[i] for i in [0, 1, 6, 7, 2, 3, 4, 5]] + + pad_op = "pad" + if len(input.out_shapes[0]) == 4: + if paddings[0] + paddings[1] + paddings[2] + paddings[3] == 0: + paddings = paddings[4:] + pad_op = "pad2d" attr = {"paddings": paddings} - node.fluid_code.add_layer("pad", + node.fluid_code.add_layer(pad_op, inputs=input, output=node, param_attr=attr)