scale_op.cc 4.6 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Y
Yu Yang 已提交
2

L
Luo Tao 已提交
3 4 5
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Y
Yu Yang 已提交
6

L
Luo Tao 已提交
7
    http://www.apache.org/licenses/LICENSE-2.0
Y
Yu Yang 已提交
8

L
Luo Tao 已提交
9 10 11 12 13
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Y
Yu Yang 已提交
14

Y
Yi Wang 已提交
15
#include "paddle/fluid/operators/scale_op.h"
16
#include <string>
Y
Yu Yang 已提交
17

W
wanghuancoder 已提交
18 19 20 21 22 23 24 25 26 27 28 29 30
namespace paddle {
namespace framework {
class InferShapeContext;
class OpDesc;
}  // namespace framework
namespace imperative {
class OpBase;
}  // namespace imperative
namespace platform {
class CPUDeviceContext;
}  // namespace platform
}  // namespace paddle

Y
Yu Yang 已提交
31 32 33 34 35
namespace paddle {
namespace operators {

class ScaleOp : public framework::OperatorWithKernel {
 public:
36 37 38
  ScaleOp(const std::string &type, const framework::VariableNameMap &inputs,
          const framework::VariableNameMap &outputs,
          const framework::AttributeMap &attrs)
Y
Yu Yang 已提交
39 40
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

41
  void InferShape(framework::InferShapeContext *ctx) const override {
42 43
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "scale");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "scale");
44 45 46 47 48

    if (ctx->IsRuntime() && ctx->HasInput("ScaleTensor")) {
      auto scale = ctx->Inputs("ScaleTensor");
      PADDLE_ENFORCE_EQ(scale.size(), 1,
                        platform::errors::InvalidArgument(
49 50 51
                            "Input(ScaleTensor) size must be 1, "
                            "but received size is %d.",
                            scale.size()));
52 53
    }

Q
Qiao Longfei 已提交
54 55
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->ShareLoD("X", /*->*/ "Out");
Y
Yu Yang 已提交
56 57 58 59 60
  }
};

class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
Y
Yu Yang 已提交
61
  void Make() override {
62
    AddInput("X", "(Tensor) Input tensor of scale operator.");
63 64 65 66 67
    AddInput("ScaleTensor",
             "(Tensor) If provided, use this as "
             "scale factor, this has a higher priority than "
             "attr(scale), the shape of this tensor MUST BE 1.")
        .AsDispensable();
68 69
    AddOutput("Out", "(Tensor) Output tensor of scale operator.");
    AddComment(R"DOC(
Y
yi.wu 已提交
70 71
**Scale operator**

S
sneaxiy 已提交
72
Apply scaling and bias addition to the input tensor.
Y
Yu Yang 已提交
73

S
sneaxiy 已提交
74 75 76 77 78 79 80
if bias_after_scale=True:

$$Out = scale*X + bias$$

else:

$$Out = scale*(X + bias)$$
Y
Yu Yang 已提交
81
)DOC");
Y
yi.wu 已提交
82
    AddAttr<float>("scale", "The scaling factor of the scale operator.")
C
caoying03 已提交
83
        .SetDefault(1.0);
S
sneaxiy 已提交
84
    AddAttr<float>("bias", "The bias of the scale operator.").SetDefault(0.0);
S
sneaxiy 已提交
85 86 87 88 89
    AddAttr<bool>(
        "bias_after_scale",
        "Apply bias addition after or before scaling. It is useful for "
        "numeric stability in some circumstances.")
        .SetDefault(true);
Y
Yu Yang 已提交
90 91 92
  }
};

93 94
class ScaleOpVarTypeInference : public framework::VarTypeInference {
 public:
M
minqiyang 已提交
95
  void operator()(framework::InferVarTypeContext *ctx) const override {
96
    ctx->SyncTypeAndDataType("X", "Out");
97 98 99
  }
};

H
hong 已提交
100 101
template <typename T>
class ScaleGradMaker : public framework::SingleGradOpMaker<T> {
Y
Yu Yang 已提交
102
 public:
H
hong 已提交
103
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
104

105
  void Apply(GradOpPtr<T> grad_op) const override {
Y
Yu Yang 已提交
106
    grad_op->SetType("scale");
H
hong 已提交
107
    grad_op->SetInput("X", this->OutputGrad("Out"));
108 109 110
    if (this->HasInput("ScaleTensor") > 0) {
      grad_op->SetInput("ScaleTensor", this->Input("ScaleTensor"));
    }
H
hong 已提交
111 112
    grad_op->SetOutput("Out", this->InputGrad("X"));
    grad_op->SetAttr("scale", this->GetAttr("scale"));
S
sneaxiy 已提交
113
    grad_op->SetAttr("bias", 0.0f);
S
sneaxiy 已提交
114
    grad_op->SetAttr("bias_after_scale", true);
Y
Yu Yang 已提交
115 116 117
  }
};

118
DECLARE_INPLACE_OP_INFERER(ScaleOpInplaceInferer, {"X", "Out"});
Y
Yu Yang 已提交
119 120 121 122 123
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;

H
hong 已提交
124 125 126
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker,
                  ops::ScaleGradMaker<paddle::framework::OpDesc>,
                  ops::ScaleGradMaker<paddle::imperative::OpBase>,
127
                  ops::ScaleOpVarTypeInference, ops::ScaleOpInplaceInferer);
Q
QI JUN 已提交
128 129 130
REGISTER_OP_CPU_KERNEL(
    scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
131 132 133
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, uint8_t>,
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, int8_t>,
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, int16_t>,
Q
QI JUN 已提交
134 135
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, int>,
    ops::ScaleKernel<paddle::platform::CPUDeviceContext, int64_t>);