From 6cfcf6245a67eb39cf5667adb011069c76e55c03 Mon Sep 17 00:00:00 2001 From: Abhinav Arora Date: Sat, 18 Nov 2017 19:02:46 +0530 Subject: [PATCH] Adding logical operators for beam search and control flow (#5708) --- paddle/framework/data_type.h | 5 + paddle/operators/CMakeLists.txt | 5 + paddle/operators/logical_op.cc | 153 ++++++++++++++++++ paddle/operators/logical_op.cu | 24 +++ paddle/operators/logical_op.h | 93 +++++++++++ .../paddle/v2/fluid/tests/test_logical_op.py | 35 ++++ 6 files changed, 315 insertions(+) create mode 100644 paddle/operators/logical_op.cc create mode 100644 paddle/operators/logical_op.cu create mode 100644 paddle/operators/logical_op.h create mode 100644 python/paddle/v2/fluid/tests/test_logical_op.py diff --git a/paddle/framework/data_type.h b/paddle/framework/data_type.h index be144d8fc01..c54d2d4ddf0 100644 --- a/paddle/framework/data_type.h +++ b/paddle/framework/data_type.h @@ -46,6 +46,8 @@ inline std::type_index ToTypeIndex(DataType type) { return typeid(int); case DataType::INT64: return typeid(int64_t); + case DataType::BOOL: + return typeid(bool); default: PADDLE_THROW("Not support type %d", type); } @@ -66,6 +68,9 @@ inline void VisitDataType(DataType type, Visitor visitor) { case DataType::INT64: visitor.template operator()(); break; + case DataType::BOOL: + visitor.template operator()(); + break; default: PADDLE_THROW("Not supported"); } diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 46c2833030c..d0fe5b46351 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -87,6 +87,11 @@ function(op_library TARGET) file(APPEND ${pybind_file} "USE_OP(pool2d_cudnn);\n") endif() + if ("${TARGET}" STREQUAL "logical_op") + set(pybind_flag 1) + file(APPEND ${pybind_file} "USE_OP(logical_and);\n") + endif() + # pool_with_index_op contains several operators if ("${TARGET}" STREQUAL "pool_with_index_op") set(pybind_flag 1) diff --git a/paddle/operators/logical_op.cc b/paddle/operators/logical_op.cc new file mode 100644 index 00000000000..a37582c1d84 --- /dev/null +++ b/paddle/operators/logical_op.cc @@ -0,0 +1,153 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { +template +class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + BinaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", + string::Sprintf("(LoDTensor) Left hand operand of %s operator", + comment.type)); + AddInput("Y", + string::Sprintf("(LoDTensor) Right hand operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + UnaryLogicalOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + OpComment comment; + AddInput("X", string::Sprintf("(LoDTensor) Operand of %s operator", + comment.type)); + AddOutput("Out", string::Sprintf( + "(LoDTensor) n-dim bool tensor. Each element is %s", + comment.equation)); + AddComment(string::Sprintf(R"DOC(%s Operator + +It operates element-wise on X, and returns the Out. X and Out are N-dim boolean tensors. +Each element of Out is calculated by %s +)DOC", + comment.type, comment.equation)); + } +}; + +template +class BinaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + PADDLE_ENFORCE(context->HasInput("Y"), + "Input(Y) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + auto dim_y = context->GetInputDim("Y"); + PADDLE_ENFORCE_EQ(framework::product(dim_x), framework::product(dim_y), + "The number of elements in X and Y should be same"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +template +class UnaryLogicalOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OpComment comment; + PADDLE_ENFORCE(context->HasInput("X"), + "Input(X) of %s operator must not be null", comment.type); + auto dim_x = context->GetInputDim("X"); + + context->SetOutputDim("Out", context->GetInputDim("X")); + context->ShareLoD("X", "Out"); + } +}; + +class LogicalOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + framework::OpKernelType GetKernelType( + const framework::ExecutionContext &ctx) const override { + framework::OpKernelType kt = OperatorWithKernel::GetKernelType(ctx); + // LogicalOp kernel's device type is decided by input tensor place + kt.place_ = ctx.Input("X")->place(); + return kt; + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::BinaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation) \ + struct _##op_type##Comment { \ + static char type[]; \ + static char equation[]; \ + }; \ + char _##op_type##Comment::type[]{#op_type}; \ + char _##op_type##Comment::equation[]{_equation}; \ + REGISTER_OPERATOR( \ + op_type, ::paddle::operators::LogicalOp, \ + ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \ + ::paddle::operators::UnaryLogicalOpInferShape<_##op_type##Comment>, \ + ::paddle::framework::EmptyGradOpMaker); + +REGISTER_BINARY_LOGICAL_OP(logical_and, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_or, "Out = X && Y"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_OP(logical_not, "Out = !X"); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_OP(logical_xor, "Out = (X || Y) && !(X && Y)"); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.cu b/paddle/operators/logical_op.cu new file mode 100644 index 00000000000..d41239b2ca4 --- /dev/null +++ b/paddle/operators/logical_op.cu @@ -0,0 +1,24 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/logical_op.h" + +REGISTER_BINARY_LOGICAL_KERNEL(logical_and, GPU, + paddle::operators::LogicalAndFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_or, GPU, + paddle::operators::LogicalOrFunctor); +REGISTER_UNARY_LOGICAL_KERNEL(logical_not, GPU, + paddle::operators::LogicalNotFunctor); +REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, GPU, + paddle::operators::LogicalXorFunctor); diff --git a/paddle/operators/logical_op.h b/paddle/operators/logical_op.h new file mode 100644 index 00000000000..6e78a7d6ed8 --- /dev/null +++ b/paddle/operators/logical_op.h @@ -0,0 +1,93 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include +#include "paddle/framework/op_registry.h" +#include "paddle/platform/transform.h" + +namespace paddle { +namespace operators { + +template +struct LogicalAndFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; } +}; + +template +struct LogicalOrFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; } +}; + +template +struct LogicalNotFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a) const { return !a; } +}; + +template +struct LogicalXorFunctor { + using ELEM_TYPE = T; + HOSTDEVICE bool operator()(const T& a, const T& b) const { + return (a || b) && !(a && b); + } +}; + +template +class BinaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* y = context.Input("Y"); + auto* out = context.Output("Out"); + Functor binary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + y->data(), out->mutable_data(context.GetPlace()), + binary_func); + } +}; + +template +class UnaryLogicalOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using T = typename Functor::ELEM_TYPE; + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + Functor unary_func; + platform::Transform trans; + trans(context.device_context(), x->data(), x->data() + x->numel(), + out->mutable_data(context.GetPlace()), unary_func); + } +}; + +} // namespace operators +} // namespace paddle + +#define REGISTER_BINARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::BinaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); + +#define REGISTER_UNARY_LOGICAL_KERNEL(op_type, dev, functor) \ + REGISTER_OP_##dev##_KERNEL( \ + op_type, ::paddle::operators::UnaryLogicalOpKernel< \ + ::paddle::platform::dev##Place, functor>); diff --git a/python/paddle/v2/fluid/tests/test_logical_op.py b/python/paddle/v2/fluid/tests/test_logical_op.py new file mode 100644 index 00000000000..ac90bf839cb --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_logical_op.py @@ -0,0 +1,35 @@ +import op_test +import unittest +import numpy as np + + +def create_test_class(op_type, callback, binary_op=True): + class Cls(op_test.OpTest): + def setUp(self): + a = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + if binary_op: + b = np.random.choice(a=[True, False], size=(10, 7)).astype(bool) + c = callback(a, b) + else: + c = callback(a) + self.outputs = {'Out': c} + self.op_type = op_type + if binary_op: + self.inputs = {'X': a, 'Y': b} + else: + self.inputs = {'X': a} + + def test_output(self): + self.check_output() + + Cls.__name__ = op_type + globals()[op_type] = Cls + + +create_test_class('logical_and', lambda _a, _b: np.logical_and(_a, _b)) +create_test_class('logical_or', lambda _a, _b: np.logical_or(_a, _b)) +create_test_class('logical_not', lambda _a: np.logical_not(_a), False) +create_test_class('logical_xor', lambda _a, _b: np.logical_xor(_a, _b)) + +if __name__ == '__main__': + unittest.main() -- GitLab