未验证 提交 3161e6c3 编写于 作者: zhenhailiu's avatar zhenhailiu 提交者: GitHub

[AMP OP&Test] Arg min max bf16 test (#52276)

* polish

* add type check
上级 bed54a70
......@@ -16,7 +16,7 @@ import os
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
from test_attribute_var import UnittestBase
import paddle
......@@ -96,6 +96,57 @@ class TestCase1FP16(BaseTestCase):
self.axis = 1
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "BFP16 test runs only on GPU"
)
class TestArgMinBF16OP(OpTest):
def initTestType(self):
self.op_type = 'arg_min'
self.python_api = paddle.tensor.argmin
def initTestCase(self):
self.initTestType()
self.dims = (3, 4, 5)
self.axis = 0
self.dtype = np.uint16
def setUp(self):
self.initTestCase()
x = np.random.random(self.dims).astype("float32")
self.x = convert_float_to_uint16(x)
self.inputs = {'X': self.x}
self.attrs = {'axis': self.axis}
if self.op_type == "arg_min":
self.outputs = {'Out': np.argmin(x, axis=self.axis)}
else:
self.outputs = {'Out': np.argmax(x, axis=self.axis)}
def test_check_output(self):
self.check_output_with_place(paddle.CUDAPlace(0))
class TestArgMaxBF16OP(TestArgMinBF16OP):
def initTestType(self):
self.op_type = 'arg_max'
self.python_api = paddle.tensor.argmax
class TestArgMinMaxTypeCheck(unittest.TestCase):
def test_type_error(self):
# in static mode
with program_guard(Program(), Program()):
x = paddle.static.data(name='x', shape=[100, 10], dtype='bool')
self.assertRaises(TypeError, paddle.argmin, x)
self.assertRaises(TypeError, paddle.argmax, x)
def test_bfp16(self):
# in static mode
with program_guard(Program(), Program()):
x = paddle.zeros(name='x', shape=[100, 10], dtype='uint16')
t1 = paddle.argmin(x)
t2 = paddle.argmax(x)
class TestCase2_1(BaseTestCase):
def initTestCase(self):
self.op_type = 'arg_max'
......
......@@ -194,6 +194,7 @@ def argmax(x, axis=None, keepdim=False, dtype="int64", name=None):
x,
'x',
[
'uint16',
'float16',
'float32',
'float64',
......@@ -283,6 +284,7 @@ def argmin(x, axis=None, keepdim=False, dtype="int64", name=None):
x,
'x',
[
'uint16',
'float16',
'float32',
'float64',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册