未验证 提交 7b19efe4 编写于 作者: Z zxcd 提交者: GitHub

Support more dtype for any/all API. (#55253)

* add more data type for all/any.

* remove xpu fix.

* add test unit.

* fix typename name.

* fix output data type.
上级 6692dc9a
......@@ -52,6 +52,7 @@ void Reduce(const DeviceContext& dev_ctx,
phi::funcs::ReduceKernelImpl<DeviceContext, T, data_t, Functor>(
dev_ctx, x, out, dims, keep_dim, reduce_all);
}));
} else {
// cast x tensor to out_dtype
auto tmp_tensor = phi::Cast<T, DeviceContext>(dev_ctx, x, out_dtype);
......@@ -65,7 +66,7 @@ void Reduce(const DeviceContext& dev_ctx,
}
}
template <typename DeviceContext, typename OutT, typename Functor>
template <typename DeviceContext, typename T, typename Functor>
void BoolReduceKernel(const DeviceContext& dev_ctx,
const phi::DenseTensor& input,
const std::vector<int64_t>& dims,
......@@ -73,7 +74,7 @@ void BoolReduceKernel(const DeviceContext& dev_ctx,
bool reduce_all,
phi::DenseTensor* output) {
reduce_all = recompute_reduce_all(input, dims, reduce_all);
dev_ctx.template Alloc<OutT>(output);
dev_ctx.template Alloc<bool>(output);
// The dims has full dim, set the reduce_all is True
const auto& input_dim_size = input.dims().size();
......@@ -86,9 +87,15 @@ void BoolReduceKernel(const DeviceContext& dev_ctx,
}
}
reduce_all = (reduce_all || full_dim);
funcs::ReduceKernelImpl<DeviceContext, bool, OutT, Functor>(
dev_ctx, input, output, dims, keep_dim, reduce_all);
DenseTensor tmp_tensor;
if (input.dtype() != phi::DataType::BOOL) {
tmp_tensor =
phi::Cast<T, DeviceContext>(dev_ctx, input, phi::DataType::BOOL);
} else {
tmp_tensor = input;
}
funcs::ReduceKernelImpl<DeviceContext, bool, bool, Functor>(
dev_ctx, tmp_tensor, output, dims, keep_dim, reduce_all);
}
} // namespace phi
......@@ -35,6 +35,14 @@ void AllRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(all_raw, CPU, ALL_LAYOUT, phi::AllRawKernel, bool) {
PD_REGISTER_KERNEL(all_raw,
CPU,
ALL_LAYOUT,
phi::AllRawKernel,
float,
double,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
......@@ -35,4 +35,14 @@ void AnyRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(any_raw, CPU, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
PD_REGISTER_KERNEL(any_raw,
CPU,
ALL_LAYOUT,
phi::AnyRawKernel,
float,
double,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
......@@ -26,7 +26,7 @@ void AllRawKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
auto out_dtype = phi::DataType::BOOL;
phi::Reduce<T, kps::LogicalAndFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -38,7 +38,15 @@ PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#else
PD_REGISTER_KERNEL(all_raw, KPS, ALL_LAYOUT, phi::AllRawKernel, bool) {
PD_REGISTER_KERNEL(all_raw,
KPS,
ALL_LAYOUT,
phi::AllRawKernel,
float,
double,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#endif
......@@ -26,7 +26,7 @@ void AnyRawKernel(const Context& dev_ctx,
bool reduce_all,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
auto out_dtype = phi::DataType::BOOL;
phi::Reduce<T, kps::LogicalOrFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims, keep_dim, out_dtype, out);
}
......@@ -36,5 +36,15 @@ void AnyRawKernel(const Context& dev_ctx,
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
#else
PD_REGISTER_KERNEL(any_raw, KPS, ALL_LAYOUT, phi::AnyRawKernel, bool) {}
PD_REGISTER_KERNEL(any_raw,
KPS,
ALL_LAYOUT,
phi::AnyRawKernel,
float,
double,
int,
int64_t,
bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#endif
......@@ -38,10 +38,16 @@ void AllKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(all, CPU, ALL_LAYOUT, phi::AllKernel, bool) {}
PD_REGISTER_KERNEL(
all, CPU, ALL_LAYOUT, phi::AllKernel, float, double, int, int64_t, bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(all, GPU, ALL_LAYOUT, phi::AllKernel, bool) {}
PD_REGISTER_KERNEL(
all, GPU, ALL_LAYOUT, phi::AllKernel, float, double, int, int64_t, bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#endif
#if defined(PADDLE_WITH_XPU_KP)
......
......@@ -31,10 +31,16 @@ void AnyKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(any, CPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
PD_REGISTER_KERNEL(
any, CPU, ALL_LAYOUT, phi::AnyKernel, float, double, int64_t, int, bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(any, GPU, ALL_LAYOUT, phi::AnyKernel, bool) {}
PD_REGISTER_KERNEL(
any, GPU, ALL_LAYOUT, phi::AnyKernel, float, double, int, int64_t, bool) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
}
#endif
#if defined(PADDLE_WITH_XPU_KP) && !defined(PADDLE_WITH_XPU)
......
......@@ -4096,12 +4096,13 @@ def all(x, axis=None, keepdim=False, name=None):
'keep_dim': keepdim,
'reduce_all': reduce_all,
}
check_variable_and_dtype(x, 'x', ['bool'], 'all')
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'all'
)
check_type(axis, 'axis', (int, list, tuple, type(None)), 'all')
helper = LayerHelper('all', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=paddle.bool)
helper.append_op(
type='reduce_all',
inputs={'X': x},
......@@ -4170,13 +4171,13 @@ def any(x, axis=None, keepdim=False, name=None):
'keep_dim': keepdim,
'reduce_all': reduce_all,
}
check_variable_and_dtype(x, 'x', ['bool'], 'any')
check_variable_and_dtype(
x, 'x', ['bool', 'float32', 'float64', 'int32', 'int64'], 'any'
)
check_type(axis, 'axis', (int, list, tuple, type(None)), 'any')
helper = LayerHelper('any', **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
out = helper.create_variable_for_type_inference(dtype=paddle.bool)
helper.append_op(
type='reduce_any',
inputs={'X': x},
......
......@@ -804,6 +804,30 @@ class TestAllOp(OpTest):
self.check_output()
class TestAllFloatOp(OpTest):
def setUp(self):
self.op_type = "reduce_all"
self.python_api = reduce_all_wrapper
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("float")}
self.outputs = {'Out': self.inputs['X'].all()}
self.attrs = {'reduce_all': True}
def test_check_output(self):
self.check_output()
class TestAllIntOp(OpTest):
def setUp(self):
self.op_type = "reduce_all"
self.python_api = reduce_all_wrapper
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("int")}
self.outputs = {'Out': self.inputs['X'].all()}
self.attrs = {'reduce_all': True}
def test_check_output(self):
self.check_output()
class TestAllOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.all
......@@ -900,11 +924,6 @@ class TestAllOpError(unittest.TestCase):
# The input type of reduce_all_op must be Variable.
input1 = 12
self.assertRaises(TypeError, paddle.all, input1)
# The input dtype of reduce_all_op must be bool.
input2 = paddle.static.data(
name='input2', shape=[-1, 12, 10], dtype="int32"
)
self.assertRaises(TypeError, paddle.all, input2)
def reduce_any_wrapper(x, axis=None, keepdim=False, reduce_all=True, name=None):
......@@ -923,6 +942,30 @@ class TestAnyOp(OpTest):
self.check_output()
class TestAnyFloatOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
self.python_api = reduce_any_wrapper
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("float")}
self.outputs = {'Out': self.inputs['X'].any()}
self.attrs = {'reduce_all': True}
def test_check_output(self):
self.check_output()
class TestAnyIntOp(OpTest):
def setUp(self):
self.op_type = "reduce_any"
self.python_api = reduce_any_wrapper
self.inputs = {'X': np.random.randint(0, 2, (5, 6, 10)).astype("int")}
self.outputs = {'Out': self.inputs['X'].any()}
self.attrs = {'reduce_all': True}
def test_check_output(self):
self.check_output()
class TestAnyOp_ZeroDim(OpTest):
def setUp(self):
self.python_api = paddle.any
......@@ -1021,11 +1064,6 @@ class TestAnyOpError(unittest.TestCase):
# The input type of reduce_any_op must be Variable.
input1 = 12
self.assertRaises(TypeError, paddle.any, input1)
# The input dtype of reduce_any_op must be bool.
input2 = paddle.static.data(
name='input2', shape=[-1, 12, 10], dtype="int32"
)
self.assertRaises(TypeError, paddle.any, input2)
class Test1DReduce(OpTest):
......@@ -1645,11 +1683,43 @@ class TestAllAPI(unittest.TestCase):
feed={"input": input_np},
fetch_list=[result],
)
np.testing.assert_allclose(fetches[0], np.all(input_np), rtol=1e-05)
self.assertTrue((fetches[0] == np.all(input_np)).all())
def check_static_float_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = paddle.static.data(
name="input", shape=[4, 4], dtype="float"
)
result = paddle.all(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("float")
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.all(input_np)).all())
def check_static_int_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = paddle.static.data(name="input", shape=[4, 4], dtype="int")
result = paddle.all(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("int")
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.all(input_np)).all())
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
self.check_static_float_result(place=place)
self.check_static_int_result(place=place)
def test_dygraph(self):
paddle.disable_static()
......@@ -1679,6 +1749,18 @@ class TestAllAPI(unittest.TestCase):
expect_res4 = np.all(np_x, axis=1, keepdims=True)
self.assertTrue((np_out4 == expect_res4).all())
x = paddle.cast(x, 'float')
out5 = paddle.all(x)
np_out5 = out5.numpy()
expect_res5 = np.all(np_x)
self.assertTrue((np_out5 == expect_res5).all())
x = paddle.cast(x, 'int')
out6 = paddle.all(x)
np_out6 = out6.numpy()
expect_res6 = np.all(np_x)
self.assertTrue((np_out6 == expect_res6).all())
paddle.enable_static()
......@@ -1702,11 +1784,43 @@ class TestAnyAPI(unittest.TestCase):
feed={"input": input_np},
fetch_list=[result],
)
np.testing.assert_allclose(fetches[0], np.any(input_np), rtol=1e-05)
self.assertTrue((fetches[0] == np.any(input_np)).all())
def check_static_float_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = paddle.static.data(
name="input", shape=[4, 4], dtype="float"
)
result = paddle.any(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("float")
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.any(input_np)).all())
def check_static_int_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input = paddle.static.data(name="input", shape=[4, 4], dtype="int")
result = paddle.any(x=input)
input_np = np.random.randint(0, 2, [4, 4]).astype("int")
exe = fluid.Executor(place)
fetches = exe.run(
fluid.default_main_program(),
feed={"input": input_np},
fetch_list=[result],
)
self.assertTrue((fetches[0] == np.any(input_np)).all())
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
self.check_static_float_result(place=place)
self.check_static_int_result(place=place)
def test_dygraph(self):
paddle.disable_static()
......@@ -1736,6 +1850,21 @@ class TestAnyAPI(unittest.TestCase):
expect_res4 = np.any(np_x, axis=1, keepdims=True)
self.assertTrue((np_out4 == expect_res4).all())
np_x = np.random.randint(0, 2, (12, 10)).astype(np.float32)
x = paddle.assign(np_x)
x = paddle.cast(x, 'float32')
out5 = paddle.any(x)
np_out5 = out5.numpy()
expect_res5 = np.any(np_x)
self.assertTrue((np_out5 == expect_res5).all())
x = paddle.cast(x, 'int')
out6 = paddle.any(x)
np_out6 = out6.numpy()
expect_res6 = np.any(np_x)
self.assertTrue((np_out6 == expect_res6).all())
paddle.enable_static()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册