未验证 提交 c282db3a 编写于 作者: J Jack Zhou 提交者: GitHub

add broadcast feature for elementwise logical op

add broadcast feature for elementwise logical op
上级 63eef763
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
......@@ -97,19 +99,19 @@ class BinaryLogicalOp : public LogicalOp {
OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
auto dim_x = context->GetInputDim("X");
auto dim_y = context->GetInputDim("Y");
int product_x = framework::product(dim_x);
int product_y = framework::product(dim_y);
bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
if (check) {
PADDLE_ENFORCE_EQ(product_x, product_y,
platform::errors::InvalidArgument(
"The number of elements in X and Y should be same, "
"but received %d != %d",
product_x, product_y));
if (dim_x == dim_y) {
context->SetOutputDim("Out", dim_x);
} else {
int max_dim = std::max(dim_x.size(), dim_y.size());
int axis = std::abs(dim_x.size() - dim_y.size());
std::vector<int> x_dims_array(max_dim);
std::vector<int> y_dims_array(max_dim);
std::vector<int> out_dims_array(max_dim);
GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(),
y_dims_array.data(), out_dims_array.data(),
max_dim, axis);
context->SetOutputDim("Out", framework::make_ddim(out_dims_array));
}
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
......
......@@ -16,6 +16,7 @@ limitations under the License. */
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
......@@ -57,10 +58,8 @@ class BinaryLogicalOpKernel
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), y->data<T>(),
out->mutable_data<bool>(context.GetPlace()), binary_func);
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, -1,
binary_func, out);
}
};
......
......@@ -12086,6 +12086,13 @@ Examples:
def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
if in_dygraph_mode():
op = getattr(core.ops, op_name)
if binary_op:
return op(x, y)
else:
return op(x)
check_variable_and_dtype(x, "x", ["bool"], op_name)
if y is not None:
check_variable_and_dtype(y, "y", ["bool"], op_name)
......@@ -12110,28 +12117,27 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
return out
@templatedoc()
def logical_and(x, y, out=None, name=None):
"""
:alias_main: paddle.logical_and
:alias: paddle.logical_and, paddle.tensor.logical_and, paddle.tensor.logic.logical_and
:old_api: paddle.fluid.layers.logical_and
``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Variable``.
``logical_and`` operator computes element-wise logical AND on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
out = x \&\& y
.. note::
``paddle.logical_and`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x(${x_type}): ${x_comment}.
y(${y_type}): ${y_comment}.
out(Variable): The ``Variable`` that specifies the output of the operator, which can be any ``Variable`` that has been created in the program. The default value is None, and a new ``Variable`` will be created to save the output.
name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`.
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_type}: ${out_comment}
N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``.
Examples:
.. code-block:: python
......@@ -12140,43 +12146,38 @@ def logical_and(x, y, out=None, name=None):
import numpy as np
paddle.disable_static()
x_data = np.array([True, True, False, False], dtype=np.bool)
x_data = np.array([True], dtype=np.bool)
y_data = np.array([True, False, True, False], dtype=np.bool)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.logical_and(x, y)
print(res.numpy()) # [True False False False]
print(res.numpy()) # [True False True False]
"""
if x.shape != y.shape:
raise TypeError(
'Input tensors must be same shape, but received x \'s shape: %s, y \'s shape: %s '
% (x.shape, y.shape))
return _logical_op(
op_name="logical_and", x=x, y=y, name=name, out=out, binary_op=True)
@templatedoc()
def logical_or(x, y, out=None, name=None):
"""
:alias_main: paddle.logical_or
:alias: paddle.logical_or, paddle.tensor.logical_or, paddle.tensor.logic.logical_or
:old_api: paddle.fluid.layers.logical_or
``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Variable``.
``logical_or`` operator computes element-wise logical OR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
out = x || y
.. note::
``paddle.logical_or`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x(${x_type}): ${x_comment}.
y(${y_type}): ${y_comment}.
out(Variable): The ``Variable`` that specifies the output of the operator, which can be any ``Variable`` that has been created in the program. The default value is None, and a new ``Variable`` will be created to save the output.
name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`.
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
out(Tensor): The ``Variable`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_type}: ${out_comment}
N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``.
Examples:
.. code-block:: python
......@@ -12185,43 +12186,38 @@ def logical_or(x, y, out=None, name=None):
import numpy as np
paddle.disable_static()
x_data = np.array([True, True, False, False], dtype=np.bool)
y_data = np.array([True, False, True, False], dtype=np.bool)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x_data = np.array([True, False], dtype=np.bool).reshape(2, 1)
y_data = np.array([True, False, True, False], dtype=np.bool).reshape(2, 2)
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.logical_or(x, y)
print(res.numpy()) # [True True True False]
print(res.numpy()) # [[ True True] [ True False]]
"""
if x.shape != y.shape:
raise TypeError(
'Input tensors must be same shape, but received x \'s shape: %s, y \'s shape: %s '
% (x.shape, y.shape))
return _logical_op(
op_name="logical_or", x=x, y=y, name=name, out=out, binary_op=True)
@templatedoc()
def logical_xor(x, y, out=None, name=None):
"""
:alias_main: paddle.logical_xor
:alias: paddle.logical_xor, paddle.tensor.logical_xor, paddle.tensor.logic.logical_xor
:old_api: paddle.fluid.layers.logical_xor
``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Variable``.
``logical_xor`` operator computes element-wise logical XOR on ``x`` and ``y``, and returns ``out``. ``x``, ``y`` and ``out`` are N-dim boolean ``Tensor``.
Each element of ``out`` is calculated by
.. math::
out = (x || y) \&\& !(x \&\& y)
.. note::
``paddle.logical_xor`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`.
Args:
x(${x_type}): ${x_comment}.
y(${y_type}): ${y_comment}.
out(Variable): The ``Variable`` that specifies the output of the operator, which can be any ``Variable`` that has been created in the program. The default value is None, and a new ``Variable`` will be created to save the output.
name(str|None): The default value is None. Normally there is no need for users to set this property. For more information, please refer to :ref:`api_guide_Name`.
x (Tensor): the input tensor, it's data type should be bool.
y (Tensor): the input tensor, it's data type should be bool.
out(Tensor): The ``Tensor`` that specifies the output of the operator, which can be any ``Tensor`` that has been created in the program. The default value is None, and a new ``Tensor`` will be created to save the output.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
${out_type}: ${out_comment}
N-D Tensor. A location into which the result is stored. It's dimension equals with ``x``.
Examples:
.. code-block:: python
......@@ -12230,17 +12226,13 @@ def logical_xor(x, y, out=None, name=None):
import numpy as np
paddle.disable_static()
x_data = np.array([True, True, False, False], dtype=np.bool)
y_data = np.array([True, False, True, False], dtype=np.bool)
x = paddle.to_variable(x_data)
y = paddle.to_variable(y_data)
x_data = np.array([True, False], dtype=np.bool).reshape([2, 1])
y_data = np.array([True, False, True, False], dtype=np.bool).reshape([2, 2])
x = paddle.to_tensor(x_data)
y = paddle.to_tensor(y_data)
res = paddle.logical_xor(x, y)
print(res.numpy()) # [False True True False]
print(res.numpy()) # [[False, True], [ True, False]]
"""
if x.shape != y.shape:
raise TypeError(
'Input tensors must be same shape, but received x \'s shape: %s, y \'s shape: %s '
% (x.shape, y.shape))
return _logical_op(
op_name="logical_xor", x=x, y=y, name=name, out=out, binary_op=True)
......
......@@ -21,59 +21,231 @@ import paddle
import paddle.fluid as fluid
from paddle.static import Program, program_guard
TEST_META_OP_DATA = [{
'op_str': 'logical_and',
'binary_op': True
}, {
'op_str': 'logical_or',
'binary_op': True
}, {
'op_str': 'logical_xor',
'binary_op': True
}, {
'op_str': 'logical_not',
'binary_op': False
}]
def create_test_class(op_type, callback, binary_op=True):
class Cls(op_test.OpTest):
def setUp(self):
a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
if binary_op:
b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool)
c = callback(a, b)
else:
c = callback(a)
self.outputs = {'Out': c}
self.op_type = op_type
if binary_op:
self.inputs = {'X': a, 'Y': b}
TEST_META_SHAPE_DATA = {
'XDimLargerThanYDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 5]
},
'XDimLargerThanYDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 1]
},
'XDimLargerThanYDim3': {
'x_shape': [2, 3, 4, 5],
'y_shape': [1, 4, 1]
},
'XDimLargerThanYDim4': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4, 1]
},
'XDimLargerThanYDim5': {
'x_shape': [2, 3, 1, 5],
'y_shape': [3, 1, 1]
},
'XDimLessThanYDim1': {
'x_shape': [4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim2': {
'x_shape': [1, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim3': {
'x_shape': [3, 4, 1],
'y_shape': [2, 3, 4, 5]
},
'XDimLessThanYDim4': {
'x_shape': [3, 1, 1],
'y_shape': [2, 3, 1, 5]
},
'XDimLessThanYDim5': {
'x_shape': [4, 5],
'y_shape': [2, 3, 4, 5]
},
'Axis1InLargerDim': {
'x_shape': [1, 4, 5],
'y_shape': [2, 3, 1, 5]
},
'EqualDim1': {
'x_shape': [10, 7],
'y_shape': [10, 7]
},
'EqualDim2': {
'x_shape': [1, 1, 4, 5],
'y_shape': [2, 3, 1, 5]
}
}
TEST_META_WRONG_SHAPE_DATA = {
'ErrorDim1': {
'x_shape': [2, 3, 4, 5],
'y_shape': [3, 4]
},
'ErrorDim2': {
'x_shape': [2, 3, 4, 5],
'y_shape': [4, 3]
}
}
def run_static(x_np, y_np, op_str, use_gpu=False, binary_op=True):
paddle.enable_static()
startup_program = fluid.Program()
main_program = fluid.Program()
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
exe = fluid.Executor(place)
with fluid.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=x_np.shape, dtype='bool')
op = getattr(paddle, op_str)
feed_list = {'x': x_np}
if not binary_op:
res = op(x)
else:
y = paddle.static.data(name='y', shape=y_np.shape, dtype='bool')
feed_list['y'] = y_np
res = op(x, y)
exe.run(startup_program)
static_result = exe.run(main_program, feed=feed_list, fetch_list=[res])
return static_result
def run_dygraph(x_np, y_np, op_str, use_gpu=False, binary_op=True):
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
paddle.disable_static(place)
op = getattr(paddle, op_str)
x = paddle.to_tensor(x_np)
if not binary_op:
dygraph_result = op(x)
else:
y = paddle.to_tensor(y_np)
dygraph_result = op(x, y)
return dygraph_result
def np_data_generator(np_shape, *args, **kwargs):
return np.random.choice(a=[True, False], size=np_shape).astype(bool)
def test(unit_test, use_gpu=False, test_error=False):
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
meta_data['use_gpu'] = use_gpu
np_op = getattr(np, meta_data['op_str'])
META_DATA = dict(TEST_META_SHAPE_DATA)
if test_error:
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values():
meta_data['x_np'] = np_data_generator(shape_data['x_shape'])
meta_data['y_np'] = np_data_generator(shape_data['y_shape'])
if meta_data['binary_op'] and test_error:
# catch C++ Exception
unit_test.assertRaises(BaseException, run_static, **meta_data)
unit_test.assertRaises(BaseException, run_dygraph, **meta_data)
continue
static_result = run_static(**meta_data)
dygraph_result = run_dygraph(**meta_data)
if meta_data['binary_op']:
np_result = np_op(meta_data['x_np'], meta_data['y_np'])
else:
self.inputs = {'X': a}
def test_output(self):
self.check_output()
def test_error(self):
with program_guard(Program(), Program()):
# test 1 type error, x, y must be bool type
x = fluid.layers.data(name='x', shape=[2], dtype='bool')
y = fluid.layers.data(name='y', shape=[2], dtype='bool')
a = fluid.layers.data(name='a', shape=[2], dtype='int32')
op = eval("fluid.layers.%s" % self.op_type)
if self.op_type != "logical_not":
self.assertRaises(TypeError, op, x=x, y=y, out=1)
self.assertRaises(TypeError, op, x=x, y=a)
self.assertRaises(TypeError, op, x=a, y=y)
else:
self.assertRaises(TypeError, op, x=x, out=1)
self.assertRaises(TypeError, op, x=a)
# test 2 type error, x, y must be same shape
x_data = fluid.layers.data(
name='x_data', shape=[2], dtype='bool')
y_data = fluid.layers.data(
name='y_data', shape=[2, 2], dtype='bool')
if self.op_type != "logical_not":
self.assertRaises(TypeError, op, x=x_data, y=y_data, out=1)
self.assertRaises(TypeError, op, x=y_data, y=x_data)
globals()[op_type] = Cls
create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b))
create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b))
create_test_class('logical_not', lambda _a: np.logical_not(_a), False)
create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b))
np_result = np_op(meta_data['x_np'])
unit_test.assertTrue((static_result == np_result).all())
unit_test.assertTrue((dygraph_result.numpy() == np_result).all())
def test_type_error(unit_test, use_gpu, type_str_map):
def check_type(op_str, x, y, binary_op):
op = getattr(paddle, op_str)
error_type = TypeError
if isinstance(x, np.ndarray):
x = paddle.to_tensor(x)
y = paddle.to_tensor(y)
error_type = BaseException
if binary_op:
if type_str_map['x'] != 'bool' or type_str_map['y'] != 'bool':
unit_test.assertRaises(error_type, op, x=x, y=y)
if not fluid.in_dygraph_mode():
unit_test.assertRaises(error_type, op, x=x, y=y, out=1)
else:
if type_str_map['x'] != 'bool':
unit_test.assertRaises(error_type, op, x=x)
if not fluid.in_dygraph_mode():
unit_test.assertRaises(error_type, op, x=x, out=1)
place = paddle.CPUPlace()
if use_gpu and fluid.core.is_compiled_with_cuda():
place = paddle.CUDAPlace(0)
for op_data in TEST_META_OP_DATA:
meta_data = dict(op_data)
binary_op = meta_data['binary_op']
paddle.disable_static(place)
x = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['x'])
y = np.random.choice(a=[0, 1], size=[10]).astype(type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
paddle.enable_static()
startup_program = paddle.static.Program()
main_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(
name='x', shape=[10], dtype=type_str_map['x'])
y = paddle.static.data(
name='y', shape=[10], dtype=type_str_map['y'])
check_type(meta_data['op_str'], x, y, binary_op)
def type_map_factory():
x_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
y_type_list = ['float32', 'float64', 'int32', 'int64', 'bool']
return [{
'x': x_type,
'y': y_type
} for x_type in x_type_list for y_type in y_type_list]
class TestCPU(unittest.TestCase):
def test(self):
test(self)
def test_error(self):
test(self, False, True)
def test_type_error(self):
type_map_list = type_map_factory()
for type_map in type_map_list:
test_type_error(self, False, type_map)
class TestCUDA(unittest.TestCase):
def test(self):
test(self, True)
def test_error(self):
test(self, True, True)
def test_type_error(self):
type_map_list = type_map_factory()
for type_map in type_map_list:
test_type_error(self, True, type_map)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册