diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 33390745cc8c96bc00b9eab84dfb637a8a76c2f9..a200b948dea45dd0ee9e5ced5fbc38e1eb4349b7 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -208,7 +208,7 @@ function(op_library TARGET) endif() # Define operators that don't need pybind here. - foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "nccl_op" + foreach(manual_pybind_op "compare_all_op" "compare_op" "logical_op" "bitwise_op" "nccl_op" "tensor_array_read_write_op" "tensorrt_engine_op" "conv_fusion_op" "fusion_transpose_flatten_concat_op" "fusion_conv_inception_op" "sync_batch_norm_op" "dgc_op" "fused_fc_elementwise_layernorm_op" diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index e645b379f3c06ae2e83c93b6f1a4d56f57f99d57..14912ac3a7d33322b7aa56996ef24e93b9e13ba9 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -7,8 +7,6 @@ set(pybind_file ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h.tmp CACHE INTE set(pybind_file_final ${PADDLE_BINARY_DIR}/paddle/fluid/pybind/pybind.h) file(WRITE ${pybind_file} "// Generated by the paddle/fluid/operators/CMakeLists.txt. DO NOT EDIT!\n\n") -copy_if_different(${pybind_file} ${pybind_file_final}) - add_subdirectory(math) add_subdirectory(eigen) add_subdirectory(controlflow) @@ -203,3 +201,5 @@ endif() if (WITH_GPU OR WITH_ASCEND_CL) cc_test(copy_cross_scope_test SRCS copy_cross_scope_test.cc DEPS op_registry copy_cross_scope_op scope device_context enforce executor) endif() + +copy_if_different(${pybind_file} ${pybind_file_final}) diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index e23fb05833c0fa428b4f74785ff947a4c785648e..1a2df2a0c7ba34f67ecb7c2ade002fcb4475229f 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -19,4 +19,6 @@ else() target_link_libraries(conditional_block_infer_op conditional_block_op) endif() -file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") +file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n") +file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cc b/paddle/fluid/operators/controlflow/bitwise_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..cfe0d99962190aa282b46e212d01df4b718d1305 --- /dev/null +++ b/paddle/fluid/operators/controlflow/bitwise_op.cc @@ -0,0 +1,174 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/controlflow/bitwise_op.h" +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +class BinaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + OpComment comment; + AddInput("X", string::Sprintf( + "Input Tensor of ``%s`` . It is " + "a N-D Tensor of bool, uint8, int8, int16, int32, int64.", + comment.type)); + AddInput("Y", string::Sprintf( + "Input Tensor of ``%s`` . It is " + "a N-D Tensor of bool, uint8, int8, int16, int32, int64.", + comment.type)); + AddOutput("Out", + string::Sprintf("Result of ``%s`` . It is a N-D Tensor with " + "the same data type of input Tensor.", + comment.type)); + AddComment(string::Sprintf(R"DOC( +It operates ``%s`` on Tensor ``X`` and ``Y`` . + +.. math:: + %s + +.. note:: + ``paddle.%s`` supports broadcasting. If you want know more about broadcasting, please refer to :ref:`user_guide_broadcasting`. +)DOC", + comment.type, comment.equation, comment.type)); + } +}; + +template +class UnaryBitwiseOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + OpComment comment; + AddInput("X", string::Sprintf( + "Input Tensor of ``%s`` . It is " + "a N-D Tensor of bool, uint8, int8, int16, int32, int64.", + comment.type)); + AddOutput("Out", + string::Sprintf("Result of ``%s`` . It is a N-D Tensor with " + "the same data type of input Tensor.", + comment.type)); + AddComment(string::Sprintf(R"DOC( +It operates ``%s`` on Tensor ``X`` . + +.. math:: + %s + +)DOC", + comment.type, comment.equation)); + } +}; + +class BitwiseOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + // BitwiseOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + +template +class UnaryBitwiseOp : public BitwiseOp { + public: + using BitwiseOp::BitwiseOp; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + OpComment comment; + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type); + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +template +class BinaryBitwiseOp : public BitwiseOp { + public: + using BitwiseOp::BitwiseOp; + + protected: + void InferShape(framework::InferShapeContext *context) const override { + OpComment comment; + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type); + OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + if (dim_x == dim_y) { + context->SetOutputDim("Out", dim_x); + } else { + int max_dim = std::max(dim_x.size(), dim_y.size()); + int axis = std::abs(dim_x.size() - dim_y.size()); + std::vector x_dims_array(max_dim); + std::vector y_dims_array(max_dim); + std::vector out_dims_array(max_dim); + GetBroadcastDimsArrays(dim_x, dim_y, x_dims_array.data(), + y_dims_array.data(), out_dims_array.data(), + max_dim, axis); + context->SetOutputDim("Out", framework::make_ddim(out_dims_array)); + } + context->ShareLoD("X", "Out"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = ::paddle::operators; + +#define REGISTER_BINARY_BITWISE_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ops::BinaryBitwiseOp<_##op_type##Comment>, \ + ops::BinaryBitwiseOpProtoMaker<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker, \ + ::paddle::framework::EmptyGradOpMaker); + +#define REGISTER_UNARY_BITWISE_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ops::UnaryBitwiseOp<_##op_type##Comment>, \ + ops::UnaryBitwiseOpProtoMaker<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker, \ + ::paddle::framework::EmptyGradOpMaker); + +REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y"); +REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y"); +REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y"); +REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X"); + +REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CPU, ops::BitwiseAndFunctor); +REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CPU, ops::BitwiseOrFunctor); +REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CPU, ops::BitwiseXorFunctor); +REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CPU, ops::BitwiseNotFunctor); diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cu b/paddle/fluid/operators/controlflow/bitwise_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..b549f7e33005e33a2f73e0617beb2a8b12dd1245 --- /dev/null +++ b/paddle/fluid/operators/controlflow/bitwise_op.cu @@ -0,0 +1,87 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/controlflow/bitwise_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" + +namespace paddle { +namespace operators { + +#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ + template \ + struct Bitwise##func##CUDAFunctor { \ + using ELEM_TYPE = T; \ + HOSTDEVICE T operator()(const T* args) const { \ + return args[0] expr args[1]; \ + } \ + }; \ + \ + template <> \ + struct Bitwise##func##CUDAFunctor { \ + using ELEM_TYPE = bool; \ + HOSTDEVICE bool operator()(const bool* args) const { \ + return args[0] bool_expr args[1]; \ + } \ + }; + +BITWISE_BINARY_FUNCTOR(And, &, &&) +BITWISE_BINARY_FUNCTOR(Or, |, ||) +BITWISE_BINARY_FUNCTOR(Xor, ^, !=) +#undef BITWISE_BINARY_FUNCTOR + +template +struct BitwiseNotCUDAFunctor { + using ELEM_TYPE = T; + HOSTDEVICE T operator()(const T* args) const { return ~args[0]; } +}; + +template <> +struct BitwiseNotCUDAFunctor { + using ELEM_TYPE = bool; + HOSTDEVICE bool operator()(const bool* args) const { return !args[0]; } +}; + +template +class BinaryBitwiseOpKernel + : public framework::OpKernel { + public: + using T = typename Functor::ELEM_TYPE; + void Compute(const framework::ExecutionContext& ctx) const override { + auto functor = Functor(); + std::vector ins; + std::vector outs; + const auto& cuda_ctx = + ctx.template device_context(); + int axis = PackTensorsIntoVector(ctx, &ins, &outs); + + if (ins.size() == 1) { + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, functor); + } else { + LaunchElementwiseCudaKernel( + cuda_ctx, ins, &outs, axis, functor); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = ::paddle::operators; +namespace plat = ::paddle::platform; + +REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndCUDAFunctor); +REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrCUDAFunctor); +REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorCUDAFunctor); +REGISTER_BINARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotCUDAFunctor); diff --git a/paddle/fluid/operators/controlflow/bitwise_op.h b/paddle/fluid/operators/controlflow/bitwise_op.h new file mode 100644 index 0000000000000000000000000000000000000000..92abe4cd3b1c3630ed9c2652f2ff8a49f033f13b --- /dev/null +++ b/paddle/fluid/operators/controlflow/bitwise_op.h @@ -0,0 +1,112 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { + +#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ + template \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = T; \ + HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \ + }; \ + \ + template <> \ + struct Bitwise##func##Functor { \ + using ELEM_TYPE = bool; \ + HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \ + return a bool_expr b; \ + } \ + }; + +BITWISE_BINARY_FUNCTOR(And, &, &&) +BITWISE_BINARY_FUNCTOR(Or, |, ||) +BITWISE_BINARY_FUNCTOR(Xor, ^, !=) +#undef BITWISE_BINARY_FUNCTOR + +template +struct BitwiseNotFunctor { + using ELEM_TYPE = T; + HOSTDEVICE T operator()(const T& a) const { return ~a; } +}; + +template <> +struct BitwiseNotFunctor { + using ELEM_TYPE = bool; + HOSTDEVICE bool operator()(const bool& a) const { return !a; } +}; + +template +class BinaryBitwiseOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto func = Functor(); + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); + ElementwiseComputeEx(context, x, y, -1, func, + out); + } +}; + +template +class UnaryBitwiseOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto func = Functor(); + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + platform::Transform trans; + trans(context.template device_context(), x->data(), + x->data() + x->numel(), out->mutable_data(context.GetPlace()), + func); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = ::paddle::operators; +namespace plat = ::paddle::platform; + +#define REGISTER_BINARY_BITWISE_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, \ + ops::BinaryBitwiseOpKernel>, \ + ops::BinaryBitwiseOpKernel>, \ + ops::BinaryBitwiseOpKernel>, \ + ops::BinaryBitwiseOpKernel>, \ + ops::BinaryBitwiseOpKernel>, \ + ops::BinaryBitwiseOpKernel>); + +#define REGISTER_UNARY_BITWISE_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, \ + ops::UnaryBitwiseOpKernel>, \ + ops::UnaryBitwiseOpKernel>, \ + ops::UnaryBitwiseOpKernel>, \ + ops::UnaryBitwiseOpKernel>, \ + ops::UnaryBitwiseOpKernel>, \ + ops::UnaryBitwiseOpKernel>); diff --git a/paddle/fluid/operators/controlflow/unity_build_rule.cmake b/paddle/fluid/operators/controlflow/unity_build_rule.cmake index 6ed8f8a75374eaba122e7a3b3d935079a81756ee..f75785bd961c2543a20877d6b68d84471df96f41 100644 --- a/paddle/fluid/operators/controlflow/unity_build_rule.cmake +++ b/paddle/fluid/operators/controlflow/unity_build_rule.cmake @@ -12,9 +12,11 @@ register_unity_group(cc fetch_op.cc get_places_op.cc logical_op.cc + bitwise_op.cc tensor_array_read_write_op.cc while_op.cc) register_unity_group(cu logical_op.cu + bitwise_op.cu compare_op.cu compare_all_op.cu) diff --git a/paddle/scripts/paddle_build.bat b/paddle/scripts/paddle_build.bat index c4a93f0d4a1e9f689b1510fe037e2d0e397e01d1..4e501e727207d60ac93745b29ea8b837acea898f 100644 --- a/paddle/scripts/paddle_build.bat +++ b/paddle/scripts/paddle_build.bat @@ -78,7 +78,7 @@ if not defined PYTHON_ROOT set PYTHON_ROOT=C:\Python37 rem -------set cache build directory----------- rmdir build\python /s/q rmdir build\paddle\third_party\externalError /s/q -rmdir build\paddle\fluid\pybind /s/q +rem rmdir build\paddle\fluid\pybind /s/q rmdir build\paddle_install_dir /s/q rmdir build\paddle_inference_install_dir /s/q rmdir build\paddle_inference_c_install_dir /s/q diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b5315a5d19ac7ef85f9c138218ba679082c39335..cc8a43c572c05d749c31ed3cda68f360839ca007 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -108,6 +108,10 @@ from .tensor.logic import logical_and # noqa: F401 from .tensor.logic import logical_not # noqa: F401 from .tensor.logic import logical_or # noqa: F401 from .tensor.logic import logical_xor # noqa: F401 +from .tensor.logic import bitwise_and # noqa: F401 +from .tensor.logic import bitwise_not # noqa: F401 +from .tensor.logic import bitwise_or # noqa: F401 +from .tensor.logic import bitwise_xor # noqa: F401 from .tensor.logic import not_equal # noqa: F401 from .tensor.logic import allclose # noqa: F401 from .tensor.logic import equal_all # noqa: F401 @@ -371,6 +375,10 @@ __all__ = [ # noqa 'max', 'norm', 'logical_or', + 'bitwise_and', + 'bitwise_or', + 'bitwise_xor', + 'bitwise_not', 'mm', 'flip', 'histogram', diff --git a/python/paddle/fluid/dygraph/math_op_patch.py b/python/paddle/fluid/dygraph/math_op_patch.py index a014e0a722ab3280b2738b45de65ef29d189c5bd..83804e80c2309b5b9c1df9907ceb9a37b29ba6b0 100644 --- a/python/paddle/fluid/dygraph/math_op_patch.py +++ b/python/paddle/fluid/dygraph/math_op_patch.py @@ -319,10 +319,13 @@ def monkey_patch_math_varbase(): else: import paddle.tensor # Tensor method from module paddle.tensor - tensor_methods = paddle.tensor.tensor_method_func - for method_name in tensor_methods: + for method_name in paddle.tensor.tensor_method_func: if hasattr(core.VarBase, method_name): continue method_impl = getattr(paddle.tensor, method_name, None) if method_impl: setattr(core.VarBase, method_name, method_impl) + for magic_method, origin_method in paddle.tensor.magic_method_func: + impl = getattr(paddle.tensor, origin_method, None) + if impl: setattr(core.VarBase, magic_method, impl) + _already_patch_varbase = True diff --git a/python/paddle/fluid/layers/math_op_patch.py b/python/paddle/fluid/layers/math_op_patch.py index 2a57c1a907aaccf2f1a511fb11b617cc11143606..9433e0e5ee0e5f5bc203b865a8f9475fbae7063b 100644 --- a/python/paddle/fluid/layers/math_op_patch.py +++ b/python/paddle/fluid/layers/math_op_patch.py @@ -364,10 +364,13 @@ def monkey_patch_variable(): setattr(Variable, method_name, method_impl) else: import paddle.tensor - variabel_methods = paddle.tensor.tensor_method_func - for method_name in variabel_methods: + for method_name in paddle.tensor.tensor_method_func: if hasattr(Variable, method_name): continue method_impl = getattr(paddle.tensor, method_name, None) if method_impl: setattr(Variable, method_name, method_impl) + for magic_method, origin_method in paddle.tensor.magic_method_func: + impl = getattr(paddle.tensor, origin_method, None) + if impl: setattr(Variable, magic_method, impl) + _already_patch_variable = True diff --git a/python/paddle/fluid/tests/unittests/test_bitwise_op.py b/python/paddle/fluid/tests/unittests/test_bitwise_op.py new file mode 100644 index 0000000000000000000000000000000000000000..ead78d75c3dc4663306d473845ac57db48407d02 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bitwise_op.py @@ -0,0 +1,354 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +from op_test import OpTest + +paddle.enable_static() + + +################## TEST OP: BitwiseAnd ################## +class TestBitwiseAnd(OpTest): + def setUp(self): + self.op_type = "bitwise_and" + self.init_dtype() + self.init_shape() + self.init_bound() + + x = np.random.randint( + self.low, self.high, self.x_shape, dtype=self.dtype) + y = np.random.randint( + self.low, self.high, self.y_shape, dtype=self.dtype) + out = np.bitwise_and(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_dtype(self): + self.dtype = np.int32 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [2, 3, 4, 5] + + def init_bound(self): + self.low = -100 + self.high = 100 + + +class TestBitwiseAndUInt8(TestBitwiseAnd): + def init_dtype(self): + self.dtype = np.uint8 + + def init_bound(self): + self.low = 0 + self.high = 100 + + +class TestBitwiseAndInt8(TestBitwiseAnd): + def init_dtype(self): + self.dtype = np.int8 + + def init_shape(self): + self.x_shape = [4, 5] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseAndInt16(TestBitwiseAnd): + def init_dtype(self): + self.dtype = np.int16 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [4, 1] + + +class TestBitwiseAndInt64(TestBitwiseAnd): + def init_dtype(self): + self.dtype = np.int64 + + def init_shape(self): + self.x_shape = [1, 4, 1] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseAndBool(TestBitwiseAnd): + def setUp(self): + self.op_type = "bitwise_and" + self.init_shape() + + x = np.random.choice([True, False], self.x_shape) + y = np.random.choice([True, False], self.y_shape) + out = np.bitwise_and(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + +################## TEST OP: BitwiseOr ################## +class TestBitwiseOr(OpTest): + def setUp(self): + self.op_type = "bitwise_or" + self.init_dtype() + self.init_shape() + self.init_bound() + + x = np.random.randint( + self.low, self.high, self.x_shape, dtype=self.dtype) + y = np.random.randint( + self.low, self.high, self.y_shape, dtype=self.dtype) + out = np.bitwise_or(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_dtype(self): + self.dtype = np.int32 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [2, 3, 4, 5] + + def init_bound(self): + self.low = -100 + self.high = 100 + + +class TestBitwiseOrUInt8(TestBitwiseOr): + def init_dtype(self): + self.dtype = np.uint8 + + def init_bound(self): + self.low = 0 + self.high = 100 + + +class TestBitwiseOrInt8(TestBitwiseOr): + def init_dtype(self): + self.dtype = np.int8 + + def init_shape(self): + self.x_shape = [4, 5] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseOrInt16(TestBitwiseOr): + def init_dtype(self): + self.dtype = np.int16 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [4, 1] + + +class TestBitwiseOrInt64(TestBitwiseOr): + def init_dtype(self): + self.dtype = np.int64 + + def init_shape(self): + self.x_shape = [1, 4, 1] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseOrBool(TestBitwiseOr): + def setUp(self): + self.op_type = "bitwise_or" + self.init_shape() + + x = np.random.choice([True, False], self.x_shape) + y = np.random.choice([True, False], self.y_shape) + out = np.bitwise_or(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + +################## TEST OP: BitwiseXor ################## +class TestBitwiseXor(OpTest): + def setUp(self): + self.op_type = "bitwise_xor" + self.init_dtype() + self.init_shape() + self.init_bound() + + x = np.random.randint( + self.low, self.high, self.x_shape, dtype=self.dtype) + y = np.random.randint( + self.low, self.high, self.y_shape, dtype=self.dtype) + out = np.bitwise_xor(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_dtype(self): + self.dtype = np.int32 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [2, 3, 4, 5] + + def init_bound(self): + self.low = -100 + self.high = 100 + + +class TestBitwiseXorUInt8(TestBitwiseXor): + def init_dtype(self): + self.dtype = np.uint8 + + def init_bound(self): + self.low = 0 + self.high = 100 + + +class TestBitwiseXorInt8(TestBitwiseXor): + def init_dtype(self): + self.dtype = np.int8 + + def init_shape(self): + self.x_shape = [4, 5] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseXorInt16(TestBitwiseXor): + def init_dtype(self): + self.dtype = np.int16 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [4, 1] + + +class TestBitwiseXorInt64(TestBitwiseXor): + def init_dtype(self): + self.dtype = np.int64 + + def init_shape(self): + self.x_shape = [1, 4, 1] + self.y_shape = [2, 3, 4, 5] + + +class TestBitwiseXorBool(TestBitwiseXor): + def setUp(self): + self.op_type = "bitwise_xor" + self.init_shape() + + x = np.random.choice([True, False], self.x_shape) + y = np.random.choice([True, False], self.y_shape) + out = np.bitwise_xor(x, y) + + self.inputs = {'X': x, 'Y': y} + self.outputs = {'Out': out} + + +################## TEST OP: BitwiseNot ################## +class TestBitwiseNot(OpTest): + def setUp(self): + self.op_type = "bitwise_not" + self.init_dtype() + self.init_shape() + self.init_bound() + + x = np.random.randint( + self.low, self.high, self.x_shape, dtype=self.dtype) + out = np.bitwise_not(x) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + pass + + def init_dtype(self): + self.dtype = np.int32 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + + def init_bound(self): + self.low = -100 + self.high = 100 + + +class TestBitwiseNotUInt8(TestBitwiseNot): + def init_dtype(self): + self.dtype = np.uint8 + + def init_bound(self): + self.low = 0 + self.high = 100 + + +class TestBitwiseNotInt8(TestBitwiseNot): + def init_dtype(self): + self.dtype = np.int8 + + def init_shape(self): + self.x_shape = [4, 5] + + +class TestBitwiseNotInt16(TestBitwiseNot): + def init_dtype(self): + self.dtype = np.int16 + + def init_shape(self): + self.x_shape = [2, 3, 4, 5] + self.y_shape = [4, 1] + + +class TestBitwiseNotInt64(TestBitwiseNot): + def init_dtype(self): + self.dtype = np.int64 + + def init_shape(self): + self.x_shape = [1, 4, 1] + + +class TestBitwiseNotBool(TestBitwiseNot): + def setUp(self): + self.op_type = "bitwise_not" + self.init_shape() + + x = np.random.choice([True, False], self.x_shape) + out = np.bitwise_not(x) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch.py b/python/paddle/fluid/tests/unittests/test_math_op_patch.py index fc5e613decddea2f7e2cd5a0e5b672d9bbd8dcfb..b2afda9ed3f2549394815daef7dd9809aa86a23c 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch.py @@ -19,6 +19,7 @@ from decorator_helper import prog_scope import paddle import paddle.fluid as fluid import numpy +import numpy as np class TestMathOpPatches(unittest.TestCase): @@ -270,6 +271,71 @@ class TestMathOpPatches(unittest.TestCase): fetch_list=[b]) self.assertTrue(numpy.allclose(a_np.astype('float32'), b_np)) + @prog_scope() + def test_bitwise_and(self): + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + out_np = x_np & y_np + + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + z = x & y + + exe = fluid.Executor() + out = exe.run(fluid.default_main_program(), + feed={"x": x_np, + "y": y_np}, + fetch_list=[z]) + self.assertTrue(np.array_equal(out[0], out_np)) + + @prog_scope() + def test_bitwise_or(self): + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + out_np = x_np | y_np + + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + z = x | y + + exe = fluid.Executor() + out = exe.run(fluid.default_main_program(), + feed={"x": x_np, + "y": y_np}, + fetch_list=[z]) + self.assertTrue(np.array_equal(out[0], out_np)) + + @prog_scope() + def test_bitwise_xor(self): + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + y_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + out_np = x_np ^ y_np + + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + y = paddle.static.data(name="y", shape=[2, 3, 5], dtype="int32") + z = x ^ y + + exe = fluid.Executor() + out = exe.run(fluid.default_main_program(), + feed={"x": x_np, + "y": y_np}, + fetch_list=[z]) + self.assertTrue(np.array_equal(out[0], out_np)) + + @prog_scope() + def test_bitwise_not(self): + x_np = np.random.randint(-100, 100, [2, 3, 5]).astype("int32") + out_np = ~x_np + + x = paddle.static.data(name="x", shape=[2, 3, 5], dtype="int32") + z = ~x + + exe = fluid.Executor() + out = exe.run(fluid.default_main_program(), + feed={"x": x_np}, + fetch_list=[z]) + self.assertTrue(np.array_equal(out[0], out_np)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py index 4b097f6359f8862d128c568f4de0776c46190a4e..4ad6261293d260e911806ef2f0b7771e1dadad66 100644 --- a/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py +++ b/python/paddle/fluid/tests/unittests/test_math_op_patch_var_base.py @@ -141,6 +141,31 @@ class TestMathOpPatchesVarBase(unittest.TestCase): res = a % b self.assertTrue(np.array_equal(res.numpy(), a_np % b_np)) + # for bitwise and/or/xor/not + def test_bitwise(self): + paddle.disable_static() + + x_np = np.random.randint(-100, 100, [2, 3, 5]) + y_np = np.random.randint(-100, 100, [2, 3, 5]) + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + + out_np = x_np & y_np + out = x & y + self.assertTrue(np.array_equal(out.numpy(), out_np)) + + out_np = x_np | y_np + out = x | y + self.assertTrue(np.array_equal(out.numpy(), out_np)) + + out_np = x_np ^ y_np + out = x ^ y + self.assertTrue(np.array_equal(out.numpy(), out_np)) + + out_np = ~x_np + out = ~x + self.assertTrue(np.array_equal(out.numpy(), out_np)) + # for logical compare def test_equal(self): a_np = np.asarray([1, 2, 3, 4, 5]) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 206aa62adfb779ce89598d5cb84a576bbbe12492..2cb3f54063452c7d95d60c97b879e7563ad1f783 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -54,6 +54,10 @@ from .logic import logical_and # noqa: F401 from .logic import logical_not # noqa: F401 from .logic import logical_or # noqa: F401 from .logic import logical_xor # noqa: F401 +from .logic import bitwise_and # noqa: F401 +from .logic import bitwise_or # noqa: F401 +from .logic import bitwise_xor # noqa: F401 +from .logic import bitwise_not # noqa: F401 from .logic import not_equal # noqa: F401 from .logic import allclose # noqa: F401 from .logic import equal_all # noqa: F401 @@ -352,4 +356,16 @@ tensor_method_func = [ #noqa 'imag', 'trunc' 'digamma' + 'bitwise_and', + 'bitwise_or', + 'bitwise_xor', + 'bitwise_not', +] + +#this list used in math_op_patch.py for magic_method bind +magic_method_func = [ + ('__and__', 'bitwise_and'), + ('__or__', 'bitwise_or'), + ('__xor__', 'bitwise_xor'), + ('__invert__', 'bitwise_not'), ] diff --git a/python/paddle/tensor/logic.py b/python/paddle/tensor/logic.py index f948eeb9a48ebe723e5e22a916cdc624580b7a10..4851c2487bf69296683e5dced47ec6d191203f43 100644 --- a/python/paddle/tensor/logic.py +++ b/python/paddle/tensor/logic.py @@ -16,7 +16,7 @@ from ..fluid.layer_helper import LayerHelper from ..fluid.data_feeder import check_type, check_variable_and_dtype from ..fluid.layers.layer_function_generator import templatedoc from .. import fluid -from ..fluid.framework import in_dygraph_mode +from ..fluid.framework import in_dygraph_mode, Variable from ..framework import VarBase as Tensor # TODO: define logic functions of a tensor @@ -437,3 +437,140 @@ def is_tensor(x): """ return isinstance(x, Tensor) + + +def _bitwise_op(op_name, x, y, out=None, name=None, binary_op=True): + if in_dygraph_mode(): + op = getattr(core.ops, op_name) + if binary_op: + return op(x, y) + else: + return op(x) + + check_variable_and_dtype( + x, "x", ["bool", "uint8", "int8", "int16", "int32", "int64"], op_name) + if y is not None: + check_variable_and_dtype( + y, "y", ["bool", "uint8", "int8", "int16", "int32", "int64"], + op_name) + if out is not None: + check_type(out, "out", Variable, op_name) + + helper = LayerHelper(op_name, **locals()) + if binary_op: + assert x.dtype == y.dtype + + if out is None: + out = helper.create_variable_for_type_inference(dtype=x.dtype) + + if binary_op: + helper.append_op( + type=op_name, inputs={"X": x, + "Y": y}, outputs={"Out": out}) + else: + helper.append_op(type=op_name, inputs={"X": x}, outputs={"Out": out}) + + return out + + +@templatedoc() +def bitwise_and(x, y, out=None, name=None): + """ + ${comment} + + Args: + x (Tensor): ${x_comment} + y (Tensor): ${y_comment} + out(Tensor): ${out_comment} + + Returns: + Tensor: ${out_comment} + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([-5, -1, 1]) + y = paddle.to_tensor([4, 2, -3]) + res = paddle.bitwise_and(x, y) + print(res) # [0, 2, 1] + """ + return _bitwise_op( + op_name="bitwise_and", x=x, y=y, name=name, out=out, binary_op=True) + + +@templatedoc() +def bitwise_or(x, y, out=None, name=None): + """ + ${comment} + + Args: + x (Tensor): ${x_comment} + y (Tensor): ${y_comment} + out(Tensor): ${out_comment} + + Returns: + Tensor: ${out_comment} + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([-5, -1, 1]) + y = paddle.to_tensor([4, 2, -3]) + res = paddle.bitwise_or(x, y) + print(res) # [-1, -1, -3] + """ + return _bitwise_op( + op_name="bitwise_or", x=x, y=y, name=name, out=out, binary_op=True) + + +@templatedoc() +def bitwise_xor(x, y, out=None, name=None): + """ + ${comment} + + Args: + x (Tensor): ${x_comment} + y (Tensor): ${y_comment} + out(Tensor): ${out_comment} + + Returns: + Tensor: ${out_comment} + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([-5, -1, 1]) + y = paddle.to_tensor([4, 2, -3]) + res = paddle.bitwise_xor(x, y) + print(res) # [-1, -3, -4] + """ + return _bitwise_op( + op_name="bitwise_xor", x=x, y=y, name=name, out=out, binary_op=True) + + +@templatedoc() +def bitwise_not(x, out=None, name=None): + """ + ${comment} + + Args: + x(Tensor): ${x_comment} + out(Tensor): ${out_comment} + + Returns: + Tensor: ${out_comment} + + Examples: + .. code-block:: python + + import paddle + x = paddle.to_tensor([-5, -1, 1]) + res = paddle.bitwise_not(x) + print(res) # [4, 0, -2] + """ + + return _bitwise_op( + op_name="bitwise_not", x=x, y=None, name=name, out=out, binary_op=False)