未验证 提交 feb1b54f 编写于 作者: C chengduo 提交者: GitHub

fix min and max bug (#16570)

test=develop
上级 5dea0bdd
...@@ -58,6 +58,8 @@ class ArgMinMaxKernel : public framework::OpKernel<T> { ...@@ -58,6 +58,8 @@ class ArgMinMaxKernel : public framework::OpKernel<T> {
auto& out = *(ctx.Output<framework::LoDTensor>("Out")); auto& out = *(ctx.Output<framework::LoDTensor>("Out"));
out.mutable_data<Tout>(ctx.GetPlace()); out.mutable_data<Tout>(ctx.GetPlace());
auto axis = ctx.Attr<int64_t>("axis"); auto axis = ctx.Attr<int64_t>("axis");
auto x_rank = x.dims().size();
if (axis < 0) axis += x_rank;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
#define CALL_ARG_MINMAX_FUNCTOR(rank) \ #define CALL_ARG_MINMAX_FUNCTOR(rank) \
......
...@@ -64,6 +64,14 @@ class TestCase2(BaseTestCase): ...@@ -64,6 +64,14 @@ class TestCase2(BaseTestCase):
self.axis = 0 self.axis = 0
class TestCase2_1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4)
self.dtype = 'int64'
self.axis = -1
class TestCase3(BaseTestCase): class TestCase3(BaseTestCase):
def initTestCase(self): def initTestCase(self):
self.op_type = 'arg_max' self.op_type = 'arg_max'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册