fake_quantize_op.cu 19.7 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16
#include "paddle/fluid/memory/memcpy.h"
视言's avatar
视言 已提交
17 18 19 20 21 22 23
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

template <typename T>
24
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
视言's avatar
视言 已提交
25 26 27
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

C
cc 已提交
28 29
  extern __shared__ char* shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
视言's avatar
视言 已提交
30 31 32
  if (gridDim.x > 1) {
    shared_max_data[tid] = T(0);
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
C
cc 已提交
33
      T tmp = abs(in[i]);
视言's avatar
视言 已提交
34 35 36 37 38 39
      if (tmp > shared_max_data[tid]) {
        shared_max_data[tid] = tmp;
      }
    }
  } else {
    if (bid < n) {
C
cc 已提交
40
      shared_max_data[tid] = abs(in[bid]);
视言's avatar
视言 已提交
41 42 43 44 45 46 47
    } else {
      shared_max_data[tid] = T(0);
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
48
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
视言's avatar
视言 已提交
49 50 51 52 53 54 55 56 57
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, T* out) {
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

    framework::Tensor max;
    T* max_data =
        max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
    FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, max_data);
    FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
        max_data, grid, out);
  }
};

template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
C
cc 已提交
77 78
template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
                                  paddle::platform::float16>;
视言's avatar
视言 已提交
79

80
template <typename T>
81 82
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
                                                  const int c, T* out) {
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
  int tid = threadIdx.x;
  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  extern __shared__ T shared_max_data[];
  shared_max_data[tid] = T(0);
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T tmp = fabs(in_c[i]);
    if (tmp > shared_max_data[tid]) {
      shared_max_data[tid] = tmp;
    }
  }
  __syncthreads();
  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
                                                  const int cin, const int cout,
                                                  T* out) {
  extern __shared__ T shared_max_data[];
  int cout_wh_size = n / cin;
  int wh_size = n / (cin * cout);

  int tid = threadIdx.x;
  int bid = blockIdx.x;
  const T* in_current = in + tid * cout_wh_size + bid * wh_size;
  shared_max_data[tid] = T(0);
  for (int i = 0; i < wh_size; i++) {
    T tmp = fabs(in_current[i]);
    if (tmp > shared_max_data[tid]) {
      shared_max_data[tid] = tmp;
    }
  }
  __syncthreads();

  int len = blockDim.x;
  for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
    if (tid < i && tid + i < len &&
        shared_max_data[tid] < shared_max_data[tid + i]) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    if (i == 1) {
      i = 0;  // break the loop
    }
    __syncthreads();
  }
137
  if (tid == 0 && shared_max_data[0] > out[bid]) {
138 139 140 141
    out[bid] = shared_max_data[0];
  }
}

142 143
template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
144 145 146 147 148 149 150 151 152 153 154 155
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_tensor, const int quant_axis,
                  T* out_abs_max) {
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
    const int num = in_tensor.numel();
    auto in_dims = in_tensor.dims();
    const T* in_data = in_tensor.data<T>();
    if (quant_axis == 0) {
156 157
      int cout = in_dims[0];
      int grid = cout;
158 159 160
      int block = 1024;
      FindChannelAbsMaxKernelQuantAxis0<
          T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
161
          in_data, num, cout, out_abs_max);
162
    } else if (quant_axis == 1) {
163 164 165 166 167
      int cin = in_dims[0];
      int cout = in_dims[1];
      int grid = cout;
      int max_threads = 1024;

168 169 170
#ifdef PADDLE_WITH_HIP
      hipMemset(out_abs_max, 0, sizeof(T) * cout);
#else
171
      cudaMemset(out_abs_max, 0, sizeof(T) * cout);
172
#endif
173 174 175 176 177 178 179 180 181 182 183 184 185 186 187

      for (int i = 0; i < cin / max_threads; i++) {
        int block = max_threads;
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, cin, cout, out_abs_max);
        in_data += num / cin;
      }

      int block = cin % max_threads;
      if (block > 0) {
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, in_dims[0], in_dims[1], out_abs_max);
      }
188
    }
189 190 191 192 193
  }
};

template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;

视言's avatar
视言 已提交
194
template <typename T>
195 196
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
                                   const int bin_cnt, const int n, T* out) {
视言's avatar
视言 已提交
197 198 199
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

200
  T s = scale[0];
201
  T inv_s = inverse(s);
视言's avatar
视言 已提交
202
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
203
    T x = in[i];
204 205
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
206
    v = bin_cnt * inv_s * v;
207
    out[i] = round(v);
视言's avatar
视言 已提交
208 209 210
  }
}

211 212 213 214 215 216 217 218
template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
                                          const int bin_cnt, const int n,
                                          T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  T s = scale[0];
W
whs 已提交
219
  T inv_s = inverse(s);
C
cc 已提交
220 221
  T bin_cnt_t = static_cast<T>(bin_cnt);

222 223
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
    T x = in[i];
C
cc 已提交
224 225
    x = x > s ? s : x;
    x = x < -s ? -s : x;
W
whs 已提交
226
    x = bin_cnt_t * inv_s * x;
C
cc 已提交
227 228
    x = static_cast<T>(round(static_cast<float>(x)));
    out[i] = (x * s) / bin_cnt_t;
229 230 231
  }
}

232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;

252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

270
// ChannelClipAndQuantKernel for quant_axis is 0
271
template <typename T>
272 273 274 275
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
                                                    const int bin_cnt,
                                                    const int n, const int c,
                                                    T* out) {
276 277 278 279 280 281 282
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
283 284
  T inv_s = inverse(s);

285 286 287 288
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
289
    v = bin_cnt * inv_s * v;
290 291 292 293
    out_c[i] = round(v);
  }
}

294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
// ChannelClipAndQuantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis1(const T* in, const T* scale,
                                                    const int bin_cnt,
                                                    const int n, const int cin,
                                                    const int cout, T* out) {
  T s = scale[blockIdx.x % cout];
  T inv_s = inverse(s);

  int wh_size = n / (cin * cout);
  const T* in_c = in + blockIdx.x * wh_size;
  T* out_c = out + blockIdx.x * wh_size;

  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v);
  }
}

316 317 318 319
template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
320
                  const int bin_cnt, const int quant_axis,
321
                  framework::Tensor* out) {
322 323 324 325 326
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
327

328 329
    int num = in.numel();
    auto in_dims = in.dims();
330 331 332 333
    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

334 335 336 337 338 339 340 341 342 343 344
    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
    } else if (quant_axis == 1) {
      int grid = in_dims[0] * in_dims[1];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis1<T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
    }
345 346 347 348 349 350
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
                                               float>;

视言's avatar
视言 已提交
351
template <typename T>
352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
                                            const T* last_scale,
                                            const int64_t* iter,
                                            const int window_size, T* scale_arr,
                                            T* out_scale, int* need_find_max,
                                            int* out_size) {
  int it = iter[0];
  int idx = it % window_size;
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
  if (fabs(removed - max) < 1e-6) {
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
视言's avatar
视言 已提交
368
  } else {
369
    need_find_max[0] = 0;
视言's avatar
视言 已提交
370 371 372 373
  }
}

template <typename T>
374 375 376 377 378 379
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
380
    const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
M
minqiyang 已提交
381

382 383 384 385
    T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);

    framework::Tensor need_find_max, out_size;
Z
Zeng Jinle 已提交
386 387
    int* find_max = need_find_max.mutable_data<int>({1}, gpu_place);
    int* out_size_data = out_size.mutable_data<int>({1}, gpu_place);
388 389 390 391 392 393 394

    FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
        cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
        window_size, scale_arr, out_scale_data, find_max, out_size_data);

    int g_find_max;
    memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
395 396
                 sizeof(int), ctx.stream());
    ctx.Wait();
397 398 399
    if (g_find_max) {
      int len;
      memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
400 401
                   sizeof(int), ctx.stream());
      ctx.Wait();
402 403
      FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
                                                          out_scale_data);
视言's avatar
视言 已提交
404 405
    }
  }
406
};
视言's avatar
视言 已提交
407

408
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
视言's avatar
视言 已提交
409

410 411 412 413 414 415 416
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
417
    const auto gpu_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
418 419 420 421

    T accum;
    T state;
    T scale;
422 423 424 425
    memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
                 sizeof(T), ctx.stream());
    memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
                 sizeof(T), ctx.stream());
426
    memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
427 428
                 ctx.stream());
    ctx.Wait();
C
cc 已提交
429 430 431 432

    T rate_t = static_cast<T>(rate);
    state = rate_t * state + static_cast<T>(1.0);
    accum = rate_t * accum + scale;
433 434 435
    scale = accum / state;

    memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
436
                 platform::CPUPlace(), &accum, sizeof(T), ctx.stream());
437
    memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
438
                 platform::CPUPlace(), &state, sizeof(T), ctx.stream());
439
    memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
440 441
                 platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
    ctx.Wait();
442 443 444
  }
};

H
huangxu96 已提交
445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528
// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
    const T* in, const T* scale, const int bin_cnt, const int n, const int c,
    T* out) {
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
  T inv_s = inverse(s);

  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
    const T* in, const T* scale, const int bin_cnt, const int n, const int cin,
    const int cout, T* out) {
  T s = scale[blockIdx.x % cout];
  T inv_s = inverse(s);

  int wh_size = n / (cin * cout);
  const T* in_c = in + blockIdx.x * wh_size;
  T* out_c = out + blockIdx.x * wh_size;

  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int quant_axis,
                  framework::Tensor* out) {
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int num = in.numel();
    auto in_dims = in.dims();

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantDequantKernelQuantAxis0<
          T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt,
                                               num, in_dims[0], out_data);
    } else if (quant_axis == 1) {
      int grid = in_dims[0] * in_dims[1];
      int block = 1024;

      ChannelClipAndQuantDequantKernelQuantAxis1<
          T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
    }
  }
};

template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
                                                   float>;
529

视言's avatar
视言 已提交
530 531 532
}  // namespace operators
}  // namespace paddle

533 534
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
C
cc 已提交
535
using float16 = paddle::platform::float16;
536 537
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
                        ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
538
REGISTER_OP_CUDA_KERNEL(fake_quantize_dequantize_abs_max,
C
cc 已提交
539 540
                        ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float>,
                        ops::FakeQuantizeDequantizeAbsMaxKernel<CUDA, float16>);
Z
Zhen Wang 已提交
541 542
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
                        ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
543 544
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
                        ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
545 546 547
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_moving_average_abs_max,
    ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
Z
Zhen Wang 已提交
548
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
C
cc 已提交
549 550
                        ops::MovingAverageAbsMaxScaleKernel<CUDA, float>,
                        ops::MovingAverageAbsMaxScaleKernel<CUDA, float16>);
551 552
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_dequantize_moving_average_abs_max,
C
cc 已提交
553 554
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>,
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float16>);
555
REGISTER_OP_CUDA_KERNEL(stright_throuth_estimator_grad,
C
cc 已提交
556 557
                        ops::StrightThroughEstimatorGradKernel<CUDA, float>,
                        ops::StrightThroughEstimatorGradKernel<CUDA, float16>);
H
huangxu96 已提交
558 559 560
REGISTER_OP_CUDA_KERNEL(
    fake_channel_wise_quantize_dequantize_abs_max,
    ops::FakeChannelWiseQuantizeDequantizeAbsMaxKernel<CUDA, float>);