quantize_linear_op.cc 9.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/quantize_linear_op.h"
13

14 15 16
#include <algorithm>
#include <string>
#include <vector>
17

18 19 20 21 22 23 24 25 26
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/transform.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/impl/clip_kernel_impl.h"

namespace paddle {
namespace operators {

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
template <typename T>
struct DequantizeFunctor<phi::CPUContext, T> {
  void operator()(const phi::CPUContext &dev_ctx,
                  const phi::DenseTensor *in,
                  const phi::DenseTensor *scale,
                  T max_range,
                  phi::DenseTensor *out) {
    auto in_e = framework::EigenVector<T>::Flatten(*in);
    const T *scale_factor = scale->data<T>();
    auto out_e = framework::EigenVector<T>::Flatten(*out);

    auto &dev = *dev_ctx.eigen_device();
    out_e.device(dev) = in_e * scale_factor[0] / max_range;
  }
};

43
template <typename T>
L
Leo Chen 已提交
44 45
struct ChannelDequantizeFunctorV2<phi::CPUContext, T> {
  void operator()(const phi::CPUContext &dev_ctx,
46 47
                  const phi::DenseTensor *in,
                  const phi::DenseTensor *scale,
48 49
                  T max_range,
                  const int quant_axis,
50
                  phi::DenseTensor *out) {
51 52 53 54
    // Dequant op is before quantized op
    // Dequantize the weight of quantized op
    auto in_dims = in->dims();
    const int64_t channel = in_dims[quant_axis];
55
    const T *scale_factor = scale->data<T>();
56 57 58
    if (quant_axis == 0) {
      for (int64_t i = 0; i < channel; i++) {
        T s = scale_factor[i];
59 60
        phi::DenseTensor one_channel_in = in->Slice(i, i + 1);
        phi::DenseTensor one_channel_out = out->Slice(i, i + 1);
61 62
        auto in_e = framework::EigenVector<T>::Flatten(one_channel_in);
        auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
63
        auto &dev = *dev_ctx.eigen_device();
64 65 66 67 68 69 70 71 72
        out_e.device(dev) = in_e * s / max_range;
      }
    } else if (quant_axis == 1) {
      int64_t out_iter = 1;
      for (int i = 0; i < quant_axis; i++) {
        out_iter *= in_dims[i];
      }
      int64_t step_i = in->numel() / out_iter;
      int64_t step_j = in->numel() / (out_iter * channel);
73
      auto *in_data = in->data<T>();
74
      auto *out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
75 76
      for (int64_t i = 0; i < out_iter; i++) {
        for (int64_t j = 0; j < channel; j++) {
77 78
          auto *cur_in = in_data + i * step_i + j * step_j;
          auto *cur_out = out_data + i * step_i + j * step_j;
79 80 81 82 83 84 85 86 87 88 89 90
          T s = scale_factor[j];
          for (int64_t k = 0; k < step_j; k++) {
            *cur_out = (*cur_in) * s / max_range;
            ++cur_in;
            ++cur_out;
          }
        }
      }
    }
  }
};

91 92 93 94 95
template struct DequantizeFunctor<phi::CPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::CPUContext, float>;
template struct DequantizeFunctor<phi::CPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext,
                                           phi::dtype::float16>;
L
Leo Chen 已提交
96 97
template struct ChannelDequantizeFunctorV2<phi::CPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::CPUContext, double>;
98 99 100 101

class QuantizeLinearOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;
102
  void InferShape(framework::InferShapeContext *ctx) const override {
103 104
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "QuantizeLinear");
    OP_INOUT_CHECK(ctx->HasInput("Scale"), "Input", "Scale", "QuantizeLinear");
105 106
    OP_INOUT_CHECK(
        ctx->HasInput("ZeroPoint"), "Input", "ZeroPoint", "QuantizeLinear");
107 108 109 110 111 112 113 114 115 116
    OP_INOUT_CHECK(ctx->HasOutput("Y"), "Output", "Y", "QuantizeLinear");
    ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
    int quant_axis = ctx->Attrs().Get<int>("quant_axis");
    if (ctx->HasOutput("OutScale")) {
      if (quant_axis < 0) {
        ctx->SetOutputDim("OutScale", {1});
      } else {
        ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
      }
    }
117 118 119 120 121 122
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
123 124 125 126
    ctx->ShareLoD("X", /*->*/ "Y");
  }

 protected:
127
  phi::KernelKey GetExpectedKernelType(
128
      const framework::ExecutionContext &ctx) const override {
129 130
    return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "X"),
                          ctx.GetPlace());
131 132 133 134 135 136 137 138 139 140 141 142
  }
};

class QuantizeLinearOpMaker : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("Scale", "(Tensor) Input is float data type.");
    AddInput("ZeroPoint", "(Tensor) Input is float data type.");
    AddOutput("Y",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
143 144 145 146 147 148 149 150 151 152 153 154 155 156 157
    AddInput("InAccum", "Last accum.")
        .AsDispensable()
        .AsExtra();  // only qat use
    AddInput("InState", "Last state.")
        .AsDispensable()
        .AsExtra();  // only qat use
    AddOutput("OutState", "(Tensor) state buffer.")
        .AsDispensable()
        .AsExtra();  // only qat use
    AddOutput("OutAccum", "(Tensor) accum buffer.")
        .AsDispensable()
        .AsExtra();  // only qat use
    AddOutput("OutScale", "(Tensor) Current scale")
        .AsDispensable()
        .AsExtra();  // only qat use
158 159 160 161 162
    AddAttr<int>("quant_axis",
                 "(int, default 0) The axis for quantization. "
                 "For conv2d, depthwise_conv2d, conv2d_transpose "
                 "and mul, the quant_axis is equal to the cout axis.")
        .SetDefault(0)
163
        .AddCustomChecker([](const int &quant_axis) {
164
          PADDLE_ENFORCE_EQ(
165 166
              quant_axis == 0 || quant_axis == 1 || quant_axis == -1,
              true,
167 168 169 170 171 172 173
              platform::errors::InvalidArgument(
                  "'quant_axis' should be 0 or 1, but "
                  "the received is %d",
                  quant_axis));
        });
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
174 175 176
        .AddCustomChecker([](const int &bit_length) {
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16,
                            true,
177 178 179 180 181
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
        });
182 183 184 185 186 187 188
    AddAttr<int>(
        "round_type",
        "(int, default 0) The round type of fp32 to int."
        "0: rounding to nearest ties to even. Eg: round(1.5)=2, round(2.5)=2"
        "1: rounding to nearest ties away from zero. Eg: round(1.5)=2, "
        "round(2.5)=3")
        .SetDefault(0)
189 190 191 192 193 194 195 196 197
        .AddCustomChecker([](const int &round_type) {
          PADDLE_ENFORCE_EQ(
              round_type == 0 || round_type == 1,
              true,
              platform::errors::InvalidArgument(
                  "'round_type' should be 0 or 1, 0 rounding to "
                  "nearest ties to even and 1 is rounding to nearest "
                  "ties away from zero.but the received is %d",
                  round_type));
198
        });
199 200 201 202
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(true);
203 204 205 206 207 208
    AddAttr<bool>(
        "only_observer",
        "(bool, default false) Whether to only observer or not. If "
        "only_observer=false, it will calculate fake quant or dequant output. "
        "If only_observer=true, it will only calibrate scale information.")
        .SetDefault(false);
209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
    AddComment(R"DOC(
The scale of QuantizeLinear operator is a vector.
In detail, each channel of the input X has a scale value.
$$scale_c = max(abs(X_c))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
In above three formulas, the range value of c is as follow:
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
)DOC");
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
L
Leo Chen 已提交
225
using CPU = phi::CPUContext;
226 227

REGISTER_OPERATOR(
228 229 230
    quantize_linear,
    ops::QuantizeLinearOp,
    ops::QuantizeLinearOpMaker,
231 232 233 234 235 236
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OP_CPU_KERNEL(quantize_linear, ops::QuantizeLinearKernel<CPU, float>);

REGISTER_OPERATOR(
237 238 239
    dequantize_linear,
    ops::QuantizeLinearOp,
    ops::QuantizeLinearOpMaker,
240 241 242 243
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OP_CPU_KERNEL(dequantize_linear,
244 245 246
                       ops::DeQuantizeLinearKernel<CPU, float>,
                       ops::DeQuantizeLinearKernel<CPU, int8_t>,
                       ops::DeQuantizeLinearKernel<CPU, double>);