fake_quantize_op.h 17.6 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
18

视言's avatar
视言 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Z
Zhen Wang 已提交
21
#include "paddle/fluid/framework/tensor_util.h"
22
#include "paddle/fluid/memory/malloc.h"
23
#include "paddle/fluid/platform/transform.h"
24 25
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
视言's avatar
视言 已提交
26 27 28 29

namespace paddle {
namespace operators {

30 31
template <typename T>
inline HOSTDEVICE T inverse(T s) {
W
whs 已提交
32 33 34
  T eps = static_cast<T>(1e-6);
  T one = static_cast<T>(1.0);
  return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
35 36
}

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
  T xLower = floor(x);
  T xUpper = ceil(x);
  // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
  // even.
  T dLower = x - xLower;
  T dUpper = xUpper - x;
  return static_cast<T>(
      (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
          ? xLower
          : xUpper);
}

template <typename T>
class QuantTensorFunctor {
 public:
54 55
  explicit QuantTensorFunctor(const T bin_cnt, const T inv_s)
      : bin_cnt_(bin_cnt), inv_s_(inv_s) {}
56 57
  HOSTDEVICE T operator()(const T x) const {
    T out = bin_cnt_ * inv_s_ * x;
58
    out = roundWithTiesToEven(out);
59 60 61 62 63 64 65 66 67 68 69 70
    T max_bound = bin_cnt_;
    T min_bound = -bin_cnt_ - static_cast<T>(1);
    out = out > max_bound ? max_bound : out;
    out = out < min_bound ? min_bound : out;
    return out;
  }

 private:
  T bin_cnt_;
  T inv_s_;
};

71 72
template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
73
  void operator()(const DeviceContext &ctx, const T *in, const int num, T *out);
74
};
视言's avatar
视言 已提交
75 76

template <typename DeviceContext, typename T>
77
struct ClipAndFakeQuantFunctor {
78 79 80 81 82 83
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  framework::Tensor *out);
84 85
};

86 87
template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor {
88 89 90 91 92 93
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  framework::Tensor *out);
94 95
};

96 97
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
98 99 100 101 102 103 104
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &cur_scale,
                  const framework::Tensor &last_scale,
                  const framework::Tensor &iter,
                  const int window_size,
                  framework::Tensor *scales_arr,
                  framework::Tensor *out_scale);
105 106
};

107 108
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
109 110 111 112
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_tensor,
                  const int quant_axis,
                  T *out_abs_max);
113 114 115 116
};

template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
117 118 119 120 121 122 123
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
124 125
};

H
huangxu96 已提交
126 127
template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor {
128 129 130 131 132 133 134
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
H
huangxu96 已提交
135 136
};

137 138
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
139 140 141 142 143 144 145
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_accum,
                  const framework::Tensor &in_state,
                  const framework::Tensor &cur_scale,
                  framework::Tensor *out_state,
                  framework::Tensor *out_accum,
                  framework::Tensor *out_scale);
146 147
};

148
template <typename DeviceContext, typename T>
149
class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
视言's avatar
视言 已提交
150
 public:
151 152 153 154 155
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_s = out_scale->mutable_data<T>(context.GetPlace());
156 157

    int bit_length = context.Attr<int>("bit_length");
158
    int round_type = context.Attr<int>("round_type");
159 160
    int bin_cnt = std::pow(2, bit_length - 1) - 1;

161 162
    auto &dev_ctx = context.template device_context<DeviceContext>();
    const T *in_data = in->data<T>();
163
    FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
164
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
165 166 167 168 169
  }

  virtual ~FakeAbsMaxKernelBase() = default;

 protected:
170 171 172 173 174 175
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
176 177 178 179 180
};

template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
181 182 183 184 185 186 187 188
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
189 190 191 192 193 194 195
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel
    : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
196 197 198 199 200 201
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
202 203
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
204
  }
205
};
视言's avatar
视言 已提交
206

Z
Zhen Wang 已提交
207 208 209
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
 public:
210 211
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
Z
Zhen Wang 已提交
212

213 214
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
215 216 217
    out->mutable_data<T>(context.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
218
    int round_type = context.Attr<int>("round_type");
Z
Zhen Wang 已提交
219
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
220
    int quant_axis = context.Attr<int>("quant_axis");
221
    bool is_test = context.Attr<bool>("is_test");
Z
Zhen Wang 已提交
222

223
    auto &dev_ctx = context.template device_context<DeviceContext>();
224
    if (!is_test) {
225 226 227
      T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
      FindChannelAbsMaxFunctor<DeviceContext, T>()(
          dev_ctx, *in, quant_axis, out_scale_data);
228
    }
229
    ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
230
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
Z
Zhen Wang 已提交
231 232 233
  }
};

H
huangxu96 已提交
234 235 236 237
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
    : public framework::OpKernel<T> {
 public:
238 239 240 241 242 243
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
    auto &dev_ctx = context.template device_context<DeviceContext>();
H
huangxu96 已提交
244 245 246
    out->mutable_data<T>(dev_ctx.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
247
    int round_type = context.Attr<int>("round_type");
H
huangxu96 已提交
248 249 250
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
    int quant_axis = context.Attr<int>("quant_axis");

251 252
    FindChannelAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, *in, quant_axis, out_scale_data);
H
huangxu96 已提交
253 254

    ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
255
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
H
huangxu96 已提交
256 257 258
  }
};

259 260 261
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
 public:
262 263 264
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
视言's avatar
视言 已提交
265

266
    auto *out = context.Output<framework::Tensor>("Out");
267 268 269
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
视言's avatar
视言 已提交
270
    int bit_length = context.Attr<int>("bit_length");
271
    int round_type = context.Attr<int>("round_type");
视言's avatar
视言 已提交
272
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
273
    auto &dev_ctx = context.template device_context<DeviceContext>();
视言's avatar
视言 已提交
274

275 276
    // testing
    if (is_test) {
277 278
      ClipAndFakeQuantFunctor<DeviceContext, T>()(
          dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
279
      return;
视言's avatar
视言 已提交
280 281
    }

282
    // training
283 284 285
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    auto *out_scales = context.Output<framework::Tensor>("OutScales");
    auto *iter = context.Input<framework::Tensor>("Iter");
286 287 288 289 290

    int window_size = context.Attr<int>("window_size");
    out_scale->mutable_data<T>(context.GetPlace());

    framework::Tensor cur_scale;
291 292 293 294 295 296 297 298 299
    T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
    FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                               cur_scale,
                                               *in_scale,
                                               *iter,
                                               window_size,
                                               out_scales,
300
                                               out_scale);
301 302
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
303 304 305
  }
};

306
template <typename DeviceContext, typename T>
307
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
308
 public:
309 310 311 312
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
    auto *out = context.Output<framework::Tensor>("Out");
313 314 315 316
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
    int bit_length = context.Attr<int>("bit_length");
317
    int round_type = context.Attr<int>("round_type");
318
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
319
    auto &dev_ctx = context.template device_context<DeviceContext>();
320 321 322

    // testing
    if (is_test) {
323
      RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
324 325 326 327
      return;
    }

    // training
328 329
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
330
    auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
331
    T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
332

333 334
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
335

336 337 338
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
339 340 341 342 343
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

344 345 346 347 348 349 350 351
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
352

353
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
354
  }
355 356 357 358

  virtual ~FakeMovingAverageAbsMaxKernelBase() = default;

 protected:
359 360 361 362 363 364
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &in_scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
365 366 367 368 369
};

template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
370
 protected:
371 372 373 374 375 376 377 378
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
379 380 381 382 383 384
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
385
 protected:
386 387 388 389 390 391
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
392 393
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
394 395 396
  }
};

Z
Zhen Wang 已提交
397 398 399
template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
 public:
400 401 402
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto &dev_ctx = context.template device_context<DeviceContext>();
Z
Zhen Wang 已提交
403

404
    if (context.HasOutput("Out")) {
405
      auto *out = context.Output<framework::Tensor>("Out");
406 407 408 409
      out->mutable_data<T>(context.GetPlace());
      framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
    }

Z
Zhen Wang 已提交
410 411 412 413 414 415 416
    bool is_test = context.Attr<bool>("is_test");
    // testing
    if (is_test) {
      return;
    }

    // training
417 418
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
419
    auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
420
    T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
Z
Zhen Wang 已提交
421

422 423
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
Z
Zhen Wang 已提交
424

425 426 427
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
428 429 430 431 432
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

433 434 435 436 437 438 439 440
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
Z
Zhen Wang 已提交
441 442 443
  }
};

444
template <typename DeviceContext, typename T>
445
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
446
 public:
447 448
  void Compute(const framework::ExecutionContext &context) const override {
    auto *d_out =
449 450
        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto x_grad_name = framework::GradVarName("X");
451 452 453 454 455 456
    auto *d_x = context.Output<framework::LoDTensor>(x_grad_name);
    PADDLE_ENFORCE_NOT_NULL(d_x,
                            platform::errors::PreconditionNotMet(
                                "StrightThroughEstimatorGradKernel "
                                "doesn't have the output named %s.",
                                x_grad_name));
457 458 459 460 461 462 463

    // Initialize dx as same as d_out
    d_x->mutable_data<T>(context.GetPlace());
    framework::TensorCopy(*d_out, context.GetPlace(), d_x);
  }
};

视言's avatar
视言 已提交
464 465
}  // namespace operators
}  // namespace paddle