fake_quantize_op.cu.h 25.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#define PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#endif  // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_

#include <string>
20

21 22 23 24 25 26 27
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
namespace operators {

28 29 30 31 32 33 34 35 36 37
template <typename T>
struct QuantizeDataType {
  using type = T;
};

template <>
struct QuantizeDataType<paddle::platform::float16> {
  using type = float;
};

38
template <typename T>
39
__global__ void FindAbsMaxKernel(const T *in, const int n, T *out) {
40 41 42
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

43 44
  extern __shared__ char *shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
  if (gridDim.x > 1) {
    T local_max_data = T(0);
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
      T tmp = abs(in[i]);
      if (tmp > local_max_data) {
        local_max_data = tmp;
      }
    }
    shared_max_data[tid] = local_max_data;
  } else {
    if (bid < n) {
      shared_max_data[tid] = abs(in[bid]);
    } else {
      shared_max_data[tid] = T(0);
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

template <typename T>
L
Leo Chen 已提交
75 76
struct FindAbsMaxFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
77 78 79
                  const T *in,
                  const int num,
                  T *out) {
80 81 82 83 84
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

    framework::Tensor max;
85
    T *max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace());
86 87 88 89
    FindAbsMaxKernel<T>
        <<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(in, num, max_data);
    FindAbsMaxKernel<T>
        <<<1, block, 1024 * sizeof(T), ctx.stream()>>>(max_data, grid, out);
90 91 92
  }
};

L
Leo Chen 已提交
93 94
template struct FindAbsMaxFunctor<phi::GPUContext, float>;
template struct FindAbsMaxFunctor<phi::GPUContext, paddle::platform::float16>;
95 96

template <typename T>
97 98 99 100
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T *in,
                                                  const int n,
                                                  const int c,
                                                  T *out) {
101 102
  int tid = threadIdx.x;
  int channel_size = n / c;
103 104 105
  const T *in_c = in + blockIdx.x * channel_size;
  extern __shared__ char *shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
106 107
  T local_max_data = T(0);
  for (int i = tid; i < channel_size; i += blockDim.x) {
108 109
    T tmp = static_cast<T>(
        fabs(static_cast<typename QuantizeDataType<T>::type>(in_c[i])));
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
    if (tmp > local_max_data) {
      local_max_data = tmp;
    }
  }
  shared_max_data[tid] = local_max_data;
  __syncthreads();
  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

template <typename T>
128 129 130 131
__global__ void FindChannelAbsMaxKernelQuantAxis1(
    const T *in, const int n, const int cin, const int cout, T *out) {
  extern __shared__ char *shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T *>(shared_max_data_tmp);
132 133 134 135 136
  int cout_wh_size = n / cin;
  int wh_size = n / (cin * cout);

  int tid = threadIdx.x;
  int bid = blockIdx.x;
137
  const T *in_current = in + tid * cout_wh_size + bid * wh_size;
138 139
  T local_max_data = T(0);
  for (int i = 0; i < wh_size; i++) {
140 141
    T tmp = static_cast<T>(
        fabs(static_cast<typename QuantizeDataType<T>::type>(in_current[i])));
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165
    if (tmp > local_max_data) {
      local_max_data = tmp;
    }
  }
  shared_max_data[tid] = local_max_data;
  __syncthreads();

  int len = blockDim.x;
  for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
    if (tid < i && tid + i < len &&
        shared_max_data[tid] < shared_max_data[tid + i]) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    if (i == 1) {
      i = 0;  // break the loop
    }
    __syncthreads();
  }
  if (tid == 0 && shared_max_data[0] > out[bid]) {
    out[bid] = shared_max_data[0];
  }
}

template <typename T>
L
Leo Chen 已提交
166 167
struct FindChannelAbsMaxFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
168 169 170
                  const framework::Tensor &in_tensor,
                  const int quant_axis,
                  T *out_abs_max) {
171
    PADDLE_ENFORCE_EQ(
172 173
        quant_axis == 0 || quant_axis == 1,
        true,
174 175 176 177 178
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
    const int num = in_tensor.numel();
    auto in_dims = in_tensor.dims();
179
    const T *in_data = in_tensor.data<T>();
180 181 182 183
    if (quant_axis == 0) {
      int cout = in_dims[0];
      int grid = cout;
      int block = 1024;
184
      FindChannelAbsMaxKernelQuantAxis0<T>
185 186
          <<<grid, block, block * sizeof(T), ctx.stream()>>>(
              in_data, num, cout, out_abs_max);
187 188 189 190 191 192 193 194 195 196 197 198 199 200
    } else if (quant_axis == 1) {
      int cin = in_dims[0];
      int cout = in_dims[1];
      int grid = cout;
      int max_threads = 1024;

#ifdef PADDLE_WITH_HIP
      hipMemset(out_abs_max, 0, sizeof(T) * cout);
#else
      cudaMemset(out_abs_max, 0, sizeof(T) * cout);
#endif  // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_

      for (int i = 0; i < cin / max_threads; i++) {
        int block = max_threads;
201 202 203
        FindChannelAbsMaxKernelQuantAxis1<T>
            <<<grid, block, block * sizeof(T), ctx.stream()>>>(
                in_data, num, cin, cout, out_abs_max);
204 205 206 207 208
        in_data += num / cin;
      }

      int block = cin % max_threads;
      if (block > 0) {
209 210 211
        FindChannelAbsMaxKernelQuantAxis1<T>
            <<<grid, block, block * sizeof(T), ctx.stream()>>>(
                in_data, num, in_dims[0], in_dims[1], out_abs_max);
212 213 214 215 216
      }
    }
  }
};

L
Leo Chen 已提交
217
template struct FindChannelAbsMaxFunctor<phi::GPUContext, float>;
218 219

template <typename T>
220 221 222 223 224 225
__global__ void ClipAndQuantKernel(const T *in,
                                   const T *scale,
                                   const int bin_cnt,
                                   const int round_type,
                                   const int n,
                                   T *out) {
226 227 228
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

229 230 231 232 233 234
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);

235
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
236
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
237
    if (round_type == 0) {
238
      x = bin_cnt_t * inv_s * x;
239
      x = roundWithTiesToEven(x);
240 241 242 243 244
      ComputeDataType max_bound = bin_cnt_t;
      ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
      out[i] = static_cast<T>(x);
245
    } else {
246 247 248 249
      ComputeDataType v = x > s ? s : x;
      v = v < -s ? -s : v;
      v = bin_cnt_t * inv_s * v;
      out[i] = static_cast<T>(round(v));
250
    }
251 252 253 254
  }
}

template <typename T>
255 256
__global__ void ClipAndQuantDequantKernel(const T *in,
                                          const T *scale,
257
                                          const int bin_cnt,
258 259 260
                                          const int round_type,
                                          const int n,
                                          T *out) {
261 262 263
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

264 265 266 267 268
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
269 270

  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
271
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
272
    if (round_type == 0) {
273
      x = bin_cnt_t * inv_s * x;
274
      x = roundWithTiesToEven(x);
275 276 277 278 279
      ComputeDataType max_bound = bin_cnt_t;
      ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
      out[i] = static_cast<T>((x * s) / bin_cnt_t);
280
    } else {
281 282 283
      x = x > s ? s : x;
      x = x < -s ? -s : x;
      x = bin_cnt_t * inv_s * x;
284
      x = round(x);
285
      out[i] = static_cast<T>((x * s) / bin_cnt_t);
286
    }
287 288 289 290
  }
}

template <typename T>
L
Leo Chen 已提交
291 292
struct ClipAndFakeQuantFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
293 294 295 296 297
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  framework::Tensor *out) {
298 299 300 301
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

302 303 304
    const T *in_data = in.data<T>();
    const T *scale_data = scale.data<T>();
    T *out_data = out->mutable_data<T>(ctx.GetPlace());
305 306

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
307
        in_data, scale_data, bin_cnt, round_type, num, out_data);
308 309 310
  }
};

L
Leo Chen 已提交
311
template struct ClipAndFakeQuantFunctor<phi::GPUContext, float>;
312 313

template <typename T>
L
Leo Chen 已提交
314 315
struct ClipAndFakeQuantDequantFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
316 317 318 319 320
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  framework::Tensor *out) {
321 322 323 324
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

325 326 327
    const T *in_data = in.data<T>();
    const T *scale_data = scale.data<T>();
    T *out_data = out->mutable_data<T>(ctx.GetPlace());
328 329

    ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
330
        in_data, scale_data, bin_cnt, round_type, num, out_data);
331 332 333 334 335
  }
};

// ChannelClipAndQuantKernel for quant_axis is 0
template <typename T>
336 337
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T *in,
                                                    const T *scale,
338
                                                    const int bin_cnt,
339
                                                    const int round_type,
340
                                                    const int64_t n,
341 342
                                                    const int c,
                                                    T *out) {
343 344 345
  int tid = threadIdx.x;

  int64_t channel_size = n / c;
346 347
  const T *in_c = in + blockIdx.x * channel_size;
  T *out_c = out + blockIdx.x * channel_size;
348

349 350 351 352 353
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
354 355

  for (int64_t i = tid; i < channel_size; i += blockDim.x) {
356
    ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
357
    if (round_type == 0) {
358
      x = bin_cnt_t * inv_s * x;
359
      x = roundWithTiesToEven(x);
360 361 362 363 364
      ComputeDataType max_bound = bin_cnt_t;
      ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
      out_c[i] = static_cast<T>(x);
365
    } else {
366 367 368 369
      ComputeDataType v = x > s ? s : x;
      v = v < -s ? -s : v;
      v = bin_cnt_t * inv_s * v;
      out_c[i] = static_cast<T>(round(v));
370
    }
371 372 373 374 375
  }
}

// ChannelClipAndQuantKernel for quant_axis is N
template <typename T>
376 377 378 379 380 381 382 383
__global__ void ChannelClipAndQuantKernelQuantAxisN(const T *in,
                                                    const T *scale,
                                                    const int bin_cnt,
                                                    const int round_type,
                                                    const int64_t n,
                                                    const int nScale,
                                                    const int quant_stride,
                                                    T *out) {
384
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
385 386
  using ComputeDataType = typename QuantizeDataType<T>::type;
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
387
  for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
388 389 390 391
    ComputeDataType s =
        static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
    ComputeDataType inv_s = inverse(s);
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
392
    if (round_type == 0) {
393
      x = bin_cnt_t * inv_s * x;
394
      x = roundWithTiesToEven(x);
395 396 397 398 399
      ComputeDataType max_bound = bin_cnt_t;
      ComputeDataType min_bound = -bin_cnt_t - static_cast<ComputeDataType>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
      out[i] = static_cast<T>(x);
400
    } else {
401 402 403 404
      ComputeDataType v = x > s ? s : x;
      v = v < -s ? -s : v;
      v = bin_cnt_t * inv_s * v;
      out[i] = static_cast<T>(round(v));
405
    }
406 407 408 409
  }
}

template <typename T>
L
Leo Chen 已提交
410 411
struct ChannelClipAndFakeQuantFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
412 413 414 415 416 417
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  const int quant_axis,
                  framework::Tensor *out) {
418
    PADDLE_ENFORCE_EQ(
419 420
        quant_axis == 0 || quant_axis == 1,
        true,
421 422 423 424 425 426
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int64_t num = in.numel();
    auto in_dims = in.dims();
427 428 429
    const T *in_data = in.data<T>();
    const T *scale_data = scale.data<T>();
    T *out_data = out->mutable_data<T>(ctx.GetPlace());
430 431 432 433 434

    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
435
          in_data, scale_data, bin_cnt, round_type, num, in_dims[0], out_data);
436 437 438 439 440 441 442 443 444 445 446 447 448 449 450
    } else {
      int quant_stride = 1;
      for (int i = quant_axis + 1; i < in_dims.size(); i++) {
        quant_stride *= in_dims[i];
      }
      int64_t block_size =
          std::min(num, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
      int64_t max_threads =
          ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
      const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
                                          static_cast<int64_t>(1));

      const int64_t grid_size =
          std::min(max_blocks, (num + block_size - 1) / block_size);

451 452 453 454 455 456 457 458 459
      ChannelClipAndQuantKernelQuantAxisN<T>
          <<<grid_size, block_size>>>(in_data,
                                      scale_data,
                                      bin_cnt,
                                      round_type,
                                      num,
                                      in_dims[quant_axis],
                                      quant_stride,
                                      out_data);
460 461 462 463
    }
  }
};

L
Leo Chen 已提交
464
template struct ChannelClipAndFakeQuantFunctor<phi::GPUContext, float>;
465 466

template <typename T>
467 468 469 470 471 472 473 474
__global__ void FindRangeAbsMaxAndFillArray(const T *cur_scale,
                                            const T *last_scale,
                                            const int64_t *iter,
                                            const int window_size,
                                            T *scale_arr,
                                            T *out_scale,
                                            int *need_find_max,
                                            int *out_size) {
475 476 477 478 479 480 481
  int it = iter[0];
  int idx = it % window_size;
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
482 483
  if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
      1e-6) {
484 485 486 487 488 489 490 491
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
  } else {
    need_find_max[0] = 0;
  }
}

template <typename T>
L
Leo Chen 已提交
492 493
struct FindRangeAbsMaxFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
494 495 496 497 498 499
                  const framework::Tensor &cur_scale,
                  const framework::Tensor &last_scale,
                  const framework::Tensor &iter,
                  const int window_size,
                  framework::Tensor *scales_arr,
                  framework::Tensor *out_scale) {
500 501
    const auto gpu_place = ctx.GetPlace();

502 503
    T *scale_arr = scales_arr->mutable_data<T>(gpu_place);
    T *out_scale_data = out_scale->mutable_data<T>(gpu_place);
504 505

    framework::Tensor need_find_max, out_size;
506 507 508 509 510 511 512 513 514 515 516 517
    int *find_max = need_find_max.mutable_data<int>({1}, gpu_place);
    int *out_size_data = out_size.mutable_data<int>({1}, gpu_place);

    FindRangeAbsMaxAndFillArray<T>
        <<<1, 1, 0, ctx.stream()>>>(cur_scale.data<T>(),
                                    last_scale.data<T>(),
                                    iter.data<int64_t>(),
                                    window_size,
                                    scale_arr,
                                    out_scale_data,
                                    find_max,
                                    out_size_data);
518 519

    int g_find_max;
520 521 522 523 524 525
    memory::Copy(platform::CPUPlace(),
                 &g_find_max,
                 gpu_place,
                 find_max,
                 sizeof(int),
                 ctx.stream());
526 527 528
    ctx.Wait();
    if (g_find_max) {
      int len;
529 530 531 532 533 534
      memory::Copy(platform::CPUPlace(),
                   &len,
                   gpu_place,
                   out_size_data,
                   sizeof(int),
                   ctx.stream());
535
      ctx.Wait();
L
Leo Chen 已提交
536
      FindAbsMaxFunctor<phi::GPUContext, T>()(
537
          ctx, scale_arr, len, out_scale_data);
538 539 540 541 542
    }
  }
};

template <typename T>
543 544 545 546 547 548 549
__global__ void FindMovingAverageAbsMaxKernel(const T *in_state,
                                              const T *in_accum,
                                              const T *cur_scale,
                                              const T rate,
                                              T *out_state,
                                              T *out_accum,
                                              T *out_scale) {
550 551 552 553 554 555 556
  T state = rate * (*in_state) + T(1.0f);
  T accum = rate * (*in_accum) + (*cur_scale);
  *out_state = state;
  *out_accum = accum;
  *out_scale = accum / state;
}

L
Leo Chen 已提交
557
template struct FindRangeAbsMaxFunctor<phi::GPUContext, float>;
558 559

template <typename T>
L
Leo Chen 已提交
560 561
struct FindMovingAverageAbsMaxFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
562 563 564 565 566 567 568
                  const framework::Tensor &in_accum,
                  const framework::Tensor &in_state,
                  const T *cur_scale,
                  const float rate,
                  framework::Tensor *out_state,
                  framework::Tensor *out_accum,
                  framework::Tensor *out_scale) {
569 570 571
    const auto gpu_place = ctx.GetPlace();

    T rate_t = static_cast<T>(rate);
572 573 574 575 576 577 578 579 580 581 582 583
    T *out_state_data = out_state->mutable_data<T>(gpu_place);
    T *out_accum_data = out_accum->mutable_data<T>(gpu_place);
    T *out_scale_data = out_scale->mutable_data<T>(gpu_place);

    FindMovingAverageAbsMaxKernel<T>
        <<<1, 1, 0, ctx.stream()>>>(in_state.data<T>(),
                                    in_accum.data<T>(),
                                    cur_scale,
                                    rate_t,
                                    out_state_data,
                                    out_accum_data,
                                    out_scale_data);
584 585 586 587 588
  }
};

// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
589 590 591 592
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(const T *in,
                                                           const T *scale,
                                                           const int bin_cnt,
                                                           const int round_type,
593 594 595
                                                           const int wh_size,
                                                           const int num,
                                                           const int cout,
596
                                                           T *out) {
597
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
598

599 600 601 602
  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
    T s = scale[(i / wh_size) % cout];
    T inv_s = inverse(s);
    T x = in[i];
603
    if (round_type == 0) {
604
      x = bin_cnt * inv_s * x;
605
      x = roundWithTiesToEven(x);
606 607 608 609
      T max_bound = bin_cnt;
      T min_bound = -bin_cnt - static_cast<T>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
610
      out[i] = (x * s) / bin_cnt;
611
    } else {
612 613 614
      T v = x > s ? s : x;
      v = v < -s ? -s : v;
      v = bin_cnt * inv_s * v;
615
      out[i] = round(v) * s / bin_cnt;
616
    }
617 618 619 620 621
  }
}

// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
622 623 624 625
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(const T *in,
                                                           const T *scale,
                                                           const int bin_cnt,
                                                           const int round_type,
626 627
                                                           const int wh_size,
                                                           const int num,
628 629
                                                           const int cout,
                                                           T *out) {
630
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
631

632 633 634 635
  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
    T s = scale[(i / wh_size) % cout];
    T inv_s = inverse(s);
    T x = in[i];
636
    if (round_type == 0) {
637
      x = bin_cnt * inv_s * x;
638
      x = roundWithTiesToEven(x);
639 640 641 642
      T max_bound = bin_cnt;
      T min_bound = -bin_cnt - static_cast<T>(1);
      x = x > max_bound ? max_bound : x;
      x = x < min_bound ? min_bound : x;
643
      out[i] = (x * s) / bin_cnt;
644
    } else {
645 646 647
      T v = x > s ? s : x;
      v = v < -s ? -s : v;
      v = bin_cnt * inv_s * v;
648
      out[i] = round(v) * s / bin_cnt;
649
    }
650 651 652 653
  }
}

template <typename T>
L
Leo Chen 已提交
654 655
struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext &ctx,
656 657 658 659 660 661
                  const framework::Tensor &in,
                  const framework::Tensor &scale,
                  const int bin_cnt,
                  const int round_type,
                  const int quant_axis,
                  framework::Tensor *out) {
662 663 664
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
665 666
        quant_axis == 0 || quant_axis == 1,
        true,
667 668 669 670 671 672 673
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int num = in.numel();
    auto in_dims = in.dims();

674 675 676
    const T *in_data = in.data<T>();
    const T *scale_data = scale.data<T>();
    T *out_data = out->mutable_data<T>(ctx.GetPlace());
677

678 679 680 681 682 683 684 685 686 687
    int64_t block_size =
        std::min(static_cast<int64_t>(num),
                 static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));

    int64_t max_threads = ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
    const int64_t max_blocks =
        std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
    const int64_t grid_size =
        std::min(max_blocks, (num + block_size - 1) / block_size);

688
    if (quant_axis == 0) {
689
      const int window_size = num / in_dims[0];
690
      ChannelClipAndQuantDequantKernelQuantAxis0<T>
691 692 693 694 695 696 697 698
          <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
                                                       scale_data,
                                                       bin_cnt,
                                                       round_type,
                                                       window_size,
                                                       num,
                                                       in_dims[0],
                                                       out_data);
699
    } else if (quant_axis == 1) {
700
      const int window_size = num / (in_dims[0] * in_dims[1]);
701

702
      ChannelClipAndQuantDequantKernelQuantAxis1<T>
703 704 705 706 707 708 709 710
          <<<grid_size, block_size, 0, ctx.stream()>>>(in_data,
                                                       scale_data,
                                                       bin_cnt,
                                                       round_type,
                                                       window_size,
                                                       num,
                                                       in_dims[1],
                                                       out_data);
711 712 713 714
    }
  }
};

L
Leo Chen 已提交
715
template struct ChannelClipFakeQuantDequantFunctor<phi::GPUContext, float>;
716 717 718

}  // namespace operators
}  // namespace paddle