fake_quantize_op.cc 23.9 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/fake_quantize_op.h"
#include <string>
17 18 19
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/clip_op.h"
#include "paddle/fluid/platform/transform.h"
视言's avatar
视言 已提交
20 21 22 23

namespace paddle {
namespace operators {

24 25 26 27 28
template <typename T>
struct Compare {
 public:
  bool operator()(const T a, const T b) { return (std::abs(a) < std::abs(b)); }
};
29 30 31 32 33

template <typename T>
struct FindAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx, const T* in,
                  const int num, T* out) {
34
    *out = std::abs(*(std::max_element(in + 0, in + num, Compare<T>())));
35 36 37 38 39
  }
};

template struct FindAbsMaxFunctor<platform::CPUDeviceContext, float>;

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx, const T* in,
                  const int num, const int channel, T* out) {
    const int channel_size = num / channel;
    for (int i = 0; i < channel; i++) {
      auto* start = in + i * channel_size;
      auto* end = in + (i + 1) * channel_size;
      out[i] = std::abs(*(std::max_element(start, end, Compare<T>())));
    }
  }
};

template struct FindChannelAbsMaxFunctor<platform::CPUDeviceContext, float>;

55 56 57 58 59 60
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    T s = scale.data<T>()[0];
61
    T inv_s = inverse(s);
62 63 64 65
    platform::Transform<platform::CPUDeviceContext> trans;
    trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
          out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
    auto out_e = framework::EigenVector<T>::Flatten(*out);
66
    out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
67 68 69 70 71
  }
};

template struct ClipAndFakeQuantFunctor<platform::CPUDeviceContext, float>;

72 73 74 75 76 77
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    T s = scale.data<T>()[0];
78 79
    T inv_s = inverse(s);

80 81 82 83 84
    platform::Transform<platform::CPUDeviceContext> trans;
    trans(ctx, in.data<T>(), in.data<T>() + in.numel(),
          out->mutable_data<T>(ctx.GetPlace()), ClipFunctor<T>(-s, s));
    auto out_e = framework::EigenVector<T>::Flatten(*out);
    out_e.device(*ctx.eigen_device()) =
85
        (bin_cnt * inv_s * out_e).round() * s / static_cast<T>(bin_cnt);
86 87 88 89 90
  }
};
template struct ClipAndFakeQuantDequantFunctor<platform::CPUDeviceContext,
                                               float>;

91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int channel,
                  framework::Tensor* out) {
    auto* scale_data = scale.data<T>();
    auto* in_data = in.data<T>();
    auto* out_data = out->mutable_data<T>(ctx.GetPlace());
    const int channel_size = in.numel() / channel;
    platform::Transform<platform::CPUDeviceContext> trans;
    for (int i = 0; i < channel; i++) {
      T s = scale_data[i];
      auto* start = in_data + i * channel_size;
      auto* end = in_data + (i + 1) * channel_size;
      trans(ctx, start, end, out_data + i * channel_size,
            ClipFunctor<T>(-s, s));
    }
    for (int i = 0; i < channel; i++) {
      T s = scale_data[i];
111
      T inv_s = inverse(s);
112 113
      framework::Tensor one_channel_out = out->Slice(i, i + 1);
      auto out_e = framework::EigenVector<T>::Flatten(one_channel_out);
114
      out_e.device(*ctx.eigen_device()) = (bin_cnt * inv_s * out_e).round();
115 116 117 118 119 120 121
    }
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CPUDeviceContext,
                                               float>;

122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
template <typename T>
struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
    T* scale_arr = scales_arr->mutable_data<T>(ctx.GetPlace());
    int64_t it = iter.data<int64_t>()[0];
    int idx = it % window_size;
    T removed = scale_arr[idx];
    T cur = cur_scale.data<T>()[0];
    scale_arr[idx] = cur;

    T max = last_scale.data<T>()[0];
    if (max < cur) {
      max = cur;
    } else if (fabs(removed - max) < 1e-6) {
      int size = (it > window_size) ? window_size : it;
      FindAbsMaxFunctor<platform::CPUDeviceContext, T>()(ctx, scale_arr, size,
                                                         &max);
    }
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = max;
  }
};

template struct FindRangeAbsMaxFunctor<platform::CPUDeviceContext, float>;

150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext, T> {
  void operator()(const platform::CPUDeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
    T accum = in_accum.data<T>()[0];
    T state = in_state.data<T>()[0];
    T scale = cur_scale[0];

    state = rate * state + 1;
    accum = rate * accum + scale;
    scale = accum / state;

    out_state->mutable_data<T>(ctx.GetPlace())[0] = state;
    out_accum->mutable_data<T>(ctx.GetPlace())[0] = accum;
    out_scale->mutable_data<T>(ctx.GetPlace())[0] = scale;
  }
};

template struct FindMovingAverageAbsMaxFunctor<platform::CPUDeviceContext,
                                               float>;

174
class FakeQuantOrWithDequantAbsMaxOp : public framework::OperatorWithKernel {
视言's avatar
视言 已提交
175
 public:
176 177 178 179
  FakeQuantOrWithDequantAbsMaxOp(const std::string& type,
                                 const framework::VariableNameMap& inputs,
                                 const framework::VariableNameMap& outputs,
                                 const framework::AttributeMap& attrs)
视言's avatar
视言 已提交
180 181
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

182
  void InferShape(framework::InferShapeContext* ctx) const override {
183 184
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeQuantOrWithDequantAbsMaxOp");
185
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
186
                   "FakeQuantOrWithDequantAbsMaxOp");
187
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
188
                   "FakeQuantOrWithDequantAbsMaxOp");
视言's avatar
视言 已提交
189
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
190
    ctx->SetOutputDim("OutScale", {1});
视言's avatar
视言 已提交
191 192
    ctx->ShareLoD("X", /*->*/ "Out");
  }
193 194 195 196

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
197 198 199
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
200
  }
视言's avatar
视言 已提交
201 202
};

203 204
class FakeQuantOrWithDequantAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
视言's avatar
视言 已提交
205 206
 public:
  void Make() override {
207 208 209 210 211
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
    AddOutput("OutScale", "(Tensor) Current scale");
视言's avatar
视言 已提交
212 213
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
214
        .AddCustomChecker([](const int& bit_length) {
215 216 217 218 219
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
视言's avatar
视言 已提交
220 221
        });
    AddComment(R"DOC(
222
This is a Base Op which supports FakeQuantAbsMaxOpMaker and FakeQuantDequantAbsMaxOpMaker.
223
FakeQuantAbsMaxOp operator is used in the dynamic quantization.
视言's avatar
视言 已提交
224

225
$$scale = max(abs(X))$$
226 227
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
228

229
FakeQuantDequantAbsMaxOp operator does the abs_max quantization and then dequantization.
230 231 232 233 234

$$scale = max(abs(X))$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$

235 236 237
)DOC");
  }
};
视言's avatar
视言 已提交
238

Z
Zhen Wang 已提交
239 240 241 242 243
class FakeChannelWiseQuantizeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
244 245 246 247 248 249
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeChannelWiseQuantizeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeChannelWiseQuantizeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeChannelWiseQuantizeAbsMax");
Z
Zhen Wang 已提交
250
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
251
    ctx->SetOutputDim("OutScale", {ctx->GetInputDim("X")[0]});
Z
Zhen Wang 已提交
252 253 254 255 256 257
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
258 259
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
Z
Zhen Wang 已提交
260 261 262 263 264 265 266 267 268 269 270
  }
};

class FakeChannelWiseQuantizeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddOutput("Out",
              "(Tensor) Output of quantized low level tensor, "
              "but also saved as float data type.");
271
    AddOutput("OutScale", "(Tensor) Current channel wise scale");
Z
Zhen Wang 已提交
272 273 274
    AddAttr<int>("bit_length", "(int, default 8)")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
275 276 277 278 279
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
Z
Zhen Wang 已提交
280 281 282 283 284 285
        });
    AddComment(R"DOC(
The scale of FakeChannelWiseQuantize operator is a vector.
In detail, each channel of the input X has a scale value.

$$scale_c = max(abs(X_c))$$
Z
Zhen Wang 已提交
286 287
$$range = 2^{bit\_length - 1} - 1$$
$$Out_c = round(\frac{X_c * range} {scale_c})$$
Z
Zhen Wang 已提交
288
In above three formulas, the range value of c is as follow:
Z
Zhen Wang 已提交
289
$$0 \leq c \lt \ the\ channel\ number\ of\ X$$
Z
Zhen Wang 已提交
290 291 292 293
)DOC");
  }
};

294 295 296 297 298 299 300
class FakeQuantizeRangeAbsMaxOp : public framework::OperatorWithKernel {
 public:
  FakeQuantizeRangeAbsMaxOp(const std::string& type,
                            const framework::VariableNameMap& inputs,
                            const framework::VariableNameMap& outputs,
                            const framework::AttributeMap& attrs)
      : OperatorWithKernel(type, inputs, outputs, attrs) {}
视言's avatar
视言 已提交
301

302
  void InferShape(framework::InferShapeContext* ctx) const override {
303 304 305 306 307
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "FakeQuantizeRangeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeQuantizeRangeAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeQuantizeRangeAbsMax");
308 309 310 311 312 313 314 315
    if (ctx->HasOutput("OutScales")) {
      int window_size = ctx->Attrs().Get<int>("window_size");
      ctx->SetOutputDim("OutScales", {window_size});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }
视言's avatar
视言 已提交
316

317 318 319
 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
320 321 322
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
323 324
  }
};
视言's avatar
视言 已提交
325

326 327 328 329 330 331 332 333 334 335 336 337 338 339 340
class FakeQuantizeRangeAbsMaxOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("Iter", "Global step iteration.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutScales", "(Tensor) scale buffer.").AsDispensable();
    AddAttr<int>("window_size", "(int, default 10000) window range size.")
        .SetDefault(10000);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
341 342 343 344 345
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
346
        });
347 348 349 350
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
351 352
    AddComment(R"DOC(
FakeQuantize operator is used in static quantization.
视言's avatar
视言 已提交
353

354
$$scale = max(max(abs(x)), history_abs_max)$$
355 356
$$range = 2^{bit_length - 1} - 1$$
$$Out = round(X/scale * range)$$
视言's avatar
视言 已提交
357 358 359 360 361

)DOC");
  }
};

362 363
class FakeQuantOrWithDequantMovingAverageAbsMaxOp
    : public framework::OperatorWithKernel {
364
 public:
365 366 367 368
  FakeQuantOrWithDequantMovingAverageAbsMaxOp(
      const std::string& type, const framework::VariableNameMap& inputs,
      const framework::VariableNameMap& outputs,
      const framework::AttributeMap& attrs)
369 370 371
      : OperatorWithKernel(type, inputs, outputs, attrs) {}

  void InferShape(framework::InferShapeContext* ctx) const override {
372 373 374 375 376 377
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "FakeQuantOrWithDequantMovingAverageAbsMax");
378 379 380 381 382 383 384 385 386 387 388 389 390 391
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
    ctx->SetOutputDim("OutScale", {1});
    ctx->ShareLoD("X", /*->*/ "Out");
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
392 393 394
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"),
        ctx.device_context());
395 396 397
  }
};

398
class FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker
399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InScale", "Last scale.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("Out", "(Tensor) Output of quantized low level tensor.");
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<int>("bit_length", "(int, default 8), quantization bit number.")
        .SetDefault(8)
        .AddCustomChecker([](const int& bit_length) {
415 416 417 418 419
          PADDLE_ENFORCE_EQ(bit_length >= 1 && bit_length <= 16, true,
                            platform::errors::InvalidArgument(
                                "'bit_length' should be between 1 and 16, but "
                                "the received is %d",
                                bit_length));
420 421 422 423 424 425
        });
    AddAttr<bool>("is_test",
                  "(bool, default false) Set to true for inference only, false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
426
This is a Base Op which supports FakeQuantMovingAverageAbsMaxOp and FakeQuantDequantMovingAverageAbsMaxOp.
427
FakeQuantMovingAverageAbsMaxOp operator is used in the static quantization.
428

Z
Zhen Wang 已提交
429 430
$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
431 432
$$Out = round(X/scale * range)$$

433
FakeQuantDequantMovingAverageAbsMaxOp operator does the moving_average_abs_max quant and then dequant.
434 435 436 437 438

$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$range = 2^{bit\_length - 1} - 1$$
$$Out = round(X/scale * range) * scale / range$$

439 440 441 442
)DOC");
  }
};

Z
Zhen Wang 已提交
443 444 445 446 447
class MovingAverageAbsMaxScaleOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
448 449 450 451
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X",
                   "MovingAverageAbsMaxScale");
    OP_INOUT_CHECK(ctx->HasOutput("OutScale"), "Output", "OutScale",
                   "MovingAverageAbsMaxScale");
Z
Zhen Wang 已提交
452 453 454 455 456 457 458 459 460 461 462 463
    if (ctx->HasOutput("OutState")) {
      ctx->SetOutputDim("OutState", {1});
    }
    if (ctx->HasOutput("OutAccum")) {
      ctx->SetOutputDim("OutAccum", {1});
    }
    ctx->SetOutputDim("OutScale", {1});
  }

 protected:
  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
464 465
    return framework::OpKernelType(
        OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
Z
Zhen Wang 已提交
466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
  }
};

class MovingAverageAbsMaxScaleOpMaker
    : public framework::OpProtoAndCheckerMaker {
 public:
  void Make() override {
    AddInput("X", "(Tensor) Input is float data type.");
    AddInput("InAccum", "Last accum.").AsDispensable();
    AddInput("InState", "Last state.").AsDispensable();
    AddOutput("OutScale", " Current scale");
    AddOutput("OutState", "(Tensor) state buffer.").AsDispensable();
    AddOutput("OutAccum", "(Tensor) accum buffer.").AsDispensable();
    AddAttr<float>("moving_rate", "(float, default 0.9) moving rate.")
        .SetDefault(0.9);
    AddAttr<bool>("is_test",
                  "(bool, default false) Set true for inference only and false "
                  "for training. Some layers may run faster when this is true.")
        .SetDefault(false);
    AddComment(R"DOC(
MovingAverageAbsMaxScale operator is only used for calculating the quantization scale.
And it will not quantize the input tensor.

$$scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)$$
$$Out = X$$

)DOC");
  }
};

496 497 498 499 500 501
class FakeQuantDequantGradOp : public framework::OperatorWithKernel {
 public:
  using framework::OperatorWithKernel::OperatorWithKernel;

  void InferShape(framework::InferShapeContext* ctx) const override {
    auto out_grad_name = framework::GradVarName("Out");
502
    auto x_grad_name = framework::GradVarName("X");
503 504
    OP_INOUT_CHECK(ctx->HasInput(out_grad_name), "Input", out_grad_name,
                   "FakeQuantDequantGradOp");
505 506
    OP_INOUT_CHECK(ctx->HasOutput(x_grad_name), "Output", x_grad_name,
                   "FakeQuantDequantGradOp");
507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532

    ctx->SetOutputDim(x_grad_name, ctx->GetInputDim(out_grad_name));
  }

  framework::OpKernelType GetExpectedKernelType(
      const framework::ExecutionContext& ctx) const override {
    auto input_data_type = OperatorWithKernel::IndicateVarDataType(
        ctx, framework::GradVarName("Out"));
    return framework::OpKernelType(input_data_type, ctx.GetPlace());
  }
};

template <typename T>
class FakeQuantDequantGradMaker : public framework::SingleGradOpMaker<T> {
 public:
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

 protected:
  void Apply(GradOpPtr<T> grad_op) const override {
    grad_op->SetType("fake_quantize_dequantize_grad");
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
    grad_op->SetAttrMap(this->Attrs());
  }
};

视言's avatar
视言 已提交
533 534 535 536
}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
537 538
using CPU = paddle::platform::CPUDeviceContext;

H
hong 已提交
539
REGISTER_OPERATOR(
540 541
    fake_quantize_abs_max, ops::FakeQuantOrWithDequantAbsMaxOp,
    ops::FakeQuantOrWithDequantAbsMaxOpMaker,
H
hong 已提交
542 543
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
544 545
REGISTER_OP_CPU_KERNEL(fake_quantize_abs_max,
                       ops::FakeQuantizeAbsMaxKernel<CPU, float>);
视言's avatar
视言 已提交
546

547 548 549 550 551 552 553 554
REGISTER_OPERATOR(fake_quantize_dequantize_abs_max,
                  ops::FakeQuantOrWithDequantAbsMaxOp,
                  ops::FakeQuantOrWithDequantAbsMaxOpMaker,
                  ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
                  ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_abs_max,
                       ops::FakeQuantizeDequantizeAbsMaxKernel<CPU, float>);

H
hong 已提交
555 556 557 558 559
REGISTER_OPERATOR(
    fake_quantize_range_abs_max, ops::FakeQuantizeRangeAbsMaxOp,
    ops::FakeQuantizeRangeAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
560 561
REGISTER_OP_CPU_KERNEL(fake_quantize_range_abs_max,
                       ops::FakeQuantizeRangeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
562

H
hong 已提交
563 564 565 566 567 568
REGISTER_OPERATOR(
    fake_quantize_moving_average_abs_max,
    ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
    ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
569 570
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
                       ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
571

572 573 574 575 576
REGISTER_OPERATOR(fake_quantize_dequantize_moving_average_abs_max,
                  ops::FakeQuantOrWithDequantMovingAverageAbsMaxOp,
                  ops::FakeQuantOrWithDequantMovingAverageAbsMaxOpMaker,
                  ops::FakeQuantDequantGradMaker<paddle::framework::OpDesc>,
                  ops::FakeQuantDequantGradMaker<paddle::imperative::OpBase>);
577 578 579 580
REGISTER_OP_CPU_KERNEL(
    fake_quantize_dequantize_moving_average_abs_max,
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CPU, float>);

H
hong 已提交
581 582 583 584 585
REGISTER_OPERATOR(
    fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp,
    ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Z
Zhen Wang 已提交
586 587
REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
                       ops::FakeChannelWiseQuantizeAbsMaxKernel<CPU, float>);
Z
Zhen Wang 已提交
588

H
hong 已提交
589 590 591 592 593
REGISTER_OPERATOR(
    moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
    ops::MovingAverageAbsMaxScaleOpMaker,
    paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
    paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
Z
Zhen Wang 已提交
594 595
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
                       ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
596 597 598 599

REGISTER_OPERATOR(fake_quantize_dequantize_grad, ops::FakeQuantDequantGradOp);
REGISTER_OP_CPU_KERNEL(fake_quantize_dequantize_grad,
                       ops::FakeQuantDequantGradKernel<CPU, float>);