elementwise_sub_op.cc 5.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
G
gongweibao 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
G
gongweibao 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
G
gongweibao 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
G
gongweibao 已提交
14

W
Wu Yi 已提交
15
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
W
wanghuancoder 已提交
16

17
#include <string>
W
wanghuancoder 已提交
18

W
Wu Yi 已提交
19
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
20 21 22

namespace paddle {
namespace platform {
23 24
template <typename T>
struct complex;
25 26
}  // namespace platform
}  // namespace paddle
27

W
wanghuancoder 已提交
28 29 30 31 32 33 34 35 36 37 38 39
namespace paddle {
namespace framework {
class OpDesc;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle

40 41 42
namespace paddle {
namespace operators {

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64
class ElementwiseSubOpMaker : public ElementwiseOpMaker {
 protected:
  std::string GetName() const override { return "Sub"; }
  std::string GetEquation() const override { return "Out = X - Y"; }

  void AddInputX() override {
    AddInput("X",
             "(Variable), Tensor or LoDTensor of any dimensions. Its dtype "
             "should be int32, int64, float32, float64.");
  }

  void AddInputY() override {
    AddInput("Y",
             "(Variable), Tensor or LoDTensor of any dimensions. Its dtype "
             "should be int32, int64, float32, float64.");
  }

  std::string GetOpFuntionality() const override {
    return "Substract two tensors element-wise";
  }
};

H
hong 已提交
65 66
template <typename T>
class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
67
 public:
H
hong 已提交
68
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
69 70

 protected:
71
  void Apply(GradOpPtr<T> op) const override {
72
    op->SetType("elementwise_sub_grad_grad");
H
hong 已提交
73 74 75 76
    op->SetInput("Y", this->Input("Y"));
    op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
    op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
    op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
77

H
hong 已提交
78
    op->SetAttrMap(this->Attrs());
79

H
hong 已提交
80
    op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
81 82 83 84 85 86
  }
};

}  // namespace operators
}  // namespace paddle

87
REGISTER_ELEMWISE_GRAD_MAKER(elementwise_sub, Sub);
88
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_sub, Sub);
89

90 91
namespace ops = paddle::operators;

H
hong 已提交
92
REGISTER_OPERATOR(
93 94
    elementwise_sub_grad, ops::ElementwiseOpGrad,
    ops::ElementwiseGradOpInplaceInferer, ops::ElementwiseGradNoBufVarsInferer,
H
hong 已提交
95 96
    ops::ElementwiseSubDoubleGradMaker<paddle::framework::OpDesc>,
    ops::ElementwiseSubDoubleGradMaker<paddle::imperative::OpBase>);
97
REGISTER_OPERATOR(elementwise_sub_grad_grad,
98
                  ops::ElementwiseOpDoubleGradWithoutDXDY,
99 100
                  ops::ElementwiseDoubleGradOpInplaceInferer,
                  ops::ElementwiseDoubleGradNoBufVarsInferer);
101

G
gongweibao 已提交
102 103
REGISTER_OP_CPU_KERNEL(
    elementwise_sub,
Q
QI JUN 已提交
104 105 106
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
107 108
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
109
                              paddle::platform::complex<float>>,
110
    ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
111
                              paddle::platform::complex<double>>);
G
gongweibao 已提交
112 113
REGISTER_OP_CPU_KERNEL(
    elementwise_sub_grad,
Q
QI JUN 已提交
114 115 116
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
117 118
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
119
                                  paddle::platform::complex<float>>,
120
    ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
121
                                  paddle::platform::complex<double>>);
122 123 124 125 126 127 128 129 130
REGISTER_OP_CPU_KERNEL(
    elementwise_sub_grad_grad,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
                                        float>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
                                        double>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
                                        int>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
131 132
                                        int64_t>,
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
133
                                        paddle::platform::complex<float>>,
134
    ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
135
                                        paddle::platform::complex<double>>);
136 137 138 139 140 141 142 143 144

REGISTER_OP_VERSION(elementwise_sub)
    .AddCheckpoint(
        R"ROC(Register elementwise_sub for adding the attribute of Scale_y)ROC",
        paddle::framework::compatible::OpVersionDesc().NewAttr(
            "Scale_y",
            "In order to support the function of scaling the input Y when "
            "using the operator of elementwise_sub.",
            1.0f));