未验证 提交 cf6ed7cb 编写于 作者: D denglianbin 提交者: GitHub

【Hackathon No.48】为 Paddle kthvalue 算子实现 float16 数据类型支持 (#53285)

上级 2a705b74
......@@ -40,11 +40,14 @@ class TestKthvalueOp(OpTest):
self.k = 5
self.axis = -1
def init_dtype(self):
self.dtype = np.float64
def setUp(self):
self.op_type = "kthvalue"
self.python_api = paddle.kthvalue
self.dtype = np.float64
self.input_data = np.random.random((2, 1, 2, 4, 10))
self.init_dtype()
self.input_data = np.random.random((2, 1, 2, 4, 10)).astype(self.dtype)
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
......@@ -62,17 +65,25 @@ class TestKthvalueOp(OpTest):
self.check_grad({'X'}, 'Out')
class TestKthvalueOpFp16(TestKthvalueOp):
def init_dtype(self):
self.dtype = np.float16
class TestKthvalueOpWithKeepdim(OpTest):
def init_args(self):
self.k = 2
self.axis = 1
def init_dtype(self):
self.dtype = np.float64
def setUp(self):
self.init_args()
self.init_dtype()
self.op_type = "kthvalue"
self.python_api = paddle.kthvalue
self.dtype = np.float64
self.input_data = np.random.random((1, 3, 2, 4, 10))
self.input_data = np.random.random((1, 3, 2, 4, 10)).astype(self.dtype)
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True}
output, indices = cal_kthvalue(
......@@ -89,6 +100,11 @@ class TestKthvalueOpWithKeepdim(OpTest):
self.check_grad({'X'}, 'Out')
class TestKthvalueOpWithKeepdimFp16(TestKthvalueOpWithKeepdim):
def init_dtype(self):
self.dtype = np.float16
class TestKthvalueOpKernels(unittest.TestCase):
def setUp(self):
self.axises = [2, -1]
......
......@@ -1074,7 +1074,7 @@ def kthvalue(x, k, axis=None, keepdim=False, name=None):
Find values and indices of the k-th smallest at the axis.
Args:
x(Tensor): A N-D Tensor with type float32, float64, int32, int64.
x(Tensor): A N-D Tensor with type float16, float32, float64, int32, int64.
k(int): The k for the k-th smallest number to look for along the axis.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册