increment_op.cc 3.4 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/increment_op.h"
A
Abhinav Arora 已提交
16 17 18 19

namespace paddle {
namespace operators {

F
fengjiayi 已提交
20
class IncrementOp : public framework::OperatorWithKernel {
A
Abhinav Arora 已提交
21
 public:
F
fengjiayi 已提交
22 23 24 25 26 27
  IncrementOp(const std::string &type, const framework::VariableNameMap &inputs,
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext *ctx) const override {
A
Abhinav Arora 已提交
28 29 30 31
    PADDLE_ENFORCE(ctx->HasInput("X"),
                   "Input(X) of IncrementOp should not be null.");
    PADDLE_ENFORCE(ctx->HasOutput("Out"),
                   "Output(Out) of IncrementOp should not be null.");
Y
Yang Yu 已提交
32
    PADDLE_ENFORCE_EQ(1, framework::product(ctx->GetInputDim("X")));
A
Abhinav Arora 已提交
33
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
F
fengjiayi 已提交
34
    ctx->ShareLoD("X", "Out");
A
Abhinav Arora 已提交
35
  }
36 37 38 39 40 41 42 43 44

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
    // IncrementOp kernel's device type is decided by input tensor place
    kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
    return kt;
  }
A
Abhinav Arora 已提交
45 46 47 48
};

class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
49
  IncrementOpMaker(OpProto *proto, OpAttrChecker *op_checker)
A
Abhinav Arora 已提交
50 51 52
      : OpProtoAndCheckerMaker(proto, op_checker) {
    AddInput("X", "(Tensor) The input tensor of increment operator");
    AddOutput("Out", "(Tensor) The output tensor of increment operator.");
Y
Yu Yang 已提交
53 54 55 56
    AddAttr<float>("step",
                   "(float, default 1.0) "
                   "The step size by which the "
                   "input tensor will be incremented.")
A
Abhinav Arora 已提交
57
        .SetDefault(1.0);
K
kexinzhao 已提交
58 59 60 61 62 63 64
    AddComment(R"DOC(
Increment Operator.

The equation is: 
$$Out = X + step$$

)DOC");
A
Abhinav Arora 已提交
65 66 67 68 69 70 71
  }
};

class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
 public:
  using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;

Y
Yu Yang 已提交
72 73
  std::unique_ptr<framework::OpDesc> Apply() const override {
    auto *grad_op = new framework::OpDesc();
Y
Yang Yu 已提交
74 75 76 77
    grad_op->SetType("increment");
    grad_op->SetInput("X", Output("Out"));
    grad_op->SetOutput("Out", Input("X"));
    grad_op->SetAttr("step", -boost::get<float>(GetAttr("step")));
Y
Yu Yang 已提交
78
    return std::unique_ptr<framework::OpDesc>(grad_op);
A
Abhinav Arora 已提交
79 80 81 82 83 84 85
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
F
fengjiayi 已提交
86 87 88 89 90 91 92
REGISTER_OPERATOR(increment, ops::IncrementOp, ops::IncrementOpMaker,
                  ops::IncrementGradOpMaker);
REGISTER_OP_CPU_KERNEL(
    increment, ops::IncrementKernel<paddle::platform::CPUDeviceContext, float>,
    ops::IncrementKernel<paddle::platform::CPUDeviceContext, double>,
    ops::IncrementKernel<paddle::platform::CPUDeviceContext, int>,
    ops::IncrementKernel<paddle::platform::CPUDeviceContext, int64_t>)