未验证 提交 7003dcaa 编写于 作者: S sneaxiy 提交者: GitHub

Support FP16 argmax/argmin kernel (#42038)

* support int16 argmax kernel

* add fp16 test
上级 9774f965
......@@ -259,6 +259,7 @@ PD_REGISTER_KERNEL(arg_min,
GPU,
ALL_LAYOUT,
phi::ArgMinKernel,
phi::dtype::float16,
float,
double,
int32_t,
......@@ -270,6 +271,7 @@ PD_REGISTER_KERNEL(arg_max,
GPU,
ALL_LAYOUT,
phi::ArgMaxKernel,
phi::dtype::float16,
float,
double,
int32_t,
......
......@@ -68,6 +68,26 @@ class TestCase2(BaseTestCase):
self.axis = 0
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"FP16 test runs only on GPU")
class TestCase0FP16(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
self.dims = (3, 4, 5)
self.dtype = np.float16
self.axis = 0
@unittest.skipIf(not paddle.is_compiled_with_cuda(),
"FP16 test runs only on GPU")
class TestCase1FP16(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_min'
self.dims = (3, 4)
self.dtype = np.float16
self.axis = 1
class TestCase2_1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
......@@ -202,4 +222,5 @@ class BaseTestComplex2_2(OpTest):
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册