fake_quantize_op.cc 27.0 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fake_quantize_op.h"
16
#include <algorithm>
视言's avatar
视言 已提交
17
#include <string>
18 19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
视言's avatar
视言 已提交
21 22 23 24

namespace paddle {
namespace operators {

25 26 27 28 29
template <typename T>
struct Compare {
 public:
  bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); }
};
30 31 32 33 34

template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx, const T* in,
                  const int num, T* out) {
35
    *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
36 37 38 39 40
  }
};

template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;

41 42
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in_tensor, const int quant_axis,
                  T* out_abs_max) {
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
    auto* in_data = in_tensor.data<T>();
    auto in_dims = in_tensor.dims();
    const int64_t channel = in_dims[quant_axis];
    if (quant_axis == 0) {
      const int64_t channel_size = in_tensor.numel() / channel;
      for (int64_t i = 0; i < channel; i++) {
        auto* start = in_data + i * channel_size;
        auto* end = in_data + (i + 1) * channel_size;
        out_abs_max[i] =
            std::abs(*(std::max_element(start, end, Compare<T>())));
      }
    } else if (quant_axis == 1) {
      for (int64_t i = 0; i < channel; i++) {
        out_abs_max[i] = 0;
      }
      const int64_t step_i = in_tensor.numel() / in_dims[0];
      const int64_t step_j = in_tensor.numel() / (in_dims[0] * in_dims[1]);
      for (int64_t i = 0; i < in_dims[0]; i++) {
        for (int64_t j = 0; j < in_dims[1]; j++) {
          auto* start = in_data + i * step_i + j * step_j;
          auto* end = in_data + i * step_i + (j + 1) * step_j;
          T abs_max = std::abs(*(std::max_element(start, end, Compare<T>())));
          out_abs_max[j] = std::max(out_abs_max[j], abs_max);
        }
      }
78 79 80 81 82 83
    }
  }
};

template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;

84 85 86 87 88 89
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    T s = scale.data<T>()[0];
90
    T inv_s = inverse(s);
91 92 93 94
    platform::Transform<platform::CPUDeviceContext> trans;
    trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
          out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
    auto out_e = framework::EigenVector<T>::Flatten(*out);
95
    out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
96 97 98 99 100
  }
};

template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;

101 102 103 104 105 106
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    T s = scale.data<T>()[0];
107 108
    T inv_s = inverse(s);

109 110 111 112 113
    platform::Transform<platform::CPUDeviceContext> trans;
    trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
          out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
    auto out_e = framework::EigenVector<T>::Flatten(*out);
    out_e.device(*ctx.eigen_device()) =
114
        (bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
115 116 117 118 119
  }
};
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
                                               float>;

120 121 122 123
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
124
                  const int bin_cnt, const int quant_axis,
125
                  framework::Tensor* out) {
126 127 128 129 130 131 132
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
133 134 135
    auto* scale_data = scale.data<T>();
    auto* in_data = in.data<T>();
    auto* out_data = out->mutable_data<T>(ctx.GetPlace());
136 137
    auto in_dims = in.dims();
    const int64_t channel = in_dims[quant_axis];
138
    platform::Transform<platform::CPUDeviceContext> trans;
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
    if (quant_axis == 0) {
      const int64_t channel_size = in.numel() / channel;
      for (int64_t i = 0; i < channel; i++) {
        T s = scale_data[i];
        auto* start = in_data + i * channel_size;
        auto* end = in_data + (i + 1) * channel_size;
        trans(ctx, start, end, out_data + i * channel_size,
              ClipFunctor<T>(-s, s));
      }
      for (int64_t i = 0; i < channel; i++) {
        T s = scale_data[i];
        T inv_s = inverse(s);
        framework::Tensor one_channel_out = out->Slice(i, i + 1);
        auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
        out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
      }
    } else if (quant_axis == 1) {
      const int64_t step_i = in.numel() / in_dims[0];
      const int64_t step_j = in.numel() / (in_dims[0] * in_dims[1]);
      for (int i = 0; i < in_dims[0]; i++) {
        for (int j = 0; j < in_dims[1]; j++) {
          T s = scale_data[j];
          T inv_s = inverse(s);
          auto* start = in_data + i * step_i + j * step_j;
          auto* end = in_data + i * step_i + (j + 1) * step_j;
          auto* cur_out_data = out_data + i * step_i + j * step_j;
          trans(ctx, start, end, cur_out_data, ClipFunctor<T>(-s, s));
          for (int k = 0; k < step_j; k++) {
            cur_out_data[k] = std::round(bin_cnt * inv_s * cur_out_data[k]);
          }
        }
      }
171 172 173 174 175 176 177
    }
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
                                               float>;

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
    T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
    int64_t it = iter.data<int64_t>()[0];
    int idx = it % window_size;
    T removed = scale_arr[idx];
    T cur = cur_scale.data<T>()[0];
    scale_arr[idx] = cur;

    T max = last_scale.data<T>()[0];
    if (max < cur) {
      max = cur;
    } else if (fabs(removed - max) < 1e-6) {
      int size = (it > window_size) ? window_size : it;
      FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
                                                         &max);
    }
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
  }
};

template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;

206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
    T accum = in_accum.data<T>()[0];
    T state = in_state.data<T>()[0];
    T scale = cur_scale[0];

    state = rate * state + 1;
    accum = rate * accum + scale;
    scale = accum / state;

    out_state->mutable_data<T>(ctx.GetPlace())[0] = state;
    out_accum->mutable_data<T>(ctx.GetPlace())[0] = accum;
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = scale;
  }
};

template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
                                               float>;

230
class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
视言's avatar
视言 已提交
231
 public:
232 233 234 235
  FakeQuantOrWithDequantAbsMaxOp(const std::string& type,
                                 const framework::VariableNameMap& inputs,
                                 const framework::VariableNameMap& outputs,
                                 const framework::AttributeMap& attrs)
视言's avatar
视言 已提交
236 237
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

238
  void InferShape(framework::InferShapeContext* ctx) const override {
239 240
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeQuantOrWithDequantAbsMaxOp");
241
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
242
                   "FakeQuantOrWithDequantAbsMaxOp");
243
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
244
                   "FakeQuantOrWithDequantAbsMaxOp");
视言's avatar
视言 已提交
245
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
246
    ctx->SetOutputDim("OutScale", {1});
视言's avatar
视言 已提交
247 248
    ctx->ShareLoD("X", /*->*/ "Out");
  }
249 250 251 252

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
253 254 255
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
256
  }
视言's avatar
视言 已提交
257 258
};

259 260
class FakeQuantOrWithDequantAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
视言's avatar
视言 已提交
261 262
 public:
  void Make() override {
263 264 265 266 267
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
    AddOutput("OutScale", "(Tensor) Current scale");
视言's avatar
视言 已提交
268 269
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
270
        .AddCustomChecker([](const int& bit_length) {
271 272 273 274 275
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
视言's avatar
视言 已提交
276 277
        });
    AddComment(R"DOC(
278
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
279
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
视言's avatar
视言 已提交
280

281
$$scale = max(abs(X))$$
282 283
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
284

285
FakeQuantDequantAbsMaxOp operator does the abs_max quantization and then dequantization.
286 287 288 289 290

$$scale = max(abs(X))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$

291 292 293
)DOC");
  }
};
视言's avatar
视言 已提交
294

Z
Zhen Wang 已提交
295 296 297 298 299
class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
300 301 302 303 304 305
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeChannelWiseQuantizeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeChannelWiseQuantizeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeChannelWiseQuantizeAbsMax");
306
    int quant_axis = ctx->Attrs().Get<int>("quant_axis");
Z
Zhen Wang 已提交
307
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
308
    ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[quant_axis]});
Z
Zhen Wang 已提交
309 310 311 312 313 314
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
315 316
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
Z
Zhen Wang 已提交
317 318 319 320 321 322 323 324 325 326 327
  }
};

class FakeChannelWiseQuantizeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
328
    AddOutput("OutScale", "(Tensor) Current channel wise scale");
329 330 331 332 333 334 335 336 337 338 339 340
    AddAttr<int>("quant_axis",
                 "(int, default 0) The axis for quantization. "
                 "For conv2d, depthwise_conv2d, conv2d_transpose "
                 "and mul, the quant_axis is equal to the cout axis.")
        .SetDefault(0)
        .AddCustomChecker([](const int& quant_axis) {
          PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
                            platform::errors::InvalidArgument(
                                "'quant_axis' should be 0 or 1, but "
                                "the received is %d",
                                quant_axis));
        });
Z
Zhen Wang 已提交
341 342 343
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
344 345 346 347 348
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
Z
Zhen Wang 已提交
349 350 351 352 353 354
        });
    AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.

$$scale_c = max(abs(X_c))$$
Z
Zhen Wang 已提交
355 356
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
Z
Zhen Wang 已提交
357
In above three formulas, the range value of c is as follow:
Z
Zhen Wang 已提交
358
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
Z
Zhen Wang 已提交
359 360 361 362
)DOC");
  }
};

363 364 365 366 367 368 369
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  FakeQuantizeRangeAbsMaxOp(const std::string& type,
                            const framework::VariableNameMap& inputs,
                            const framework::VariableNameMap& outputs,
                            const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}
视言's avatar
视言 已提交
370

371
  void InferShape(framework::InferShapeContext* ctx) const override {
372 373 374 375 376
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeQuantizeRangeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeQuantizeRangeAbsMax");
377 378 379 380 381 382 383 384
    if (ctx->HasOutput("OutScales")) {
      int window_size = ctx->Attrs().Get<int>("window_size");
      ctx->SetOutputDim("OutScales", {window_size});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }
视言's avatar
视言 已提交
385

386 387 388
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
389 390 391
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
392 393
  }
};
视言's avatar
视言 已提交
394

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409
class FakeQuantizeRangeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("Iter", "Global step iteration.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
    AddAttr<int>("window_size", "(int, default 10000) window range size.")
        .SetDefault(10000);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
410 411 412 413 414
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
415
        });
416 417 418 419
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
420 421
    AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
视言's avatar
视言 已提交
422

423
$$scale = max(max(abs(x)), history_abs_max)$$
424 425
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
426 427 428 429 430

)DOC");
  }
};

431 432
class FakeQuantOrWithDequantMovingAverageAbsMaxOp
    : public framework::OperatorWithKernel {
433
 public:
434 435 436 437
  FakeQuantOrWithDequantMovingAverageAbsMaxOp(
      const std::string& type, const framework::VariableNameMap& inputs,
      const framework::VariableNameMap& outputs,
      const framework::AttributeMap& attrs)
438 439 440
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
441 442 443 444 445 446
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
447 448 449 450 451 452 453 454 455 456 457 458 459 460
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
461 462 463
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
464 465 466
  }
};

467
class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
484 485 486 487 488
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
489 490 491 492 493 494
        });
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
495
This is a Base Op which supports FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
496
FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization.
497

Z
Zhen Wang 已提交
498 499
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
500 501
$$Out = round(X/scale * range)$$

502
FakeQuantDequantMovingAverageAbsMaxOp operator does the moving_average_abs_max quant and then dequant.
503 504 505 506 507

$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$

508 509 510 511
)DOC");
  }
};

Z
Zhen Wang 已提交
512 513 514 515 516
class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
517 518 519 520
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "MovingAverageAbsMaxScale");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "MovingAverageAbsMaxScale");
Z
Zhen Wang 已提交
521 522 523 524 525 526 527 528 529 530 531 532
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("OutScale", {1});
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
533 534
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
Z
Zhen Wang 已提交
535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564
  }
};

class MovingAverageAbsMaxScaleOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set true for inference only and false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
MovingAverageAbsMaxScale operator is only used for calculating the quantization scale.
And it will not quantize the input tensor.

$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$Out = X$$

)DOC");
  }
};

565 566 567 568 569 570
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
571
    auto x_grad_name = framework::GradVarName("X");
572 573
    OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
                   "FakeQuantDequantGradOp");
574 575
    OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
                   "FakeQuantDequantGradOp");
576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601

    ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("fake_quantize_dequantize_grad");
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

视言's avatar
视言 已提交
602 603 604 605
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
606 607
using CPU = paddle::platform::CPUDeviceContext;

H
hong 已提交
608
REGISTER_OPERATOR(
609 610
    fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
    ops::FakeQuantOrWithDequantAbsMaxOpMaker,
H
hong 已提交
611 612
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
613 614
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
                       ops::FakeQuantizeAbsMaxKernel<CPU, float>);
视言's avatar
视言 已提交
615

616 617 618 619 620 621 622 623
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
                  ops::FakeQuantOrWithDequantAbsMaxOp,
                  ops::FakeQuantOrWithDequantAbsMaxOpMaker,
                  ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
                  ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
                       ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);

H
hong 已提交
624 625 626 627 628
REGISTER_OPERATOR(
    fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
    ops::FakeQuantizeRangeAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
629 630
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
                       ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
631

H
hong 已提交
632 633 634 635 636 637
REGISTER_OPERATOR(
    fake_quantize_moving_average_abs_max,
    ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
    ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
638 639
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
                       ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
640

641 642 643 644 645
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
                  ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
                  ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
                  ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
                  ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
646 647 648 649
REGISTER_OP_CPU_KERNEL(
    fake_quantize_dequantize_moving_average_abs_max,
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);

H
hong 已提交
650 651 652 653 654
REGISTER_OPERATOR(
    fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp,
    ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Z
Zhen Wang 已提交
655 656
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
                       ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
657

H
hong 已提交
658 659 660 661 662
REGISTER_OPERATOR(
    moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
    ops::MovingAverageAbsMaxScaleOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Z
Zhen Wang 已提交
663 664
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
                       ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
665 666 667 668

REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
                       ops::FakeQuantDequantGradKernel<CPU, float>);