compare_op.cc 4.9 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

F
From00 已提交
15
#include "paddle/fluid/framework/infershape_utils.h"
Y
Yi Wang 已提交
16
#include "paddle/fluid/framework/op_registry.h"
Z
Zhong Hui 已提交
17
#include "paddle/fluid/framework/op_version_registry.h"
18
#include "paddle/phi/common/place.h"
F
From00 已提交
19 20
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
21

Y
Yu Yang 已提交
22 23
namespace paddle {
namespace operators {
Y
Yiqun Liu 已提交
24

Y
Yu Yang 已提交
25 26 27
template <typename OpComment>
class CompareOpProtoMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
28
  void Make() override {
Y
Yu Yang 已提交
29
    OpComment comment;
Y
yuyang18 已提交
30 31 32 33
    AddInput("X", string::Sprintf("the left hand operand of %s operator",
                                  comment.type));
    AddInput("Y", string::Sprintf("the right hand operand of %s operator",
                                  comment.type));
34 35 36 37 38
    AddAttr<int>(
        "axis",
        "The start dimension index for broadcasting Y onto X. [default -1]")
        .SetDefault(-1)
        .EqualGreaterThan(-1);
J
JiayiFeng 已提交
39
    AddAttr<bool>("force_cpu",
Y
yuyang18 已提交
40
                  "Force fill output variable to cpu "
J
JiayiFeng 已提交
41
                  "memory. Otherwise, fill output variable to the running "
Y
yuyang18 已提交
42
                  "device [default true].")
43
        .SetDefault(false);
Y
yuyang18 已提交
44 45
    AddOutput("Out", string::Sprintf("n-dim bool tensor. Each element is %s",
                                     comment.equation));
Y
yuyang18 已提交
46
    AddComment(string::Sprintf(R"DOC(
Y
Yu Yang 已提交
47 48
It operates element-wise on X and Y, and returns the Out. Each of them is a
N-dim tensor. X and Y could be any type.  The each element of the Out tensor is
Y
yuyang18 已提交
49
calculated by $%s$
Y
Yu Yang 已提交
50
)DOC",
Y
yuyang18 已提交
51
                               comment.equation));
Y
Yu Yang 已提交
52 53 54 55
  }
};

template <typename OpComment>
Z
Zeng Jinle 已提交
56
class CompareOp : public framework::OperatorWithKernel {
Y
Yu Yang 已提交
57
 public:
Z
Zeng Jinle 已提交
58 59 60
  using framework::OperatorWithKernel::OperatorWithKernel;

 protected:
61
  framework::OpKernelType GetExpectedKernelType(
Y
Yiqun Liu 已提交
62
      const framework::ExecutionContext& ctx) const override {
63
    framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
64
    // CompareOp kernel's device type is decided by input tensor place
J
JiayiFeng 已提交
65
    bool force_cpu = ctx.Attr<bool>("force_cpu");
66 67 68
    if (force_cpu) {
      kt.place_ = platform::CPUPlace();
    } else {
69
      if (ctx.Input<framework::LoDTensor>("X")->place().GetType() !=
70
          phi::AllocationType::GPUPINNED) {
71 72 73 74 75
        kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
      } else {
        kt.place_ = ctx.GetPlace();
      }
    }
76 77 78 79
    return kt;
  }
};

Y
Yu Yang 已提交
80 81 82
}  // namespace operators
}  // namespace paddle

Z
Zhong Hui 已提交
83 84 85 86
#define REGISTER_COMPARE_OP_VERSION(op_type)                               \
  REGISTER_OP_VERSION(op_type)                                             \
      .AddCheckpoint(                                                      \
          R"ROC(Upgrade compare ops, add a new attribute [force_cpu])ROC", \
87
          paddle::framework::compatible::OpVersionDesc().ModifyAttr(       \
Z
Zhong Hui 已提交
88
              "force_cpu",                                                 \
89
              "In order to force fill output variable to gpu memory.",     \
Z
Zhong Hui 已提交
90 91
              false));

F
From00 已提交
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
#define REGISTER_COMPARE_OP(op_type, _equation)                          \
  struct _##op_type##Comment {                                           \
    static char type[];                                                  \
    static char equation[];                                              \
  };                                                                     \
  char _##op_type##Comment::type[]{#op_type};                            \
  char _##op_type##Comment::equation[]{_equation};                       \
  DELCARE_INFER_SHAPE_FUNCTOR(op_type, op_type##_InferShapeFunctor,      \
                              PT_INFER_META(phi::CompareInferMeta));     \
  REGISTER_OPERATOR(                                                     \
      op_type, ::paddle::operators::CompareOp<_##op_type##Comment>,      \
      ::paddle::operators::CompareOpProtoMaker<_##op_type##Comment>,     \
      ::paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,  \
      ::paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, \
      op_type##_InferShapeFunctor);                                      \
Z
Zhong Hui 已提交
107
  REGISTER_COMPARE_OP_VERSION(op_type);
Y
Yu Yang 已提交
108

Q
qiaolongfei 已提交
109
REGISTER_COMPARE_OP(less_than, "Out = X < Y");
F
From00 已提交
110

Q
qiaolongfei 已提交
111
REGISTER_COMPARE_OP(less_equal, "Out = X <= Y");
F
From00 已提交
112

Q
qiaolongfei 已提交
113
REGISTER_COMPARE_OP(greater_than, "Out = X > Y");
F
From00 已提交
114

Q
qiaolongfei 已提交
115
REGISTER_COMPARE_OP(greater_equal, "Out = X >= Y");
F
From00 已提交
116

Q
qiaolongfei 已提交
117
REGISTER_COMPARE_OP(equal, "Out = X == Y");
F
From00 已提交
118

Q
qiaolongfei 已提交
119
REGISTER_COMPARE_OP(not_equal, "Out = X != Y");