From 21ae549ea77d9de256449e09e6feaeebc7a6d6f6 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Fri, 8 Jul 2022 16:49:46 +0800 Subject: [PATCH] Fix Argmax Layout autotune (#44080) --- paddle/fluid/imperative/layout_transformer.h | 5 +++-- python/paddle/fluid/tests/unittests/test_layout_autotune.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/imperative/layout_transformer.h b/paddle/fluid/imperative/layout_transformer.h index 50d3e2b6ac1..ab7619dedb2 100644 --- a/paddle/fluid/imperative/layout_transformer.h +++ b/paddle/fluid/imperative/layout_transformer.h @@ -374,8 +374,9 @@ class ArgmaxOpTransformer bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]); if (keep_dims) { if (var_layout != DataLayout::UNDEFINED) { - std::vector perm_nhwc = {0, 2, 3, 1}; - std::vector perm_nchw = {0, 3, 1, 2}; + std::vector perm_nhwc = {0, 3, 1, 2}; + std::vector perm_nchw = {0, 2, 3, 1}; + auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; switch (AttrTypeID((*attrs)["axis"])) { case paddle::framework::proto::AttrType::INT: { diff --git a/python/paddle/fluid/tests/unittests/test_layout_autotune.py b/python/paddle/fluid/tests/unittests/test_layout_autotune.py index 6e25e3719d3..fc9b51c5fc0 100644 --- a/python/paddle/fluid/tests/unittests/test_layout_autotune.py +++ b/python/paddle/fluid/tests/unittests/test_layout_autotune.py @@ -146,7 +146,7 @@ class LayoutAutoTune(unittest.TestCase): out = paddle.argmax(conv_out, axis=1, keepdim=True) self.assertEqual(conv_out.shape, [1, 14, 12, 8]) - self.assertEqual(out.shape, [1, 14, 1, 8]) + self.assertEqual(out.shape, [1, 14, 12, 1]) def test_argmax_op_transposer(self): if not self.use_autoune(): -- GitLab