未验证 提交 43efb979 编写于 作者: C cyberslack_lee 提交者: GitHub

【Hackathon4 No58】kthvalue (#51615)

上级 7ecbcc08
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "paddle/phi/kernels/kthvalue_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -76,4 +75,5 @@ PD_REGISTER_KERNEL(kthvalue_grad,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {}
......@@ -13,8 +13,8 @@
// limitations under the License.
#include "paddle/phi/kernels/kthvalue_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"
......@@ -287,6 +287,7 @@ PD_REGISTER_KERNEL(kthvalue,
double,
int,
int64_t,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -15,10 +15,11 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
from paddle import fluid
from paddle.fluid import core
def cal_kthvalue(x, k, axis, keepdim=False):
......@@ -207,5 +208,74 @@ class TestModeOpInStatic(unittest.TestCase):
np.testing.assert_allclose(paddle_result, expect_value, rtol=1e-05)
class TestKthvalueFP16Op(OpTest):
def init_args(self):
self.k = 5
self.axis = -1
self.keepdim = False
self.input_data = np.random.random((2, 1, 2, 4, 10))
self.dtype = np.float16
def setUp(self):
self.op_type = "kthvalue"
self.python_api = paddle.kthvalue
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': self.keepdim}
output, indices = cal_kthvalue(
self.input_data, k=self.k, axis=self.axis, keepdim=self.keepdim
)
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad({'X'}, 'Out')
class TestKthvalueWithKeepdimFP16Op(TestKthvalueFP16Op):
def init_args(self):
self.k = 2
self.axis = 1
self.keepdim = True
self.input_data = np.random.random((1, 3, 2, 4, 10))
self.dtype = np.float16
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestKthvalueBF16Op(OpTest):
def init_args(self):
self.k = 2
self.axis = 1
def setUp(self):
self.init_args()
self.op_type = 'kthvalue'
self.python_api = paddle.kthvalue
self.dtype = np.uint16
x = np.random.random((1, 3, 2, 4, 10))
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True}
out, indices = cal_kthvalue(x, k=self.k, axis=self.axis, keepdim=True)
self.outputs = {'Out': convert_float_to_uint16(out), 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
place = core.CUDAPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
paddle.enable_static()
place = core.CUDAPlace(0)
self.check_grad_with_place(place, {'X'}, 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册