fake_quantize_op.h 17.6 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
18

视言's avatar
视言 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Z
Zhen Wang 已提交
21
#include "paddle/fluid/framework/tensor_util.h"
22
#include "paddle/fluid/memory/malloc.h"
23
#include "paddle/fluid/platform/transform.h"
24 25
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
视言's avatar
视言 已提交
26 27 28 29

namespace paddle {
namespace operators {

30 31
template <typename T>
inline HOSTDEVICE T inverse(T s) {
W
whs 已提交
32 33 34
  T eps = static_cast<T>(1e-6);
  T one = static_cast<T>(1.0);
  return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
35 36
}

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
  T xLower = floor(x);
  T xUpper = ceil(x);
  // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
  // even.
  T dLower = x - xLower;
  T dUpper = xUpper - x;
  return static_cast<T>(
      (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
          ? xLower
          : xUpper);
}

template <typename T>
class QuantTensorFunctor {
 public:
54 55
  explicit QuantTensorFunctor(const T bin_cnt, const T inv_s)
      : bin_cnt_(bin_cnt), inv_s_(inv_s) {}
56 57
  HOSTDEVICE T operator()(const T x) const {
    T out = bin_cnt_ * inv_s_ * x;
58
    out = roundWithTiesToEven(out);
59 60 61 62 63 64 65 66 67 68 69 70
    T max_bound = bin_cnt_;
    T min_bound = -bin_cnt_ - static_cast<T>(1);
    out = out > max_bound ? max_bound : out;
    out = out < min_bound ? min_bound : out;
    return out;
  }

 private:
  T bin_cnt_;
  T inv_s_;
};

71 72
template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
73
  void operator()(const DeviceContext &ctx, const T *in, const int num, T *out);
74
};
视言's avatar
视言 已提交
75 76

template <typename DeviceContext, typename T>
77
struct ClipAndFakeQuantFunctor {
78 79 80 81 82 83
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  framework::Tensor *out);
84 85
};

86 87
template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor {
88 89 90 91 92 93
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  framework::Tensor *out);
94 95
};

96 97
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
98 99 100 101 102 103 104
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &cur_scale,
                  const framework::Tensor &last_scale,
                  const framework::Tensor &iter,
                  const int window_size,
                  framework::Tensor *scales_arr,
                  framework::Tensor *out_scale);
105 106
};

107 108
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
109 110 111 112
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_tensor,
                  const int quant_axis,
                  T *out_abs_max);
113 114 115 116
};

template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
117 118 119 120 121 122 123
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
124 125
};

H
huangxu96 已提交
126 127
template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor {
128 129 130 131 132 133 134
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
H
huangxu96 已提交
135 136
};

137 138
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
139 140 141
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_accum,
                  const framework::Tensor &in_state,
142 143
                  const T *cur_scale,
                  const float rate,
144 145 146
                  framework::Tensor *out_state,
                  framework::Tensor *out_accum,
                  framework::Tensor *out_scale);
147 148
};

149
template <typename DeviceContext, typename T>
150
class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
视言's avatar
视言 已提交
151
 public:
152 153 154 155 156
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_s = out_scale->mutable_data<T>(context.GetPlace());
157 158

    int bit_length = context.Attr<int>("bit_length");
159
    int round_type = context.Attr<int>("round_type");
160 161
    int bin_cnt = std::pow(2, bit_length - 1) - 1;

162 163
    auto &dev_ctx = context.template device_context<DeviceContext>();
    const T *in_data = in->data<T>();
164
    FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
165
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
166 167 168 169 170
  }

  virtual ~FakeAbsMaxKernelBase() = default;

 protected:
171 172 173 174 175 176
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
177 178 179 180 181
};

template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
182 183 184 185 186 187 188 189
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
190 191 192 193 194 195 196
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel
    : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
197 198 199 200 201 202
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
203 204
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
205
  }
206
};
视言's avatar
视言 已提交
207

Z
Zhen Wang 已提交
208 209 210
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
 public:
211 212
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
Z
Zhen Wang 已提交
213

214 215
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
216 217 218
    out->mutable_data<T>(context.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
219
    int round_type = context.Attr<int>("round_type");
Z
Zhen Wang 已提交
220
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
221
    int quant_axis = context.Attr<int>("quant_axis");
222
    bool is_test = context.Attr<bool>("is_test");
Z
Zhen Wang 已提交
223

224
    auto &dev_ctx = context.template device_context<DeviceContext>();
225
    if (!is_test) {
226 227 228
      T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
      FindChannelAbsMaxFunctor<DeviceContext, T>()(
          dev_ctx, *in, quant_axis, out_scale_data);
229
    }
230
    ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
231
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
Z
Zhen Wang 已提交
232 233 234
  }
};

H
huangxu96 已提交
235 236 237 238
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
    : public framework::OpKernel<T> {
 public:
239 240 241 242 243 244
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
    auto &dev_ctx = context.template device_context<DeviceContext>();
H
huangxu96 已提交
245 246 247
    out->mutable_data<T>(dev_ctx.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
248
    int round_type = context.Attr<int>("round_type");
H
huangxu96 已提交
249 250 251
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
    int quant_axis = context.Attr<int>("quant_axis");

252 253
    FindChannelAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, *in, quant_axis, out_scale_data);
H
huangxu96 已提交
254 255

    ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
256
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
H
huangxu96 已提交
257 258 259
  }
};

260 261 262
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
 public:
263 264 265
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
视言's avatar
视言 已提交
266

267
    auto *out = context.Output<framework::Tensor>("Out");
268 269 270
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
视言's avatar
视言 已提交
271
    int bit_length = context.Attr<int>("bit_length");
272
    int round_type = context.Attr<int>("round_type");
视言's avatar
视言 已提交
273
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
274
    auto &dev_ctx = context.template device_context<DeviceContext>();
视言's avatar
视言 已提交
275

276 277
    // testing
    if (is_test) {
278 279
      ClipAndFakeQuantFunctor<DeviceContext, T>()(
          dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
280
      return;
视言's avatar
视言 已提交
281 282
    }

283
    // training
284 285 286
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    auto *out_scales = context.Output<framework::Tensor>("OutScales");
    auto *iter = context.Input<framework::Tensor>("Iter");
287 288 289 290 291

    int window_size = context.Attr<int>("window_size");
    out_scale->mutable_data<T>(context.GetPlace());

    framework::Tensor cur_scale;
292 293 294 295 296 297 298 299 300
    T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
    FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                               cur_scale,
                                               *in_scale,
                                               *iter,
                                               window_size,
                                               out_scales,
301
                                               out_scale);
302 303
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
304 305 306
  }
};

307
template <typename DeviceContext, typename T>
308
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
309
 public:
310 311 312 313
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
    auto *out = context.Output<framework::Tensor>("Out");
314 315 316 317
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
    int bit_length = context.Attr<int>("bit_length");
318
    int round_type = context.Attr<int>("round_type");
319
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
320
    auto &dev_ctx = context.template device_context<DeviceContext>();
321 322 323

    // testing
    if (is_test) {
324
      RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
325 326 327 328
      return;
    }

    // training
329 330
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
331
    auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
332
    T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
333

334 335
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
336

337 338 339
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
340 341 342 343 344
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

345 346 347 348 349 350 351 352
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
353

354
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
355
  }
356 357 358 359

  virtual ~FakeMovingAverageAbsMaxKernelBase() = default;

 protected:
360 361 362 363 364 365
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &in_scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
366 367 368 369 370
};

template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
371
 protected:
372 373 374 375 376 377 378 379
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
380 381 382 383 384 385
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
386
 protected:
387 388 389 390 391 392
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
393 394
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
395 396 397
  }
};

Z
Zhen Wang 已提交
398 399 400
template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
 public:
401 402 403
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto &dev_ctx = context.template device_context<DeviceContext>();
Z
Zhen Wang 已提交
404

405
    if (context.HasOutput("Out")) {
406
      auto *out = context.Output<framework::Tensor>("Out");
407 408 409 410
      out->mutable_data<T>(context.GetPlace());
      framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
    }

Z
Zhen Wang 已提交
411 412 413 414 415 416 417
    bool is_test = context.Attr<bool>("is_test");
    // testing
    if (is_test) {
      return;
    }

    // training
418 419
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
420
    auto cur_scale = memory::Alloc(dev_ctx, sizeof(T));
421
    T *cur_scale_data = static_cast<T *>(cur_scale->ptr());
Z
Zhen Wang 已提交
422

423 424
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
Z
Zhen Wang 已提交
425

426 427 428
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
429 430 431 432 433
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

434 435 436 437 438 439 440 441
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
Z
Zhen Wang 已提交
442 443 444
  }
};

445
template <typename DeviceContext, typename T>
446
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
447
 public:
448 449
  void Compute(const framework::ExecutionContext &context) const override {
    auto *d_out =
450 451
        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto x_grad_name = framework::GradVarName("X");
452 453 454 455 456 457
    auto *d_x = context.Output<framework::LoDTensor>(x_grad_name);
    PADDLE_ENFORCE_NOT_NULL(d_x,
                            platform::errors::PreconditionNotMet(
                                "StrightThroughEstimatorGradKernel "
                                "doesn't have the output named %s.",
                                x_grad_name));
458 459 460 461 462 463 464

    // Initialize dx as same as d_out
    d_x->mutable_data<T>(context.GetPlace());
    framework::TensorCopy(*d_out, context.GetPlace(), d_x);
  }
};

视言's avatar
视言 已提交
465 466
}  // namespace operators
}  // namespace paddle