increment_op.cc 3.3 KB
Newer Older
F
fengjiayi 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
//   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/framework/infershape_utils.h"
16
#include "paddle/fluid/framework/op_registry.h"
17 18
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
A
Abhinav Arora 已提交
19

W
wanghuancoder 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle

A
Abhinav Arora 已提交
33 34 35
namespace paddle {
namespace operators {

F
fengjiayi 已提交
36
class IncrementOp : public framework::OperatorWithKernel {
A
Abhinav Arora 已提交
37
 public:
38 39
  IncrementOp(const std::string &type,
              const framework::VariableNameMap &inputs,
F
fengjiayi 已提交
40 41 42 43
              const framework::VariableNameMap &outputs,
              const framework::AttributeMap &attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

44 45 46 47 48 49 50 51
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext &ctx) const override {
    framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
    // IncrementOp kernel's device type is decided by input tensor place
    kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
    return kt;
  }
A
Abhinav Arora 已提交
52 53 54 55
};

class IncrementOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
56
  void Make() override {
A
Abhinav Arora 已提交
57 58
    AddInput("X", "(Tensor) The input tensor of increment operator");
    AddOutput("Out", "(Tensor) The output tensor of increment operator.");
Y
Yu Yang 已提交
59 60 61 62
    AddAttr<float>("step",
                   "(float, default 1.0) "
                   "The step size by which the "
                   "input tensor will be incremented.")
A
Abhinav Arora 已提交
63
        .SetDefault(1.0);
K
kexinzhao 已提交
64 65 66 67 68 69 70
    AddComment(R"DOC(
Increment Operator.

The equation is: 
$$Out = X + step$$

)DOC");
A
Abhinav Arora 已提交
71 72 73
  }
};

H
hong 已提交
74 75
template <typename T>
class IncrementGradOpMaker : public framework::SingleGradOpMaker<T> {
A
Abhinav Arora 已提交
76
 public:
H
hong 已提交
77
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
A
Abhinav Arora 已提交
78

79
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yang Yu 已提交
80
    grad_op->SetType("increment");
H
hong 已提交
81 82
    grad_op->SetInput("X", this->Output("Out"));
    grad_op->SetOutput("Out", this->Input("X"));
83
    grad_op->SetAttr("step", -BOOST_GET_CONST(float, this->GetAttr("step")));
A
Abhinav Arora 已提交
84 85 86 87 88 89 90
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
91 92
DECLARE_INFER_SHAPE_FUNCTOR(increment,
                            IncrementInferShapeFunctor,
93
                            PD_INFER_META(phi::IncrementInferMeta));
94 95 96
REGISTER_OPERATOR(increment,
                  ops::IncrementOp,
                  ops::IncrementOpMaker,
H
hong 已提交
97
                  ops::IncrementGradOpMaker<paddle::framework::OpDesc>,
98 99
                  ops::IncrementGradOpMaker<paddle::imperative::OpBase>,
                  IncrementInferShapeFunctor);