未验证 提交 2e92357b 编写于 作者: H Haohongxiang 提交者: GitHub

[API/OP] Support FP16/BF16 in paddle.nonzero API/OP (#51640)

上级 ad5536eb
......@@ -90,6 +90,7 @@ PD_REGISTER_KERNEL(nonzero,
int64_t,
int,
int16_t,
phi::dtype::bfloat16,
bool,
float,
double) {
......
......@@ -81,6 +81,8 @@ PD_REGISTER_KERNEL(nonzero,
int64_t,
int,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16,
bool,
float,
double) {
......
......@@ -15,10 +15,17 @@
import unittest
import numpy as np
from eager_op_test import OpTest
import paddle
from paddle import fluid
from paddle.fluid import Program, program_guard
from paddle.fluid.tests.unittests.op_test import convert_float_to_uint16
def call_nonzero(x):
input = paddle.to_tensor(x)
return paddle.nonzero(x=input)
class TestNonZeroAPI(unittest.TestCase):
......@@ -88,5 +95,86 @@ class TestNonZeroAPI(unittest.TestCase):
expect_out = np.array([[0, 0], [1, 1]])
# Base case
class TestNonzeroOp(OpTest):
def setUp(self):
'''Test where_index op with random value'''
np.random.seed(2023)
self.op_type = "where_index"
self.python_api = call_nonzero
self.init_shape()
self.init_dtype()
self.inputs = self.create_inputs()
self.outputs = self.return_outputs()
def test_check_output(self):
self.check_output()
def init_shape(self):
self.shape = [8, 8]
def init_dtype(self):
self.dtype = np.float64
def create_inputs(self):
return {
'Condition': np.random.randint(5, size=self.shape).astype(
self.dtype
)
}
def return_outputs(self):
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))}
class TestNonzeroFP32Op(TestNonzeroOp):
def init_shape(self):
self.shape = [2, 10, 2]
def init_dtype(self):
self.dtype = np.float32
class TestNonzeroFP16Op(TestNonzeroOp):
def init_shape(self):
self.shape = [3, 4, 7]
def init_dtype(self):
self.dtype = np.float16
class TestNonzeroBF16(OpTest):
def setUp(self):
'''Test where_index op with bfloat16 dtype'''
np.random.seed(2023)
self.op_type = "where_index"
self.python_api = call_nonzero
self.init_shape()
self.init_dtype()
self.inputs = self.create_inputs()
self.outputs = self.return_outputs()
def test_check_output(self):
self.check_output()
def init_shape(self):
self.shape = [12, 9]
def init_dtype(self):
self.dtype = np.uint16
def create_inputs(self):
return {
'Condition': convert_float_to_uint16(
np.random.randint(5, size=self.shape).astype(np.float32)
)
}
def return_outputs(self):
return {'Out': np.transpose(np.nonzero(self.inputs['Condition']))}
if __name__ == "__main__":
unittest.main()
......@@ -430,6 +430,22 @@ def nonzero(x, as_tuple=False):
if in_dygraph_mode():
outs = _C_ops.nonzero(x)
else:
check_variable_and_dtype(
x,
'x',
[
'int16',
'int32',
'int64',
'uint16',
'float16',
'float32',
'float64',
'bool',
],
'where_index',
)
helper = LayerHelper("where_index", **locals())
outs = helper.create_variable_for_type_inference(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册