未验证 提交 21ae549e 编写于 作者: N niuliling123 提交者: GitHub

Fix Argmax Layout autotune (#44080)

上级 61de8af8
...@@ -374,8 +374,9 @@ class ArgmaxOpTransformer ...@@ -374,8 +374,9 @@ class ArgmaxOpTransformer
bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]); bool keep_dims = BOOST_GET_CONST(bool, (*attrs)["keepdims"]);
if (keep_dims) { if (keep_dims) {
if (var_layout != DataLayout::UNDEFINED) { if (var_layout != DataLayout::UNDEFINED) {
std::vector<int> perm_nhwc = {0, 2, 3, 1}; std::vector<int> perm_nhwc = {0, 3, 1, 2};
std::vector<int> perm_nchw = {0, 3, 1, 2}; std::vector<int> perm_nchw = {0, 2, 3, 1};
auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw; auto perm = var_layout == DataLayout::NHWC ? perm_nhwc : perm_nchw;
switch (AttrTypeID((*attrs)["axis"])) { switch (AttrTypeID((*attrs)["axis"])) {
case paddle::framework::proto::AttrType::INT: { case paddle::framework::proto::AttrType::INT: {
......
...@@ -146,7 +146,7 @@ class LayoutAutoTune(unittest.TestCase): ...@@ -146,7 +146,7 @@ class LayoutAutoTune(unittest.TestCase):
out = paddle.argmax(conv_out, axis=1, keepdim=True) out = paddle.argmax(conv_out, axis=1, keepdim=True)
self.assertEqual(conv_out.shape, [1, 14, 12, 8]) self.assertEqual(conv_out.shape, [1, 14, 12, 8])
self.assertEqual(out.shape, [1, 14, 1, 8]) self.assertEqual(out.shape, [1, 14, 12, 1])
def test_argmax_op_transposer(self): def test_argmax_op_transposer(self):
if not self.use_autoune(): if not self.use_autoune():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册