未验证 提交 8c237973 编写于 作者: A Aurelius84 提交者: GitHub

[Phi] Migrate logical_and/or/not/xor into Phi (#39942)

* [Phi] Migrate logical_and/or/not/xor into Phi

* fix unittest

* fix function name
上级 4da841e0
......@@ -20,5 +20,5 @@ else()
endif()
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(equal_all);\nUSE_NO_KERNEL_OP(read_from_array);\n")
file(APPEND ${pybind_file} "USE_OP(logical_and);\nUSE_OP(logical_or);\nUSE_OP(logical_xor);\nUSE_OP(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP_ITSELF(logical_and);\nUSE_OP_ITSELF(logical_or);\nUSE_OP_ITSELF(logical_xor);\nUSE_OP_ITSELF(logical_not);\n")
file(APPEND ${pybind_file} "USE_OP(bitwise_and);\nUSE_OP(bitwise_or);\nUSE_OP(bitwise_xor);\nUSE_OP(bitwise_not);\n")
......@@ -9,11 +9,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include <algorithm>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
namespace paddle {
namespace operators {
......@@ -145,15 +145,7 @@ class BinaryLogicalOp : public LogicalOp {
::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_BINARY_LOGICAL_OP(logical_and, "$$Out = X \\&\\& Y$$");
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
paddle::operators::LogicalAndFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_or, "$$Out = X || Y$$");
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU,
paddle::operators::LogicalOrFunctor);
REGISTER_UNARY_LOGICAL_OP(logical_not, "$$Out = !X$$");
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
paddle::operators::LogicalNotFunctor);
REGISTER_BINARY_LOGICAL_OP(logical_xor,
"$$Out = (X || Y) \\&\\& !(X \\&\\& Y)$$");
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
paddle::operators::LogicalXorFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
namespace paddle {
namespace operators {
template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
using InT = typename Functor::ELEMENT_TYPE;
using OutT = bool;
auto functor = Functor();
std::vector<const framework::Tensor*> ins;
std::vector<framework::Tensor*> outs;
const auto& cuda_ctx =
ctx.template device_context<platform::CUDADeviceContext>();
int axis = PackTensorsIntoVector<OutT>(ctx, &ins, &outs);
if (ins.size() == 1) {
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kUnary,
InT, OutT>(
cuda_ctx, ins, &outs, axis, functor);
} else {
paddle::operators::LaunchElementwiseCudaKernel<ElementwiseType::kBinary,
InT, OutT>(
cuda_ctx, ins, &outs, axis, functor);
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_LOGICAL_CUDA_KERNEL(op_name, func) \
REGISTER_OP_CUDA_KERNEL( \
op_name, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<bool>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int8_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int16_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<int64_t>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>);
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor)
#undef REGISTER_LOGICAL_CUDA_KERNEL
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <math.h>
#include <type_traits>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
namespace operators {
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T a, const T b) const { \
return static_cast<bool>(a) op static_cast<bool>(b); \
} \
};
LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR
template <typename T>
struct LogicalNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T a) const { return !a; }
};
template <typename DeviceContext, typename Functor>
class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Functor binary_func;
ElementwiseComputeEx<Functor, DeviceContext, T, bool>(context, x, y, -1,
binary_func, out);
}
};
template <typename DeviceContext, typename Functor>
class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func;
platform::Transform<DeviceContext> trans;
trans(context.template device_context<DeviceContext>(), x->data<T>(),
x->data<T>() + x->numel(),
out->mutable_data<bool>(context.GetPlace()), unary_func);
}
};
} // namespace operators
} // namespace paddle
#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int8_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int16_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::operators::BinaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \
REGISTER_OP_##dev##_KERNEL( \
op_type, ::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<bool>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int8_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int16_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<int64_t>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<float>>, \
::paddle::operators::UnaryLogicalOpKernel< \
::paddle::platform::dev##DeviceContext, functor<double>>);
......@@ -9,7 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/controlflow/logical_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logical_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/elementwise.h"
#include "paddle/phi/kernels/funcs/logical_functor.h"
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/transform.h"
namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
funcs::Logical##type##Functor<T> binary_func; \
ElementwiseCompute<funcs::Logical##type##Functor<T>, T, bool>( \
dev_ctx, x, y, -1, binary_func, out); \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
DEFINE_LOGICAL_BINARY_KERNEL(Or)
DEFINE_LOGICAL_BINARY_KERNEL(Xor)
#undef DEFINE_LOGICAL_BINARY_KERNEL
template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
auto* out_ptr = dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
paddle::platform::Transform<Context> trans;
trans(dev_ctx, x.data<T>(), x.data<T>() + x.numel(), out_ptr, unary_func);
}
} // namespace phi
#define REGISTER_LOGICAL_CPU_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \
CPU, \
ALL_LAYOUT, \
phi::Logical##func_type##Kernel, \
float, \
double, \
bool, \
int64_t, \
int, \
int8_t, \
int16_t) {}
REGISTER_LOGICAL_CPU_KERNEL(logical_and, And)
REGISTER_LOGICAL_CPU_KERNEL(logical_or, Or)
REGISTER_LOGICAL_CPU_KERNEL(logical_not, Not)
REGISTER_LOGICAL_CPU_KERNEL(logical_xor, Xor)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace phi {
namespace funcs {
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T a, const T b) const { \
return static_cast<bool>(a) op static_cast<bool>(b); \
} \
};
LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR
template <typename T>
struct LogicalNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T a) const { return !a; }
};
} // namespace funcs
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/logical_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/logical_functor.h"
#include "paddle/phi/kernels/gpu/elementwise.h"
namespace phi {
#define DEFINE_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out) { \
using InT = typename funcs::Logical##type##Functor<T>::ELEMENT_TYPE; \
using OutT = bool; \
dev_ctx.template Alloc<bool>(out); \
funcs::Logical##type##Functor<T> binary_func; \
std::vector<const DenseTensor*> ins = {&x, &y}; \
std::vector<DenseTensor*> outs = {out}; \
funcs::BroadcastKernel<ElementwiseType::kBinary, InT, OutT>( \
dev_ctx, ins, &outs, -1, binary_func); \
}
DEFINE_LOGICAL_BINARY_KERNEL(And)
DEFINE_LOGICAL_BINARY_KERNEL(Or)
DEFINE_LOGICAL_BINARY_KERNEL(Xor)
#undef DEFINE_LOGICAL_BINARY_KERNEL
template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
using InT = typename funcs::LogicalNotFunctor<T>::ELEMENT_TYPE;
using OutT = bool;
dev_ctx.template Alloc<bool>(out);
funcs::LogicalNotFunctor<T> unary_func;
std::vector<const DenseTensor*> ins = {&x};
std::vector<DenseTensor*> outs = {out};
funcs::BroadcastKernel<ElementwiseType::kUnary, InT, OutT>(
dev_ctx, ins, &outs, -1, unary_func);
}
} // namespace phi
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \
GPU, \
ALL_LAYOUT, \
phi::Logical##func_type##Kernel, \
float, \
double, \
bool, \
int64_t, \
int, \
int8_t, \
int16_t) {}
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, And)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, Or)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, Not)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, Xor)
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
#define DECLEAR_LOGICAL_BINARY_KERNEL(type) \
template <typename T, typename Context> \
void Logical##type##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
const DenseTensor& y, \
DenseTensor* out);
DECLEAR_LOGICAL_BINARY_KERNEL(And)
DECLEAR_LOGICAL_BINARY_KERNEL(Or)
DECLEAR_LOGICAL_BINARY_KERNEL(Xor)
#undef DECLEAR_LOGICAL_BINARY_KERNEL
template <typename T, typename Context>
void LogicalNotKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
} // namespace phi
......@@ -55,7 +55,7 @@ class TestDiffOp(unittest.TestCase):
def test_dygraph(self):
for place in self.places:
paddle.disable_static(place)
paddle.disable_static()
x = paddle.to_tensor(self.input, place=place)
if self.prepend is not None:
self.prepend = paddle.to_tensor(self.prepend, place=place)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册