logical_op.cc 6.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/controlflow/logical_op.h"
16
#include <string>
Y
Yi Wang 已提交
17
#include "paddle/fluid/framework/op_registry.h"
18 19 20 21 22 23

namespace paddle {
namespace operators {
template <typename OpComment>
class BinaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
24
  void Make() override {
25
    OpComment comment;
W
Wilber 已提交
26 27 28 29 30 31 32
    AddInput("X", string::Sprintf("Left hand operand of %s operator. Must be "
                                  "a LoDTensor or Tensor of type bool.",
                                  comment.type));
    AddInput("Y", string::Sprintf("Right hand operand of %s operator. Must be "
                                  "a LoDTensor or Tensor of type bool.",
                                  comment.type));
    AddOutput("Out", string::Sprintf("n-dim bool LoDTensor or Tensor"));
33 34
    AddComment(string::Sprintf(R"DOC(%s Operator

W
Wilber 已提交
35
It operates element-wise on X and Y, and returns the Out. X, Y and Out are N-dim boolean LoDTensor or Tensor.
36 37 38 39 40 41 42 43 44
Each element of Out is calculated by %s
)DOC",
                               comment.type, comment.equation));
  }
};

template <typename OpComment>
class UnaryLogicalOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
45
  void Make() override {
46
    OpComment comment;
W
Wilber 已提交
47 48
    AddInput("X", string::Sprintf("Operand of %s operator. Must be "
                                  "a LoDTensor or Tensor of type bool.",
49
                                  comment.type));
W
Wilber 已提交
50
    AddOutput("Out", string::Sprintf("n-dim bool LoDTensor or Tensor."));
51 52
    AddComment(string::Sprintf(R"DOC(%s Operator

W
Wilber 已提交
53
It operates element-wise on X, and returns the Out. X and Out are N-dim boolean LoDTensor or Tensor.
54 55 56 57 58 59
Each element of Out is calculated by %s
)DOC",
                               comment.type, comment.equation));
  }
};

Z
Zeng Jinle 已提交
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81
class LogicalOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
    // LogicalOp kernel's device type is decided by input tensor place
    kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
    return kt;
  }
};

template <typename OpComment>
class UnaryLogicalOp : public LogicalOp {
 public:
  using LogicalOp::LogicalOp;

 protected:
  void InferShape(framework::InferShapeContext *context) const override {
    OpComment comment;
82
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
Z
Zeng Jinle 已提交
83 84 85 86 87
    context->SetOutputDim("Out", context->GetInputDim("X"));
    context->ShareLoD("X", "Out");
  }
};

88
template <typename OpComment>
Z
Zeng Jinle 已提交
89
class BinaryLogicalOp : public LogicalOp {
90
 public:
Z
Zeng Jinle 已提交
91 92 93 94
  using LogicalOp::LogicalOp;

 protected:
  void InferShape(framework::InferShapeContext *context) const override {
95
    OpComment comment;
96 97
    OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", comment.type);
    OP_INOUT_CHECK(context->HasInput("Y"), "Input", "Y", comment.type);
98 99
    auto dim_x = context->GetInputDim("X");
    auto dim_y = context->GetInputDim("Y");
S
superjomn 已提交
100 101 102

    int product_x = framework::product(dim_x);
    int product_y = framework::product(dim_y);
Y
Yan Chunwei 已提交
103
    bool check = context->IsRuntime() || (product_x >= 0 && product_y >= 0);
S
superjomn 已提交
104
    if (check) {
105 106 107 108 109
      PADDLE_ENFORCE_EQ(product_x, product_y,
                        platform::errors::InvalidArgument(
                            "The number of elements in X and Y should be same, "
                            "but received %d != %d",
                            product_x, product_y));
S
superjomn 已提交
110
    }
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127

    context->SetOutputDim("Out", context->GetInputDim("X"));
    context->ShareLoD("X", "Out");
  }
};

}  // namespace operators
}  // namespace paddle

#define REGISTER_BINARY_LOGICAL_OP(op_type, _equation)                     \
  struct _##op_type##Comment {                                             \
    static char type[];                                                    \
    static char equation[];                                                \
  };                                                                       \
  char _##op_type##Comment::type[]{#op_type};                              \
  char _##op_type##Comment::equation[]{_equation};                         \
  REGISTER_OPERATOR(                                                       \
Z
Zeng Jinle 已提交
128
      op_type, ::paddle::operators::BinaryLogicalOp<_##op_type##Comment>,  \
129
      ::paddle::operators::BinaryLogicalOpProtoMaker<_##op_type##Comment>, \
H
hong 已提交
130 131
      ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,    \
      ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
132 133 134 135 136 137 138 139 140

#define REGISTER_UNARY_LOGICAL_OP(op_type, _equation)                     \
  struct _##op_type##Comment {                                            \
    static char type[];                                                   \
    static char equation[];                                               \
  };                                                                      \
  char _##op_type##Comment::type[]{#op_type};                             \
  char _##op_type##Comment::equation[]{_equation};                        \
  REGISTER_OPERATOR(                                                      \
Z
Zeng Jinle 已提交
141
      op_type, ::paddle::operators::UnaryLogicalOp<_##op_type##Comment>,  \
142
      ::paddle::operators::UnaryLogicalOpProtoMaker<_##op_type##Comment>, \
H
hong 已提交
143 144
      ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,   \
      ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
145

Y
update  
yi.wu 已提交
146
REGISTER_BINARY_LOGICAL_OP(logical_and, "$$Out = X \\&\\& Y$$");
147 148
REGISTER_BINARY_LOGICAL_KERNEL(logical_and, CPU,
                               paddle::operators::LogicalAndFunctor);
Y
update  
yi.wu 已提交
149
REGISTER_BINARY_LOGICAL_OP(logical_or, "$$Out = X || Y$$");
150 151
REGISTER_BINARY_LOGICAL_KERNEL(logical_or, CPU,
                               paddle::operators::LogicalOrFunctor);
Y
update  
yi.wu 已提交
152
REGISTER_UNARY_LOGICAL_OP(logical_not, "$$Out = !X$$");
153 154
REGISTER_UNARY_LOGICAL_KERNEL(logical_not, CPU,
                              paddle::operators::LogicalNotFunctor);
155
REGISTER_BINARY_LOGICAL_OP(logical_xor,
Y
update  
yi.wu 已提交
156
                           "$$Out = (X || Y) \\&\\& !(X \\&\\& Y)$$");
157 158
REGISTER_BINARY_LOGICAL_KERNEL(logical_xor, CPU,
                               paddle::operators::LogicalXorFunctor);