From feb1b54f9d14f7cb7f9f9630813301f2af299ffa Mon Sep 17 00:00:00 2001 From: chengduo Date: Sun, 31 Mar 2019 21:12:07 -0500 Subject: [PATCH] fix min and max bug (#16570) test=develop --- paddle/fluid/operators/arg_min_max_op_base.h | 2 ++ .../paddle/fluid/tests/unittests/test_arg_min_max_op.py | 8 ++++++++ 2 files changed, 10 insertions(+) diff --git a/paddle/fluid/operators/arg_min_max_op_base.h b/paddle/fluid/operators/arg_min_max_op_base.h index 6cbdaefeda0..bf7b83bb7a7 100644 --- a/paddle/fluid/operators/arg_min_max_op_base.h +++ b/paddle/fluid/operators/arg_min_max_op_base.h @@ -58,6 +58,8 @@ class ArgMinMaxKernel : public framework::OpKernel { auto& out = *(ctx.Output("Out")); out.mutable_data(ctx.GetPlace()); auto axis = ctx.Attr("axis"); + auto x_rank = x.dims().size(); + if (axis < 0) axis += x_rank; auto& dev_ctx = ctx.template device_context(); #define CALL_ARG_MINMAX_FUNCTOR(rank) \ diff --git a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py index 0712e102b30..4f9f1ec2253 100644 --- a/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py +++ b/python/paddle/fluid/tests/unittests/test_arg_min_max_op.py @@ -64,6 +64,14 @@ class TestCase2(BaseTestCase): self.axis = 0 +class TestCase2_1(BaseTestCase): + def initTestCase(self): + self.op_type = 'arg_max' + self.dims = (3, 4) + self.dtype = 'int64' + self.axis = -1 + + class TestCase3(BaseTestCase): def initTestCase(self): self.op_type = 'arg_max' -- GitLab