diff --git a/paddle/fluid/operators/controlflow/CMakeLists.txt b/paddle/fluid/operators/controlflow/CMakeLists.txt index 70937069d97cc3eba7a93787ee4f17e76e0fe976..0c18522fa32eae5f357da062fbd25fa92878cc08 100644 --- a/paddle/fluid/operators/controlflow/CMakeLists.txt +++ b/paddle/fluid/operators/controlflow/CMakeLists.txt @@ -21,4 +21,4 @@ endif() file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n") -file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") +file(APPEND ${pybind_file} "USE_OP_ITSELF(bitwise_and);\nUSE_OP_ITSELF(bitwise_or);\nUSE_OP_ITSELF(bitwise_xor);\nUSE_OP_ITSELF(bitwise_not);\n") diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cc b/paddle/fluid/operators/controlflow/bitwise_op.cc index 55cab03ea9e3f18f36043848914ac11fac1027c9..4dcbbc8568ff18a1313171f8f66f276d77f019a1 100644 --- a/paddle/fluid/operators/controlflow/bitwise_op.cc +++ b/paddle/fluid/operators/controlflow/bitwise_op.cc @@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/controlflow/bitwise_op.h" #include #include #include #include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" namespace paddle { namespace operators { @@ -75,11 +75,19 @@ It operates ``%s`` on Tensor ``X`` . } }; -class BitwiseOp : public framework::OperatorWithKernel { +template +class UnaryBitwiseOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; protected: + void InferShape(framework::InferShapeContext *context) const override { + OpComment comment; + OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type); + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } + framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); @@ -90,23 +98,9 @@ class BitwiseOp : public framework::OperatorWithKernel { }; template -class UnaryBitwiseOp : public BitwiseOp { - public: - using BitwiseOp::BitwiseOp; - - protected: - void InferShape(framework::InferShapeContext *context) const override { - OpComment comment; - OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type); - context->SetOutputDim("Out", context->GetInputDim("X")); - context->ShareLoD("X", "Out"); - } -}; - -template -class BinaryBitwiseOp : public BitwiseOp { +class BinaryBitwiseOp : public framework::OperatorWithKernel { public: - using BitwiseOp::BitwiseOp; + using framework::OperatorWithKernel::OperatorWithKernel; protected: void InferShape(framework::InferShapeContext *context) const override { @@ -130,6 +124,14 @@ class BinaryBitwiseOp : public BitwiseOp { } context->ShareLoD("X", "Out"); } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); + // BitwiseOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } }; } // namespace operators @@ -167,8 +169,3 @@ REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y"); REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y"); REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y"); REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X"); - -REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CPU, ops::BitwiseAndFunctor); -REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CPU, ops::BitwiseOrFunctor); -REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CPU, ops::BitwiseXorFunctor); -REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CPU, ops::BitwiseNotFunctor); diff --git a/paddle/fluid/operators/controlflow/bitwise_op.cu b/paddle/fluid/operators/controlflow/bitwise_op.cu deleted file mode 100644 index 5d98da2c027fb6ee681bbea3980f1dbf631d6431..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/controlflow/bitwise_op.cu +++ /dev/null @@ -1,74 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include "paddle/fluid/operators/controlflow/bitwise_op.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" - -namespace paddle { -namespace operators { - -template -class BinaryBitwiseOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using T = typename Functor::ELEM_TYPE; - - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto functor = Functor(); - std::vector ins = {x, y}; - std::vector outs = {out}; - const auto& cuda_ctx = - ctx.template device_context(); - paddle::operators::LaunchElementwiseCudaKernel(cuda_ctx, ins, &outs, -1, - functor); - } -}; - -template -class UnaryBitwiseOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - using T = typename Functor::ELEM_TYPE; - - auto* x = ctx.Input("X"); - auto* out = ctx.Output("Out"); - out->mutable_data(ctx.GetPlace()); - - auto functor = Functor(); - std::vector ins = {x}; - std::vector outs = {out}; - const auto& cuda_ctx = - ctx.template device_context(); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(cuda_ctx, ins, - &outs, functor); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = ::paddle::operators; -namespace plat = ::paddle::platform; - -REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndFunctor); -REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrFunctor); -REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorFunctor); -REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotFunctor); diff --git a/paddle/fluid/operators/controlflow/bitwise_op.h b/paddle/fluid/operators/controlflow/bitwise_op.h deleted file mode 100644 index 9e652f92007479684fcf8ec5e539312d8d729107..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/controlflow/bitwise_op.h +++ /dev/null @@ -1,112 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_function.h" -#include "paddle/fluid/platform/transform.h" - -namespace paddle { -namespace operators { - -#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ - template \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = T; \ - HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \ - }; \ - \ - template <> \ - struct Bitwise##func##Functor { \ - using ELEM_TYPE = bool; \ - HOSTDEVICE bool operator()(const bool a, const bool b) const { \ - return a bool_expr b; \ - } \ - }; - -BITWISE_BINARY_FUNCTOR(And, &, &&) -BITWISE_BINARY_FUNCTOR(Or, |, ||) -BITWISE_BINARY_FUNCTOR(Xor, ^, !=) -#undef BITWISE_BINARY_FUNCTOR - -template -struct BitwiseNotFunctor { - using ELEM_TYPE = T; - HOSTDEVICE T operator()(const T a) const { return ~a; } -}; - -template <> -struct BitwiseNotFunctor { - using ELEM_TYPE = bool; - HOSTDEVICE bool operator()(const bool a) const { return !a; } -}; - -template -class BinaryBitwiseOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - auto func = Functor(); - auto* x = context.Input("X"); - auto* y = context.Input("Y"); - auto* out = context.Output("Out"); - ElementwiseComputeEx(context, x, y, -1, func, - out); - } -}; - -template -class UnaryBitwiseOpKernel - : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - using T = typename Functor::ELEM_TYPE; - auto func = Functor(); - auto* x = context.Input("X"); - auto* out = context.Output("Out"); - platform::Transform trans; - trans(context.template device_context(), x->data(), - x->data() + x->numel(), out->mutable_data(context.GetPlace()), - func); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = ::paddle::operators; -namespace plat = ::paddle::platform; - -#define REGISTER_BINARY_BITWISE_KERNEL(op_type, dev, functor) \ - REGISTER_OP_##dev##_KERNEL( \ - op_type, \ - ops::BinaryBitwiseOpKernel>, \ - ops::BinaryBitwiseOpKernel>, \ - ops::BinaryBitwiseOpKernel>, \ - ops::BinaryBitwiseOpKernel>, \ - ops::BinaryBitwiseOpKernel>, \ - ops::BinaryBitwiseOpKernel>); - -#define REGISTER_UNARY_BITWISE_KERNEL(op_type, dev, functor) \ - REGISTER_OP_##dev##_KERNEL( \ - op_type, \ - ops::UnaryBitwiseOpKernel>, \ - ops::UnaryBitwiseOpKernel>, \ - ops::UnaryBitwiseOpKernel>, \ - ops::UnaryBitwiseOpKernel>, \ - ops::UnaryBitwiseOpKernel>, \ - ops::UnaryBitwiseOpKernel>); diff --git a/paddle/phi/kernels/bitwise_kernel.h b/paddle/phi/kernels/bitwise_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..17307004f360e10ada708fba276ec8de1a129259 --- /dev/null +++ b/paddle/phi/kernels/bitwise_kernel.h @@ -0,0 +1,44 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void BitwiseAndKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +template +void BitwiseOrKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +template +void BitwiseXorKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + DenseTensor* out); + +template +void BitwiseNotKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/bitwise_kernel.cc b/paddle/phi/kernels/cpu/bitwise_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..69f52790f77969e8bf29fcb50b777afe504215b7 --- /dev/null +++ b/paddle/phi/kernels/cpu/bitwise_kernel.cc @@ -0,0 +1,99 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/bitwise_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/bitwise_functors.h" +#include "paddle/phi/kernels/funcs/elementwise_base.h" + +// See Note [ Why still include the fluid headers? ] +#include "paddle/fluid/platform/transform.h" + +namespace phi { + +#define DEFINE_BITWISE_KERNEL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + funcs::Bitwise##op_type##Functor func; \ + funcs::ElementwiseCompute, T, T>( \ + dev_ctx, x, y, -1, func, out); \ + } + +DEFINE_BITWISE_KERNEL(And) +DEFINE_BITWISE_KERNEL(Or) +DEFINE_BITWISE_KERNEL(Xor) +#undef DEFINE_BITWISE_KERNEL + +template +void BitwiseNotKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + const T* x_data = x.data(); + T* out_data = dev_ctx.template Alloc(out); + size_t numel = x.numel(); + funcs::BitwiseNotFunctor func; + paddle::platform::Transform trans; + trans(dev_ctx, x_data, x_data + numel, out_data, func); +} + +} // namespace phi + +PD_REGISTER_KERNEL(bitwise_and, + CPU, + ALL_LAYOUT, + phi::BitwiseAndKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_or, + CPU, + ALL_LAYOUT, + phi::BitwiseOrKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_xor, + CPU, + ALL_LAYOUT, + phi::BitwiseXorKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_not, + CPU, + ALL_LAYOUT, + phi::BitwiseNotKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} diff --git a/paddle/phi/kernels/funcs/bitwise_functors.h b/paddle/phi/kernels/funcs/bitwise_functors.h new file mode 100644 index 0000000000000000000000000000000000000000..db1fc59f534bcf752d1b010508b4a1adcb097651 --- /dev/null +++ b/paddle/phi/kernels/funcs/bitwise_functors.h @@ -0,0 +1,51 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +namespace phi { +namespace funcs { + +#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \ + template \ + struct Bitwise##func##Functor { \ + HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \ + }; \ + \ + template <> \ + struct Bitwise##func##Functor { \ + HOSTDEVICE bool operator()(const bool a, const bool b) const { \ + return a bool_expr b; \ + } \ + }; + +BITWISE_BINARY_FUNCTOR(And, &, &&) +BITWISE_BINARY_FUNCTOR(Or, |, ||) +BITWISE_BINARY_FUNCTOR(Xor, ^, !=) +#undef BITWISE_BINARY_FUNCTOR + +template +struct BitwiseNotFunctor { + using ELEM_TYPE = T; + HOSTDEVICE T operator()(const T a) const { return ~a; } +}; + +template <> +struct BitwiseNotFunctor { + using ELEM_TYPE = bool; + HOSTDEVICE bool operator()(const bool a) const { return !a; } +}; + +} // namespace funcs +} // namespace phi diff --git a/paddle/phi/kernels/gpu/bitwise_kernel.cu b/paddle/phi/kernels/gpu/bitwise_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..e88ecef318a874d338094dfc9575b732cdd7680a --- /dev/null +++ b/paddle/phi/kernels/gpu/bitwise_kernel.cu @@ -0,0 +1,98 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/phi/kernels/bitwise_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/bitwise_functors.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" +namespace phi { + +#define DEFINE_BITWISE_KERNEL(op_type) \ + template \ + void Bitwise##op_type##Kernel(const Context& dev_ctx, \ + const DenseTensor& x, \ + const DenseTensor& y, \ + DenseTensor* out) { \ + dev_ctx.template Alloc(out); \ + funcs::Bitwise##op_type##Functor func; \ + std::vector ins = {&x, &y}; \ + std::vector outs = {out}; \ + funcs::BroadcastKernel( \ + dev_ctx, ins, &outs, -1, func); \ + } + +DEFINE_BITWISE_KERNEL(And) +DEFINE_BITWISE_KERNEL(Or) +DEFINE_BITWISE_KERNEL(Xor) +#undef DEFINE_BITWISE_KERNEL + +template +void BitwiseNotKernel(const Context& dev_ctx, + const DenseTensor& x, + DenseTensor* out) { + dev_ctx.template Alloc(out); + std::vector ins = {&x}; + std::vector outs = {out}; + funcs::BitwiseNotFunctor func; + funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, func); +} + +} // namespace phi + +PD_REGISTER_KERNEL(bitwise_and, + GPU, + ALL_LAYOUT, + phi::BitwiseAndKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_or, + GPU, + ALL_LAYOUT, + phi::BitwiseOrKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_xor, + GPU, + ALL_LAYOUT, + phi::BitwiseXorKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {} + +PD_REGISTER_KERNEL(bitwise_not, + GPU, + ALL_LAYOUT, + phi::BitwiseNotKernel, + bool, + uint8_t, + int8_t, + int16_t, + int, + int64_t) {}