diff --git a/paddle/fluid/imperative/layout_transformer.h b/paddle/fluid/imperative/layout_transformer.h index 50d3e2b6ac1392fd7d57cb3f43bde51fa6d17f98..ab7619dedb2e9df8c8299e5856cddc15bcbc1c1e 100644 --- a/paddle/fluid/imperative/layout_transformer.h +++ b/paddle/fluid/imperative/layout_transformer.h @@ -374,8 +374,9 @@ class ArgmaxOpTransformer bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]); if (keep_dims) { if (var_layout != DataLayout::UNDEFINED) { - std::vector perm_nhwc = {0, 2, 3, 1}; - std::vector perm_nchw = {0, 3, 1, 2}; + std::vector perm_nhwc = {0, 3, 1, 2}; + std::vector perm_nchw = {0, 2, 3, 1}; + auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; switch (AttrTypeID((*attrs)["axis"])) { case paddle::framework::proto::AttrType::INT: { diff --git a/python/paddle/fluid/tests/unittests/test_layout_autotune.py b/python/paddle/fluid/tests/unittests/test_layout_autotune.py index 6e25e3719d3cd6790a9b49a4ed0444466a43eb6a..fc9b51c5fc0402ed21b9830e5e49d8b57e7516e4 100644 --- a/python/paddle/fluid/tests/unittests/test_layout_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_layout_autotune.py @@ -146,7 +146,7 @@ class LayoutAutoTune(unittest.TestCase): out = paddle.argmax(conv_out, axis=1, keepdim=True) self.assertEqual(conv_out.shape, [1, 14, 12, 8]) - self.assertEqual(out.shape, [1, 14, 1, 8]) + self.assertEqual(out.shape, [1, 14, 12, 1]) def test_argmax_op_transposer(self): if not self.use_autoune():