未验证 提交 27417f1f 编写于 作者: W will-jl944 提交者: GitHub

Logical Ops support more data types (#34141)

* logical ops support int8, int16, int32, int64, float, double

* update docs of logical ops

* fix npu and xpu logical ops

* fix npu and xpu logical ops

* fix bug in xpu logical op code

* update test_logical_op_npu and test_logical_op_xpu

* correct error type
上级 63f6ce7b
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -26,15 +23,16 @@ class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf("Left hand operand of %s operator. Must be "
"a Variable of type bool.",
"a Variable of type being one of bool, int8, "
"int16, int32, int64, float32, float64.",
comment.type));
AddInput("Y", string::Sprintf("Right hand operand of %s operator. Must be "
"a Variable of type bool.",
"a Variable of type being one of bool, int8, "
"int16, int32, int64, float32, float64.",
comment.type));
AddOutput("Out", string::Sprintf("n-dim bool Variable"));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean LoDTensor or Tensor.
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim LoDTensor or Tensor.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
......@@ -46,13 +44,14 @@ class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
OpComment comment;
AddInput("X", string::Sprintf("Operand of %s operator. Must be "
"a LoDTensor or Tensor of type bool.",
comment.type));
AddInput("X",
string::Sprintf("Operand of %s operator. Must be "
"a LoDTensor or Tensor of type being one of bool, "
"int8, int16, int32, int64, float32, float64.",
comment.type));
AddOutput("Out", string::Sprintf("n-dim bool LoDTensor or Tensor."));
AddComment(string::Sprintf(R"DOC(%s Operator
It operates element-wise on X, and returns the Out. X and Out are N-dim boolean LoDTensor or Tensor.
It operates element-wise on X, and returns the Out. X and Out are N-dim LoDTensor or Tensor.
Each element of Out is calculated by %s
)DOC",
comment.type, comment.equation));
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -21,13 +18,13 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return args[0] op args[1]; \
} \
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return static_cast<bool>(args[0]) op static_cast<bool>(args[1]); \
} \
};
LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||)
......@@ -68,10 +65,16 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
} // namespace operators
} // namespace paddle
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<bool>>);
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<bool>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int8_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int16_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int64_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>);
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -82,12 +79,36 @@ class UnaryLogicalOpKernel
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>);
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int8_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int16_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>);
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int8_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int16_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
......@@ -82,11 +79,29 @@ class LogicalAndPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(logical_not,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, bool>);
REGISTER_OP_NPU_KERNEL(
logical_not, ops::LogicalNotNPUKernel<plat::NPUDeviceContext, bool>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, int>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, float>,
ops::LogicalNotNPUKernel<plat::NPUDeviceContext, double>);
REGISTER_OP_NPU_KERNEL(logical_or,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, bool>);
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, bool>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, int>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, int64_t>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, float>,
ops::LogicalOrNPUKernel<plat::NPUDeviceContext, double>);
REGISTER_OP_NPU_KERNEL(logical_and,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, bool>);
ops::LogicalAndPUKernel<plat::NPUDeviceContext, bool>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, int8_t>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, int16_t>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, int>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, int64_t>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, float>,
ops::LogicalAndPUKernel<plat::NPUDeviceContext, double>);
......@@ -45,7 +45,7 @@ class BinaryLogicalOpXPUKernel : public framework::OpKernel<T> {
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
T* out_ptr = out->mutable_data<T>(context.GetPlace());
bool* out_ptr = out->mutable_data<bool>(context.GetPlace());
const T* x_ptr = x->data<T>();
const T* y_ptr = y->data<T>();
auto& dev_ctx =
......@@ -153,7 +153,7 @@ class UnaryLogicalOpXPUKernel : public framework::OpKernel<T> {
if (x->numel() == 0) {
return;
}
out->mutable_data<T>(context.GetPlace());
out->mutable_data<bool>(context.GetPlace());
auto& dev_ctx =
context.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::logical_not<bool>(dev_ctx.x_context(), x->data<T>(),
......
......@@ -17,5 +17,11 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_and,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>);
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, bool>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int8_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int16_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, int64_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, float>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_AND, double>);
#endif
......@@ -15,5 +15,11 @@ limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/controlflow/logical_op_xpu.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel<bool>);
REGISTER_OP_XPU_KERNEL(logicalnot, ops::UnaryLogicalOpXPUKernel<bool>,
ops::UnaryLogicalOpXPUKernel<int8_t>,
ops::UnaryLogicalOpXPUKernel<int16_t>,
ops::UnaryLogicalOpXPUKernel<int>,
ops::UnaryLogicalOpXPUKernel<int64_t>,
ops::UnaryLogicalOpXPUKernel<float>,
ops::UnaryLogicalOpXPUKernel<double>);
#endif
......@@ -18,5 +18,11 @@ limitations under the License. */
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
logical_or,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>);
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, bool>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int8_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int16_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, int64_t>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, float>,
ops::BinaryLogicalOpXPUKernel<ops::XpuLogicalType::XPU_OR, double>);
#endif
......@@ -12147,17 +12147,22 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
return op(x, y)
else:
return op(x)
check_variable_and_dtype(x, "x", ["bool"], op_name)
check_variable_and_dtype(x, "x", [
"bool", "int8", "int16", "int32", "int64", "float32", "float64"
], op_name)
if y is not None:
check_variable_and_dtype(y, "y", ["bool"], op_name)
check_variable_and_dtype(y, "y", [
"bool", "int8", "int16", "int32", "int64", "float32", "float64"
], op_name)
if out is not None:
check_type(out, "out", Variable, op_name)
helper = LayerHelper(op_name, **locals())
if binary_op:
assert x.dtype == y.dtype
if binary_op and x.dtype != y.dtype:
raise ValueError(
"(InvalidArgument) The DataType of %s Op's Variable must be consistent, but received %s and %s."
% (op_name, x.dtype, y.dtype))
if out is None:
out = helper.create_variable_for_type_inference(dtype=x.dtype)
......@@ -12175,7 +12180,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
def logical_and(x, y, out=None, name=None):
r"""
``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
......@@ -12186,8 +12191,8 @@ def logical_and(x, y, out=None, name=None):
``paddle.logical_and`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -12211,7 +12216,7 @@ def logical_and(x, y, out=None, name=None):
def logical_or(x, y, out=None, name=None):
"""
``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
......@@ -12222,8 +12227,8 @@ def logical_or(x, y, out=None, name=None):
``paddle.logical_or`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
out(Tensor): The ``Variable`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -12250,7 +12255,7 @@ def logical_or(x, y, out=None, name=None):
def logical_xor(x, y, out=None, name=None):
r"""
``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``out`` is N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
......@@ -12261,8 +12266,8 @@ def logical_xor(x, y, out=None, name=None):
``paddle.logical_xor`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
x (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
y (Tensor): the input tensor, it's data type should be one of bool, int8, int16, in32, in64, float32, float64.
out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
......@@ -12290,7 +12295,7 @@ def logical_xor(x, y, out=None, name=None):
def logical_not(x, out=None, name=None):
"""
``logical_not`` operator computes element-wise logical NOT on ``x``, and returns ``out``. ``x`` and ``out`` are N-dim boolean ``Variable``.
``logical_not`` operator computes element-wise logical NOT on ``x``, and returns ``out``. ``out`` is N-dim boolean ``Variable``.
Each element of ``out`` is calculated by
.. math::
......@@ -12298,7 +12303,7 @@ def logical_not(x, out=None, name=None):
out = !x
Args:
x(Tensor): Operand of logical_not operator. Must be a Tensor of type bool.
x(Tensor): Operand of logical_not operator. Must be a Tensor of type bool, int8, int16, in32, in64, float32, or float64.
out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor` will be created to save the output.
name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`.
......
......@@ -23,6 +23,10 @@ import paddle
import paddle.fluid as fluid
from paddle.static import Program, program_guard
SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
]
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
......@@ -110,13 +114,13 @@ def run_static(x_np, y_np, op_str, use_npu=False, binary_op=True):
place = paddle.NPUPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype)
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
......@@ -130,17 +134,20 @@ def run_dygraph(x_np, y_np, op_str, use_npu=False, binary_op=True):
place = paddle.NPUPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np)
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, *args, **kwargs):
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
else:
return np.random.randn(*np_shape).astype(dtype)
def test(unit_test, use_npu=False, test_error=False):
......@@ -152,40 +159,46 @@ def test(unit_test, use_npu=False, test_error=False):
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static, **meta_data)
unit_test.assertRaises(BaseException, run_dygraph, **meta_data)
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
for data_type in SUPPORTED_DTYPES:
meta_data['x_np'] = np_data_generator(
shape_data['x_shape'], dtype=data_type)
meta_data['y_np'] = np_data_generator(
shape_data['y_shape'], dtype=data_type)
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static,
**meta_data)
unit_test.assertRaises(BaseException, run_dygraph,
**meta_data)
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
))
def test_type_error(unit_test, use_npu, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = TypeError
error_type = ValueError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
if type_str_map['x'] != type_str_map['y']:
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if type_str_map['x'] != 'bool':
unit_test.assertRaises(error_type, op, x=x)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.CPUPlace()
......@@ -212,12 +225,10 @@ def test_type_error(unit_test, use_npu, type_str_map):
def type_map_factory():
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
return [{
'x': x_type,
'y': y_type
} for x_type in x_type_list for y_type in y_type_list]
} for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES]
class TestCPU(unittest.TestCase):
......
......@@ -21,6 +21,10 @@ import paddle
import paddle.fluid as fluid
from paddle.static import Program, program_guard
SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
]
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
......@@ -111,13 +115,13 @@ def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True):
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype)
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
......@@ -131,17 +135,20 @@ def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True):
place = paddle.CUDAPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np)
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, *args, **kwargs):
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
else:
return np.random.randn(*np_shape).astype(dtype)
def test(unit_test, use_gpu=False, test_error=False):
......@@ -153,40 +160,46 @@ def test(unit_test, use_gpu=False, test_error=False):
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static, **meta_data)
unit_test.assertRaises(BaseException, run_dygraph, **meta_data)
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
for data_type in SUPPORTED_DTYPES:
meta_data['x_np'] = np_data_generator(
shape_data['x_shape'], dtype=data_type)
meta_data['y_np'] = np_data_generator(
shape_data['y_shape'], dtype=data_type)
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static,
**meta_data)
unit_test.assertRaises(BaseException, run_dygraph,
**meta_data)
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
))
def test_type_error(unit_test, use_gpu, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = TypeError
error_type = ValueError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
if type_str_map['x'] != type_str_map['y']:
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if type_str_map['x'] != 'bool':
unit_test.assertRaises(error_type, op, x=x)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.CPUPlace()
......@@ -213,12 +226,10 @@ def test_type_error(unit_test, use_gpu, type_str_map):
def type_map_factory():
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
return [{
'x': x_type,
'y': y_type
} for x_type in x_type_list for y_type in y_type_list]
} for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES]
class TestCPU(unittest.TestCase):
......
......@@ -25,6 +25,10 @@ import paddle
from op_test_xpu import XPUOpTest
from paddle.static import Program, program_guard
SUPPORTED_DTYPES = [
bool, np.int8, np.int16, np.int32, np.int64, np.float32, np.float64
]
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
......@@ -110,13 +114,13 @@ def run_static_xpu(x_np, y_np, op_str, binary_op=True):
place = paddle.XPUPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
x = paddle.static.data(name='x', shape=x_np.shape, dtype=x_np.dtype)
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
y = paddle.static.data(name='y', shape=y_np.shape, dtype=y_np.dtype)
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
......@@ -128,17 +132,20 @@ def run_dygraph_xpu(x_np, y_np, op_str, binary_op=True):
place = paddle.XPUPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np)
x = paddle.to_tensor(x_np, dtype=x_np.dtype)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np)
y = paddle.to_tensor(y_np, dtype=y_np.dtype)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, *args, **kwargs):
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
else:
return np.random.randn(*np_shape).astype(dtype)
def test_xpu(unit_test, test_error=False):
......@@ -149,40 +156,44 @@ def test_xpu(unit_test, test_error=False):
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static_xpu,
**meta_data)
continue
static_result = run_static_xpu(**meta_data)
dygraph_result = run_dygraph_xpu(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
for data_type in SUPPORTED_DTYPES:
meta_data['x_np'] = np_data_generator(
shape_data['x_shape'], dtype=data_type)
meta_data['y_np'] = np_data_generator(
shape_data['y_shape'], dtype=data_type)
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static_xpu,
**meta_data)
continue
static_result = run_static_xpu(**meta_data)
dygraph_result = run_dygraph_xpu(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all(
))
def test_type_error(unit_test, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = TypeError
error_type = ValueError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
if type_str_map['x'] != type_str_map['y']:
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if type_str_map['x'] != 'bool':
unit_test.assertRaises(error_type, op, x=x)
if not fluid.in_dygraph_mode():
error_type = TypeError
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.XPUPlace(0)
......@@ -208,12 +219,10 @@ def test_type_error(unit_test, type_str_map):
def type_map_factory():
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
return [{
'x': x_type,
'y': y_type
} for x_type in x_type_list for y_type in y_type_list]
} for x_type in SUPPORTED_DTYPES for y_type in SUPPORTED_DTYPES]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册