minus_op.cc 4.3 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/minus_op.h"
16

P
phlrain 已提交
17
#include <memory>
18
#include <string>
19
#include <utility>
20
#include <vector>
Y
Yu Yang 已提交
21 22 23 24 25 26

namespace paddle {
namespace operators {

class MinusOp : public framework::OperatorWithKernel {
 public:
27 28 29
  MinusOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
Y
Yu Yang 已提交
30 31
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

32
  void InferShape(framework::InferShapeContext *ctx) const override {
33 34 35 36 37 38 39 40 41
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("X"), true,
        platform::errors::NotFound("Input(X) of MinusOp is not found."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Y"), true,
        platform::errors::NotFound("Input(Y) of MinusOp is not found."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Out"), true,
        platform::errors::NotFound("Output(Out) of MinusOp is not found."));
42

Q
Qiao Longfei 已提交
43 44
    auto x_dims = ctx->GetInputDim("X");
    auto y_dims = ctx->GetInputDim("Y");
Y
Yu Yang 已提交
45

P
phlrain 已提交
46 47
    if (ctx->IsRuntime() ||
        (framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
P
phlrain 已提交
48 49 50 51
      PADDLE_ENFORCE_EQ(
          x_dims, y_dims,
          "Minus operator must take two tensor with same num of elements");
    }
Q
Qiao Longfei 已提交
52 53
    ctx->SetOutputDim("Out", x_dims);
    ctx->ShareLoD("X", /*->*/ "Out");
Y
Yu Yang 已提交
54 55 56 57 58
  }
};

class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
59
  void Make() override {
60 61 62
    AddInput("X", "The left tensor of minus operator.");
    AddInput("Y", "The right tensor of minus operator.");
    AddOutput("Out", "The output tensor of minus operator.");
Y
Yu Yang 已提交
63

K
kexinzhao 已提交
64 65
    AddComment(R"DOC(
Minus Operator.
Y
Yu Yang 已提交
66

67 68
Equation:

K
kexinzhao 已提交
69
    $Out = X - Y$
70 71

Both the input `X` and `Y` can carry the LoD (Level of Details) information,
K
kexinzhao 已提交
72 73
or not. But the output only shares the LoD information with input `X`.

Y
Yu Yang 已提交
74 75 76
)DOC");
  }
};
77

78
class MinusGradDescMaker : public framework::GradOpDescMakerBase {
Y
Yu Yang 已提交
79
 public:
80 81
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
82 83
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
    std::vector<std::unique_ptr<framework::OpDesc>> ops;
84
    auto x_g = this->InputGrad("X");
Y
Yu Yang 已提交
85
    if (!x_g.empty()) {
Y
Yu Yang 已提交
86
      auto *x_g_op = new framework::OpDesc();
Y
Yu Yang 已提交
87
      x_g_op->SetType("scale");
88
      x_g_op->SetInput("X", this->OutputGrad("Out"));
Y
Yu Yang 已提交
89 90 91 92 93
      x_g_op->SetOutput("Out", x_g);
      x_g_op->SetAttr("scale", 1.0f);
      ops.emplace_back(x_g_op);
    }

94
    auto y_g = this->InputGrad("Y");
Y
Yu Yang 已提交
95
    if (!y_g.empty()) {
Y
Yu Yang 已提交
96
      auto *y_g_op = new framework::OpDesc();
Y
Yu Yang 已提交
97
      y_g_op->SetType("scale");
98 99 100 101 102 103 104 105 106 107 108 109 110 111
      y_g_op->SetInput("X", this->OutputGrad("Out"));
      y_g_op->SetOutput("Out", y_g);
      y_g_op->SetAttr("scale", -1.0f);
      ops.emplace_back(y_g_op);
    }

    return ops;
  }
};

class MinusGradMaker : public imperative::GradOpBaseMakerBase {
 public:
  using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;

112
  std::shared_ptr<imperative::GradOpNode> operator()() const override {
113
    auto x_g = this->InputGrad("X");
114 115 116 117
    auto y_g = this->InputGrad("Y");

    auto node = this->NewGradNode();

118
    if (!x_g.empty()) {
119
      imperative::TracedGradOp op(node);
120 121 122 123
      op.SetType("scale");
      op.SetInput("X", this->OutputGrad("Out"));
      op.SetOutput("Out", x_g);
      op.SetAttr("scale", 1.0f);
124 125 126
    }

    if (!y_g.empty()) {
127
      imperative::TracedGradOp op(node);
128 129 130 131
      op.SetType("scale");
      op.SetInput("X", this->OutputGrad("Out"));
      op.SetOutput("Out", y_g);
      op.SetAttr("scale", -1.0f);
Y
Yu Yang 已提交
132 133
    }

134
    return node;
Y
Yu Yang 已提交
135 136 137 138 139 140 141
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
142 143
REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker,
                  ops::MinusGradDescMaker, ops::MinusGradMaker);
Q
QI JUN 已提交
144 145
REGISTER_OP_CPU_KERNEL(
    minus, ops::MinusKernel<paddle::platform::CPUDeviceContext, float>);