fake_quantize_op.h 17.6 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <string>
18

视言's avatar
视言 已提交
19 20
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
Z
Zhen Wang 已提交
21
#include "paddle/fluid/framework/tensor_util.h"
22
#include "paddle/fluid/memory/malloc.h"
23
#include "paddle/fluid/platform/transform.h"
24 25
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
视言's avatar
视言 已提交
26 27 28 29

namespace paddle {
namespace operators {

30 31
template <typename T>
inline HOSTDEVICE T inverse(T s) {
W
whs 已提交
32 33 34
  T eps = static_cast<T>(1e-6);
  T one = static_cast<T>(1.0);
  return s <= static_cast<T>(1e-30) ? one / (s + eps) : one / s;
35 36
}

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53
template <typename T>
inline HOSTDEVICE T roundWithTiesToEven(T x) {
  T xLower = floor(x);
  T xUpper = ceil(x);
  // x is in interval [xl,xu]. Choose closest of two bounds, breaking ties to
  // even.
  T dLower = x - xLower;
  T dUpper = xUpper - x;
  return static_cast<T>(
      (dLower == dUpper ? fmod(xLower, 2.0F) == 0.0F : dLower < dUpper)
          ? xLower
          : xUpper);
}

template <typename T>
class QuantTensorFunctor {
 public:
54 55
  explicit QuantTensorFunctor(const T bin_cnt, const T inv_s)
      : bin_cnt_(bin_cnt), inv_s_(inv_s) {}
56 57
  HOSTDEVICE T operator()(const T x) const {
    T out = bin_cnt_ * inv_s_ * x;
58
    out = roundWithTiesToEven(out);
59 60 61 62 63 64 65 66 67 68 69 70
    T max_bound = bin_cnt_;
    T min_bound = -bin_cnt_ - static_cast<T>(1);
    out = out > max_bound ? max_bound : out;
    out = out < min_bound ? min_bound : out;
    return out;
  }

 private:
  T bin_cnt_;
  T inv_s_;
};

71 72
template <typename DeviceContext, typename T>
struct FindAbsMaxFunctor {
73
  void operator()(const DeviceContext &ctx, const T *in, const int num, T *out);
74
};
视言's avatar
视言 已提交
75 76

template <typename DeviceContext, typename T>
77
struct ClipAndFakeQuantFunctor {
78 79 80 81 82 83
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  framework::Tensor *out);
84 85
};

86 87
template <typename DeviceContext, typename T>
struct ClipAndFakeQuantDequantFunctor {
88 89 90 91 92 93
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  framework::Tensor *out);
94 95
};

96 97
template <typename DeviceContext, typename T>
struct FindRangeAbsMaxFunctor {
98 99 100 101 102 103 104
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &cur_scale,
                  const framework::Tensor &last_scale,
                  const framework::Tensor &iter,
                  const int window_size,
                  framework::Tensor *scales_arr,
                  framework::Tensor *out_scale);
105 106
};

107 108
template <typename DeviceContext, typename T>
struct FindChannelAbsMaxFunctor {
109 110 111 112
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_tensor,
                  const int quant_axis,
                  T *out_abs_max);
113 114 115 116
};

template <typename DeviceContext, typename T>
struct ChannelClipAndFakeQuantFunctor {
117 118 119 120 121 122 123
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
124 125
};

H
huangxu96 已提交
126 127
template <typename DeviceContext, typename T>
struct ChannelClipFakeQuantDequantFunctor {
128 129 130 131 132 133 134
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  int round_type,
                  const int quant_axis,
                  framework::Tensor *out);
H
huangxu96 已提交
135 136
};

137 138
template <typename DeviceContext, typename T>
struct FindMovingAverageAbsMaxFunctor {
139 140 141
  void operator()(const DeviceContext &ctx,
                  const framework::Tensor &in_accum,
                  const framework::Tensor &in_state,
142 143
                  const T *cur_scale,
                  const float rate,
144 145 146
                  framework::Tensor *out_state,
                  framework::Tensor *out_accum,
                  framework::Tensor *out_scale);
147 148
};

149
template <typename DeviceContext, typename T>
150
class FakeAbsMaxKernelBase : public framework::OpKernel<T> {
视言's avatar
视言 已提交
151
 public:
152 153 154 155 156
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_s = out_scale->mutable_data<T>(context.GetPlace());
157 158

    int bit_length = context.Attr<int>("bit_length");
159
    int round_type = context.Attr<int>("round_type");
160 161
    int bin_cnt = std::pow(2, bit_length - 1) - 1;

162 163
    auto &dev_ctx = context.template device_context<DeviceContext>();
    const T *in_data = in->data<T>();
164
    FindAbsMaxFunctor<DeviceContext, T>()(dev_ctx, in_data, in->numel(), out_s);
165
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
166 167 168 169 170
  }

  virtual ~FakeAbsMaxKernelBase() = default;

 protected:
171 172 173 174 175 176
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
177 178 179 180 181
};

template <typename DeviceContext, typename T>
class FakeQuantizeAbsMaxKernel : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
182 183 184 185 186 187 188 189
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
190 191 192 193 194 195 196
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeAbsMaxKernel
    : public FakeAbsMaxKernelBase<DeviceContext, T> {
 protected:
197 198 199 200 201 202
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
203 204
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
205
  }
206
};
视言's avatar
视言 已提交
207

Z
Zhen Wang 已提交
208 209 210
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeAbsMaxKernel : public framework::OpKernel<T> {
 public:
211 212
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
Z
Zhen Wang 已提交
213

214 215
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
216 217 218
    out->mutable_data<T>(context.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
219
    int round_type = context.Attr<int>("round_type");
Z
Zhen Wang 已提交
220
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
221
    int quant_axis = context.Attr<int>("quant_axis");
222
    bool is_test = context.Attr<bool>("is_test");
Z
Zhen Wang 已提交
223

224
    auto &dev_ctx = context.template device_context<DeviceContext>();
225
    if (!is_test) {
226 227 228
      T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
      FindChannelAbsMaxFunctor<DeviceContext, T>()(
          dev_ctx, *in, quant_axis, out_scale_data);
229
    }
230
    ChannelClipAndFakeQuantFunctor<DeviceContext, T>()(
231
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
Z
Zhen Wang 已提交
232 233 234
  }
};

H
huangxu96 已提交
235 236 237 238
template <typename DeviceContext, typename T>
class FakeChannelWiseQuantizeDequantizeAbsMaxKernel
    : public framework::OpKernel<T> {
 public:
239 240 241 242 243 244
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *out = context.Output<framework::Tensor>("Out");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    T *out_scale_data = out_scale->mutable_data<T>(context.GetPlace());
    auto &dev_ctx = context.template device_context<DeviceContext>();
H
huangxu96 已提交
245 246 247
    out->mutable_data<T>(dev_ctx.GetPlace());

    int bit_length = context.Attr<int>("bit_length");
248
    int round_type = context.Attr<int>("round_type");
H
huangxu96 已提交
249 250 251
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
    int quant_axis = context.Attr<int>("quant_axis");

252 253
    FindChannelAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, *in, quant_axis, out_scale_data);
H
huangxu96 已提交
254 255

    ChannelClipFakeQuantDequantFunctor<DeviceContext, T>()(
256
        dev_ctx, *in, *out_scale, bin_cnt, round_type, quant_axis, out);
H
huangxu96 已提交
257 258 259
  }
};

260 261 262
template <typename DeviceContext, typename T>
class FakeQuantizeRangeAbsMaxKernel : public framework::OpKernel<T> {
 public:
263 264 265
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
视言's avatar
视言 已提交
266

267
    auto *out = context.Output<framework::Tensor>("Out");
268 269 270
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
视言's avatar
视言 已提交
271
    int bit_length = context.Attr<int>("bit_length");
272
    int round_type = context.Attr<int>("round_type");
视言's avatar
视言 已提交
273
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
274
    auto &dev_ctx = context.template device_context<DeviceContext>();
视言's avatar
视言 已提交
275

276 277
    // testing
    if (is_test) {
278 279
      ClipAndFakeQuantFunctor<DeviceContext, T>()(
          dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
280
      return;
视言's avatar
视言 已提交
281 282
    }

283
    // training
284 285 286
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
    auto *out_scales = context.Output<framework::Tensor>("OutScales");
    auto *iter = context.Input<framework::Tensor>("Iter");
287 288 289 290 291

    int window_size = context.Attr<int>("window_size");
    out_scale->mutable_data<T>(context.GetPlace());

    framework::Tensor cur_scale;
292 293 294 295 296 297 298 299 300
    T *cur_scale_data = cur_scale.mutable_data<T>({1}, context.GetPlace());
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
    FindRangeAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                               cur_scale,
                                               *in_scale,
                                               *iter,
                                               window_size,
                                               out_scales,
301
                                               out_scale);
302 303
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
视言's avatar
视言 已提交
304 305 306
  }
};

307
template <typename DeviceContext, typename T>
308
class FakeMovingAverageAbsMaxKernelBase : public framework::OpKernel<T> {
309
 public:
310 311 312 313
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto *in_scale = context.Input<framework::Tensor>("InScale");
    auto *out = context.Output<framework::Tensor>("Out");
314 315 316 317
    out->mutable_data<T>(context.GetPlace());

    bool is_test = context.Attr<bool>("is_test");
    int bit_length = context.Attr<int>("bit_length");
318
    int round_type = context.Attr<int>("round_type");
319
    int bin_cnt = std::pow(2, bit_length - 1) - 1;
320
    auto &dev_ctx = context.template device_context<DeviceContext>();
321 322 323

    // testing
    if (is_test) {
324
      RunClipFunctor(dev_ctx, *in, *in_scale, bin_cnt, round_type, out);
325 326 327 328
      return;
    }

    // training
329 330
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
331 332 333 334

    phi::DenseTensor tmp_scale;
    tmp_scale.Resize(phi::make_dim(1));
    T *cur_scale_data = dev_ctx.template Alloc<T>(&tmp_scale);
335

336 337
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
338

339 340 341
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
342 343 344 345 346
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

347 348 349 350 351 352 353 354
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
355

356
    RunClipFunctor(dev_ctx, *in, *out_scale, bin_cnt, round_type, out);
357
  }
358 359 360 361

  virtual ~FakeMovingAverageAbsMaxKernelBase() = default;

 protected:
362 363 364 365 366 367
  virtual void RunClipFunctor(const DeviceContext &dev_ctx,
                              const framework::Tensor &in,
                              const framework::Tensor &in_scale,
                              int bin_cnt,
                              int round_type,
                              framework::Tensor *out) const = 0;
368 369 370 371 372
};

template <typename DeviceContext, typename T>
class FakeQuantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
373
 protected:
374 375 376 377 378 379 380 381
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
    ClipAndFakeQuantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
382 383 384 385 386 387
  }
};

template <typename DeviceContext, typename T>
class FakeQuantizeDequantizeMovingAverageAbsMaxKernel
    : public FakeMovingAverageAbsMaxKernelBase<DeviceContext, T> {
388
 protected:
389 390 391 392 393 394
  void RunClipFunctor(const DeviceContext &dev_ctx,
                      const framework::Tensor &in,
                      const framework::Tensor &in_scale,
                      int bin_cnt,
                      int round_type,
                      framework::Tensor *out) const override {
395 396
    ClipAndFakeQuantDequantFunctor<DeviceContext, T>()(
        dev_ctx, in, in_scale, bin_cnt, round_type, out);
397 398 399
  }
};

Z
Zhen Wang 已提交
400 401 402
template <typename DeviceContext, typename T>
class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
 public:
403 404 405
  void Compute(const framework::ExecutionContext &context) const override {
    auto *in = context.Input<framework::Tensor>("X");
    auto &dev_ctx = context.template device_context<DeviceContext>();
Z
Zhen Wang 已提交
406

407
    if (context.HasOutput("Out")) {
408
      auto *out = context.Output<framework::Tensor>("Out");
409 410 411 412
      out->mutable_data<T>(context.GetPlace());
      framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
    }

Z
Zhen Wang 已提交
413 414 415 416 417 418 419
    bool is_test = context.Attr<bool>("is_test");
    // testing
    if (is_test) {
      return;
    }

    // training
420 421
    auto *in_accum = context.Input<framework::Tensor>("InAccum");
    auto *in_state = context.Input<framework::Tensor>("InState");
422 423 424
    phi::DenseTensor tmp_scale;
    tmp_scale.Resize(phi::make_dim(1));
    T *cur_scale_data = dev_ctx.template Alloc<T>(&tmp_scale);
Z
Zhen Wang 已提交
425

426 427
    FindAbsMaxFunctor<DeviceContext, T>()(
        dev_ctx, in->data<T>(), in->numel(), cur_scale_data);
Z
Zhen Wang 已提交
428

429 430 431
    auto *out_state = context.Output<framework::Tensor>("OutState");
    auto *out_accum = context.Output<framework::Tensor>("OutAccum");
    auto *out_scale = context.Output<framework::Tensor>("OutScale");
Z
Zhen Wang 已提交
432 433 434 435 436
    out_state->mutable_data<T>(context.GetPlace());
    out_accum->mutable_data<T>(context.GetPlace());
    out_scale->mutable_data<T>(context.GetPlace());
    float moving_rate = context.Attr<float>("moving_rate");

437 438 439 440 441 442 443 444
    FindMovingAverageAbsMaxFunctor<DeviceContext, T>()(dev_ctx,
                                                       *in_accum,
                                                       *in_state,
                                                       cur_scale_data,
                                                       moving_rate,
                                                       out_state,
                                                       out_accum,
                                                       out_scale);
Z
Zhen Wang 已提交
445 446 447
  }
};

448
template <typename DeviceContext, typename T>
449
class StrightThroughEstimatorGradKernel : public framework::OpKernel<T> {
450
 public:
451 452
  void Compute(const framework::ExecutionContext &context) const override {
    auto *d_out =
453 454
        context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
    auto x_grad_name = framework::GradVarName("X");
455 456 457 458 459 460
    auto *d_x = context.Output<framework::LoDTensor>(x_grad_name);
    PADDLE_ENFORCE_NOT_NULL(d_x,
                            platform::errors::PreconditionNotMet(
                                "StrightThroughEstimatorGradKernel "
                                "doesn't have the output named %s.",
                                x_grad_name));
461 462 463 464 465 466 467

    // Initialize dx as same as d_out
    d_x->mutable_data<T>(context.GetPlace());
    framework::TensorCopy(*d_out, context.GetPlace(), d_x);
  }
};

视言's avatar
视言 已提交
468 469
}  // namespace operators
}  // namespace paddle