未验证 提交 03eb792d 编写于 作者: zhouweiwei2014's avatar zhouweiwei2014 提交者: GitHub

【Phi】Migrate bitwise_and/bitwise_or/bitwise_xor/bitwise_not op into phi (#40031)

* Migrate bitwise_and/or/xor/not op into phi

* fix CI
上级 d9dd840f
...@@ -21,4 +21,4 @@ endif() ...@@ -21,4 +21,4 @@ endif()
file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(less_than);\nUSE_OP_ITSELF(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n") file(APPEND ${pybind_file} "USE_OP_ITSELF(bitwise_and);\nUSE_OP_ITSELF(bitwise_or);\nUSE_OP_ITSELF(bitwise_xor);\nUSE_OP_ITSELF(bitwise_not);\n")
...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,11 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/controlflow/bitwise_op.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -75,11 +75,19 @@ It operates ``%s`` on Tensor ``X`` . ...@@ -75,11 +75,19 @@ It operates ``%s`` on Tensor ``X`` .
} }
}; };
class BitwiseOp : public framework::OperatorWithKernel { template <typename OpComment>
class UnaryBitwiseOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override { const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx); framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
...@@ -90,23 +98,9 @@ class BitwiseOp : public framework::OperatorWithKernel { ...@@ -90,23 +98,9 @@ class BitwiseOp : public framework::OperatorWithKernel {
}; };
template <typename OpComment> template <typename OpComment>
class UnaryBitwiseOp : public BitwiseOp { class BinaryBitwiseOp : public framework::OperatorWithKernel {
public:
using BitwiseOp::BitwiseOp;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OpComment comment;
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
};
template <typename OpComment>
class BinaryBitwiseOp : public BitwiseOp {
public: public:
using BitwiseOp::BitwiseOp; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContext *context) const override { void InferShape(framework::InferShapeContext *context) const override {
...@@ -130,6 +124,14 @@ class BinaryBitwiseOp : public BitwiseOp { ...@@ -130,6 +124,14 @@ class BinaryBitwiseOp : public BitwiseOp {
} }
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
// BitwiseOp kernel's device type is decided by input tensor place
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
return kt;
}
}; };
} // namespace operators } // namespace operators
...@@ -167,8 +169,3 @@ REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y"); ...@@ -167,8 +169,3 @@ REGISTER_BINARY_BITWISE_OP(bitwise_and, "Out = X \\& Y");
REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y"); REGISTER_BINARY_BITWISE_OP(bitwise_or, "Out = X | Y");
REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y"); REGISTER_BINARY_BITWISE_OP(bitwise_xor, "Out = X ^\\wedge Y");
REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X"); REGISTER_UNARY_BITWISE_OP(bitwise_not, "Out = \\sim X");
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CPU, ops::BitwiseAndFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CPU, ops::BitwiseOrFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CPU, ops::BitwiseXorFunctor);
REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CPU, ops::BitwiseNotFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/bitwise_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace paddle {
namespace operators {
template <typename Functor>
class BinaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto functor = Functor();
std::vector<const framework::Tensor*> ins = {x, y};
std::vector<framework::Tensor*> outs = {out};
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary, T,
T>(cuda_ctx, ins, &outs, -1,
functor);
}
};
template <typename Functor>
class UnaryBitwiseOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using T = typename Functor::ELEM_TYPE;
auto* x = ctx.Input<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto functor = Functor();
std::vector<const framework::Tensor*> ins = {x};
std::vector<framework::Tensor*> outs = {out};
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(cuda_ctx, ins,
&outs, functor);
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform;
REGISTER_BINARY_BITWISE_KERNEL(bitwise_and, CUDA, ops::BitwiseAndFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_or, CUDA, ops::BitwiseOrFunctor);
REGISTER_BINARY_BITWISE_KERNEL(bitwise_xor, CUDA, ops::BitwiseXorFunctor);
REGISTER_UNARY_BITWISE_KERNEL(bitwise_not, CUDA, ops::BitwiseNotFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool a, const bool b) const { \
return a bool_expr b; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
template <typename T>
struct BitwiseNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T a) const { return ~a; }
};
template <>
struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};
template <typename DeviceContext, typename Functor>
class BinaryBitwiseOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto func = Functor();
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
ElementwiseComputeEx<Functor, DeviceContext, T>(context, x, y, -1, func,
out);
}
};
template <typename DeviceContext, typename Functor>
class UnaryBitwiseOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
auto func = Functor();
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(), out->mutable_data<T>(context.GetPlace()),
func);
}
};
} // namespace operators
} // namespace paddle
namespace ops = ::paddle::operators;
namespace plat = ::paddle::platform;
#define REGISTER_BINARY_BITWISE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<bool>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<uint8_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int8_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int16_t>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int>>, \
ops::BinaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int64_t>>);
#define REGISTER_UNARY_BITWISE_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<bool>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<uint8_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int8_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int16_t>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int>>, \
ops::UnaryBitwiseOpKernel<plat::dev##DeviceContext, functor<int64_t>>);
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void BitwiseAndKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void BitwiseOrKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void BitwiseXorKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/bitwise_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/bitwise_functors.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
namespace phi {
#define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Bitwise##op_type##Functor<T> func; \
funcs::ElementwiseCompute<funcs::Bitwise##op_type##Functor<T>, T, T>( \
dev_ctx, x, y, -1, func, out); \
}
DEFINE_BITWISE_KERNEL(And)
DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL
template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
size_t numel = x.numel();
funcs::BitwiseNotFunctor<T> func;
paddle::platform::Transform<Context> trans;
trans(dev_ctx, x_data, x_data + numel, out_data, func);
}
} // namespace phi
PD_REGISTER_KERNEL(bitwise_and,
CPU,
ALL_LAYOUT,
phi::BitwiseAndKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_or,
CPU,
ALL_LAYOUT,
phi::BitwiseOrKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_xor,
CPU,
ALL_LAYOUT,
phi::BitwiseXorKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_not,
CPU,
ALL_LAYOUT,
phi::BitwiseNotKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
HOSTDEVICE bool operator()(const bool a, const bool b) const { \
return a bool_expr b; \
} \
};
BITWISE_BINARY_FUNCTOR(And, &, &&)
BITWISE_BINARY_FUNCTOR(Or, |, ||)
BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
#undef BITWISE_BINARY_FUNCTOR
template <typename T>
struct BitwiseNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T a) const { return ~a; }
};
template <>
struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};
} // namespace funcs
} // namespace phi
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/bitwise_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/bitwise_functors.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
namespace phi {
#define DEFINE_BITWISE_KERNEL(op_type) \
template <typename T, typename Context> \
void Bitwise##op_type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
dev_ctx.template Alloc<T>(out); \
funcs::Bitwise##op_type##Functor<T> func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, T, T>( \
dev_ctx, ins, &outs, -1, func); \
}
DEFINE_BITWISE_KERNEL(And)
DEFINE_BITWISE_KERNEL(Or)
DEFINE_BITWISE_KERNEL(Xor)
#undef DEFINE_BITWISE_KERNEL
template <typename T, typename Context>
void BitwiseNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BitwiseNotFunctor<T> func;
funcs::BroadcastKernel<ElementwiseType::kUnary, T, T>(
dev_ctx, ins, &outs, -1, func);
}
} // namespace phi
PD_REGISTER_KERNEL(bitwise_and,
GPU,
ALL_LAYOUT,
phi::BitwiseAndKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_or,
GPU,
ALL_LAYOUT,
phi::BitwiseOrKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_xor,
GPU,
ALL_LAYOUT,
phi::BitwiseXorKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
PD_REGISTER_KERNEL(bitwise_not,
GPU,
ALL_LAYOUT,
phi::BitwiseNotKernel,
bool,
uint8_t,
int8_t,
int16_t,
int,
int64_t) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册