minus_op.cc 4.5 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/minus_op.h"
16

P
phlrain 已提交
17
#include <memory>
18 19
#include <string>
#include <vector>
Y
Yu Yang 已提交
20 21 22 23 24 25

namespace paddle {
namespace operators {

class MinusOp : public framework::OperatorWithKernel {
 public:
26 27 28
  MinusOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
Y
Yu Yang 已提交
29 30
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

31
  void InferShape(framework::InferShapeContext *ctx) const override {
32 33 34 35 36 37 38 39 40
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("X"), true,
        platform::errors::NotFound("Input(X) of MinusOp is not found."));
    PADDLE_ENFORCE_EQ(
        ctx->HasInput("Y"), true,
        platform::errors::NotFound("Input(Y) of MinusOp is not found."));
    PADDLE_ENFORCE_EQ(
        ctx->HasOutput("Out"), true,
        platform::errors::NotFound("Output(Out) of MinusOp is not found."));
41

Q
Qiao Longfei 已提交
42 43
    auto x_dims = ctx->GetInputDim("X");
    auto y_dims = ctx->GetInputDim("Y");
Y
Yu Yang 已提交
44

P
phlrain 已提交
45 46
    if (ctx->IsRuntime() ||
        (framework::product(x_dims) > 0 && framework::product(y_dims) > 0)) {
P
phlrain 已提交
47 48 49 50
      PADDLE_ENFORCE_EQ(
          x_dims, y_dims,
          "Minus operator must take two tensor with same num of elements");
    }
Q
Qiao Longfei 已提交
51 52
    ctx->SetOutputDim("Out", x_dims);
    ctx->ShareLoD("X", /*->*/ "Out");
Y
Yu Yang 已提交
53 54 55 56 57
  }
};

class MinusOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
58
  void Make() override {
59 60 61
    AddInput("X", "The left tensor of minus operator.");
    AddInput("Y", "The right tensor of minus operator.");
    AddOutput("Out", "The output tensor of minus operator.");
Y
Yu Yang 已提交
62

K
kexinzhao 已提交
63 64
    AddComment(R"DOC(
Minus Operator.
Y
Yu Yang 已提交
65

66 67
Equation:

K
kexinzhao 已提交
68
    $Out = X - Y$
69 70

Both the input `X` and `Y` can carry the LoD (Level of Details) information,
K
kexinzhao 已提交
71 72
or not. But the output only shares the LoD information with input `X`.

Y
Yu Yang 已提交
73 74 75
)DOC");
  }
};
76

77
class MinusGradDescMaker : public framework::GradOpDescMakerBase {
Y
Yu Yang 已提交
78
 public:
79 80
  using framework::GradOpDescMakerBase::GradOpDescMakerBase;

Y
Yu Yang 已提交
81 82
  std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
    std::vector<std::unique_ptr<framework::OpDesc>> ops;
83
    auto x_g = this->InputGrad("X");
Y
Yu Yang 已提交
84
    if (!x_g.empty()) {
Y
Yu Yang 已提交
85
      auto *x_g_op = new framework::OpDesc();
Y
Yu Yang 已提交
86
      x_g_op->SetType("scale");
87
      x_g_op->SetInput("X", this->OutputGrad("Out"));
Y
Yu Yang 已提交
88 89 90 91 92
      x_g_op->SetOutput("Out", x_g);
      x_g_op->SetAttr("scale", 1.0f);
      ops.emplace_back(x_g_op);
    }

93
    auto y_g = this->InputGrad("Y");
Y
Yu Yang 已提交
94
    if (!y_g.empty()) {
Y
Yu Yang 已提交
95
      auto *y_g_op = new framework::OpDesc();
Y
Yu Yang 已提交
96
      y_g_op->SetType("scale");
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
      y_g_op->SetInput("X", this->OutputGrad("Out"));
      y_g_op->SetOutput("Out", y_g);
      y_g_op->SetAttr("scale", -1.0f);
      ops.emplace_back(y_g_op);
    }

    return ops;
  }
};

class MinusGradMaker : public imperative::GradOpBaseMakerBase {
 public:
  using imperative::GradOpBaseMakerBase::GradOpBaseMakerBase;

  std::vector<std::unique_ptr<imperative::OpBase>> operator()() const override {
    std::vector<std::unique_ptr<imperative::OpBase>> ops;
    auto x_g = this->InputGrad("X");
    if (!x_g.empty()) {
      auto *x_g_op = new imperative::OpBase();
      x_g_op->SetType("scale");
      x_g_op->SetInput("X", this->OutputGrad("Out"));
      x_g_op->SetOutput("Out", x_g);
      x_g_op->SetAttr("scale", 1.0f);
      ops.emplace_back(x_g_op);
    }

    auto y_g = this->InputGrad("Y");
    if (!y_g.empty()) {
      auto *y_g_op = new imperative::OpBase();
      y_g_op->SetType("scale");
      y_g_op->SetInput("X", this->OutputGrad("Out"));
Y
Yu Yang 已提交
128 129 130 131 132
      y_g_op->SetOutput("Out", y_g);
      y_g_op->SetAttr("scale", -1.0f);
      ops.emplace_back(y_g_op);
    }

133
    return ops;
Y
Yu Yang 已提交
134 135 136 137 138 139 140
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
141 142
REGISTER_OPERATOR(minus, ops::MinusOp, ops::MinusOpMaker,
                  ops::MinusGradDescMaker, ops::MinusGradMaker);
Q
QI JUN 已提交
143 144
REGISTER_OP_CPU_KERNEL(
    minus, ops::MinusKernel<paddle::platform::CPUDeviceContext, float>);