fake_quantize_op.cu.h 20.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#ifndef PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#define PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_
#endif  // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_

#include <string>
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"

namespace paddle {
namespace operators {

27 28 29 30 31 32 33 34 35 36
template <typename T>
struct QuantizeDataType {
  using type = T;
};

template <>
struct QuantizeDataType<paddle::platform::float16> {
  using type = float;
};

37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99
template <typename T>
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  extern __shared__ char* shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
  if (gridDim.x > 1) {
    T local_max_data = T(0);
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
      T tmp = abs(in[i]);
      if (tmp > local_max_data) {
        local_max_data = tmp;
      }
    }
    shared_max_data[tid] = local_max_data;
  } else {
    if (bid < n) {
      shared_max_data[tid] = abs(in[bid]);
    } else {
      shared_max_data[tid] = T(0);
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, T* out) {
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

    framework::Tensor max;
    T* max_data = max.mutable_data<T>(phi::make_ddim({grid}), ctx.GetPlace());
    FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, max_data);
    FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
        max_data, grid, out);
  }
};

template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
template struct FindAbsMaxFunctor<platform::CUDADeviceContext,
                                  paddle::platform::float16>;

template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis0(const T* in, const int n,
                                                  const int c, T* out) {
  int tid = threadIdx.x;
  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
100 101
  extern __shared__ char* shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
102 103
  T local_max_data = T(0);
  for (int i = tid; i < channel_size; i += blockDim.x) {
104 105
    T tmp = static_cast<T>(
        fabs(static_cast<typename QuantizeDataType<T>::type>(in_c[i])));
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126
    if (tmp > local_max_data) {
      local_max_data = tmp;
    }
  }
  shared_max_data[tid] = local_max_data;
  __syncthreads();
  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

template <typename T>
__global__ void FindChannelAbsMaxKernelQuantAxis1(const T* in, const int n,
                                                  const int cin, const int cout,
                                                  T* out) {
127 128
  extern __shared__ char* shared_max_data_tmp[];
  auto shared_max_data = reinterpret_cast<T*>(shared_max_data_tmp);
129 130 131 132 133 134 135 136
  int cout_wh_size = n / cin;
  int wh_size = n / (cin * cout);

  int tid = threadIdx.x;
  int bid = blockIdx.x;
  const T* in_current = in + tid * cout_wh_size + bid * wh_size;
  T local_max_data = T(0);
  for (int i = 0; i < wh_size; i++) {
137 138
    T tmp = static_cast<T>(
        fabs(static_cast<typename QuantizeDataType<T>::type>(in_current[i])));
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
    if (tmp > local_max_data) {
      local_max_data = tmp;
    }
  }
  shared_max_data[tid] = local_max_data;
  __syncthreads();

  int len = blockDim.x;
  for (int i = (len + 1) / 2; i > 0; len = i, i = (i + 1) / 2) {
    if (tid < i && tid + i < len &&
        shared_max_data[tid] < shared_max_data[tid + i]) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    if (i == 1) {
      i = 0;  // break the loop
    }
    __syncthreads();
  }
  if (tid == 0 && shared_max_data[0] > out[bid]) {
    out[bid] = shared_max_data[0];
  }
}

template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_tensor, const int quant_axis,
                  T* out_abs_max) {
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));
    const int num = in_tensor.numel();
    auto in_dims = in_tensor.dims();
    const T* in_data = in_tensor.data<T>();
    if (quant_axis == 0) {
      int cout = in_dims[0];
      int grid = cout;
      int block = 1024;
      FindChannelAbsMaxKernelQuantAxis0<
          T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
          in_data, num, cout, out_abs_max);
    } else if (quant_axis == 1) {
      int cin = in_dims[0];
      int cout = in_dims[1];
      int grid = cout;
      int max_threads = 1024;

#ifdef PADDLE_WITH_HIP
      hipMemset(out_abs_max, 0, sizeof(T) * cout);
#else
      cudaMemset(out_abs_max, 0, sizeof(T) * cout);
#endif  // PADDLE_FLUID_OPERATORS_FAKE_QUANTIZE_OP_CU_H_

      for (int i = 0; i < cin / max_threads; i++) {
        int block = max_threads;
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, cin, cout, out_abs_max);
        in_data += num / cin;
      }

      int block = cin % max_threads;
      if (block > 0) {
        FindChannelAbsMaxKernelQuantAxis1<
            T><<<grid, block, block * sizeof(T), ctx.stream()>>>(
            in_data, num, in_dims[0], in_dims[1], out_abs_max);
      }
    }
  }
};

template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;

template <typename T>
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
                                   const int bin_cnt, const int n, T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

220 221 222 223 224 225
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);

226
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
227 228
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
    ComputeDataType v = x > s ? s : x;
229
    v = v < -s ? -s : v;
230
    v = bin_cnt_t * inv_s * v;
231
    out[i] = static_cast<T>(round(v));
232 233 234 235 236 237 238 239 240 241
  }
}

template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
                                          const int bin_cnt, const int n,
                                          T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

242 243 244 245 246
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[0]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
247 248

  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
249
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
250 251 252
    x = x > s ? s : x;
    x = x < -s ? -s : x;
    x = bin_cnt_t * inv_s * x;
253 254
    x = round(x);
    out[i] = static_cast<T>((x * s) / bin_cnt_t);
255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307
  }
}

template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;

template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

// ChannelClipAndQuantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxis0(const T* in, const T* scale,
                                                    const int bin_cnt,
                                                    const int64_t n,
                                                    const int c, T* out) {
  int tid = threadIdx.x;

  int64_t channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

308 309 310 311 312
  using ComputeDataType = typename QuantizeDataType<T>::type;

  ComputeDataType s = static_cast<ComputeDataType>(scale[blockIdx.x]);
  ComputeDataType inv_s = inverse(s);
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
313 314

  for (int64_t i = tid; i < channel_size; i += blockDim.x) {
315 316
    ComputeDataType x = static_cast<ComputeDataType>(in_c[i]);
    ComputeDataType v = x > s ? s : x;
317
    v = v < -s ? -s : v;
318
    v = bin_cnt_t * inv_s * v;
319
    out_c[i] = static_cast<T>(round(v));
320 321 322 323 324 325 326 327 328
  }
}

// ChannelClipAndQuantKernel for quant_axis is N
template <typename T>
__global__ void ChannelClipAndQuantKernelQuantAxisN(
    const T* in, const T* scale, const int bin_cnt, const int64_t n,
    const int nScale, const int quant_stride, T* out) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
329 330
  using ComputeDataType = typename QuantizeDataType<T>::type;
  ComputeDataType bin_cnt_t = static_cast<ComputeDataType>(bin_cnt);
331
  for (int64_t i = idx; i < n; i += blockDim.x * gridDim.x) {
332 333 334 335 336
    ComputeDataType s =
        static_cast<ComputeDataType>(scale[(i / quant_stride) % nScale]);
    ComputeDataType inv_s = inverse(s);
    ComputeDataType x = static_cast<ComputeDataType>(in[i]);
    ComputeDataType v = x > s ? s : x;
337
    v = v < -s ? -s : v;
338
    v = bin_cnt_t * inv_s * v;
339
    out[i] = static_cast<T>(round(v));
340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404
  }
}

template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int quant_axis,
                  framework::Tensor* out) {
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int64_t num = in.numel();
    auto in_dims = in.dims();
    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantKernelQuantAxis0<T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], out_data);
    } else {
      int quant_stride = 1;
      for (int i = quant_axis + 1; i < in_dims.size(); i++) {
        quant_stride *= in_dims[i];
      }
      int64_t block_size =
          std::min(num, static_cast<int64_t>(ctx.GetMaxThreadsPerBlock() / 4));
      int64_t max_threads =
          ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
      const int64_t max_blocks = std::max(((max_threads - 1) / block_size + 1),
                                          static_cast<int64_t>(1));

      const int64_t grid_size =
          std::min(max_blocks, (num + block_size - 1) / block_size);

      ChannelClipAndQuantKernelQuantAxisN<T><<<grid_size, block_size>>>(
          in_data, scale_data, bin_cnt, num, in_dims[quant_axis], quant_stride,
          out_data);
    }
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
                                               float>;

template <typename T>
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
                                            const T* last_scale,
                                            const int64_t* iter,
                                            const int window_size, T* scale_arr,
                                            T* out_scale, int* need_find_max,
                                            int* out_size) {
  int it = iter[0];
  int idx = it % window_size;
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
405 406
  if (fabs(static_cast<typename QuantizeDataType<T>::type>(removed - max)) <
      1e-6) {
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
  } else {
    need_find_max[0] = 0;
  }
}

template <typename T>
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
    const auto gpu_place = ctx.GetPlace();

    T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);

    framework::Tensor need_find_max, out_size;
    int* find_max = need_find_max.mutable_data<int>({1}, gpu_place);
    int* out_size_data = out_size.mutable_data<int>({1}, gpu_place);

    FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
        cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
        window_size, scale_arr, out_scale_data, find_max, out_size_data);

    int g_find_max;
    memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
                 sizeof(int), ctx.stream());
    ctx.Wait();
    if (g_find_max) {
      int len;
      memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
                   sizeof(int), ctx.stream());
      ctx.Wait();
      FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
                                                          out_scale_data);
    }
  }
};

template <typename T>
__global__ void FindMovingAverageAbsMaxKernel(const T* in_state,
                                              const T* in_accum,
                                              const T* cur_scale, const T rate,
                                              T* out_state, T* out_accum,
                                              T* out_scale) {
  T state = rate * (*in_state) + T(1.0f);
  T accum = rate * (*in_accum) + (*cur_scale);
  *out_state = state;
  *out_accum = accum;
  *out_scale = accum / state;
}

template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;

template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
    const auto gpu_place = ctx.GetPlace();

    T rate_t = static_cast<T>(rate);
    T* out_state_data = out_state->mutable_data<T>(gpu_place);
    T* out_accum_data = out_accum->mutable_data<T>(gpu_place);
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);

    FindMovingAverageAbsMaxKernel<T><<<1, 1, 0, ctx.stream()>>>(
        in_state.data<T>(), in_accum.data<T>(), cur_scale, rate_t,
        out_state_data, out_accum_data, out_scale_data);
  }
};

// ChannelClipAndQuantDequantKernel for quant_axis is 0
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis0(
    const T* in, const T* scale, const int bin_cnt, const int n, const int c,
    T* out) {
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
  T inv_s = inverse(s);

  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

// ChannelClipAndQuantDequantKernel for quant_axis is 1
template <typename T>
__global__ void ChannelClipAndQuantDequantKernelQuantAxis1(
    const T* in, const T* scale, const int bin_cnt, const int n, const int cin,
    const int cout, T* out) {
  T s = scale[blockIdx.x % cout];
  T inv_s = inverse(s);

  int wh_size = n / (cin * cout);
  const T* in_c = in + blockIdx.x * wh_size;
  T* out_c = out + blockIdx.x * wh_size;

  for (int i = threadIdx.x; i < wh_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt * inv_s * v;
    out_c[i] = round(v) * s / bin_cnt;
  }
}

template <typename T>
struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int quant_axis,
                  framework::Tensor* out) {
    // At present, channelwise quantization supports conv2d, depthwise_conv2d
    // conv2d_transpose and mul
    PADDLE_ENFORCE_EQ(
        quant_axis == 0 || quant_axis == 1, true,
        platform::errors::InvalidArgument("'quant_axis' should be 0 or 1, but "
                                          "the received is %d",
                                          quant_axis));

    int num = in.numel();
    auto in_dims = in.dims();

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    if (quant_axis == 0) {
      int grid = in_dims[0];
      int block = 1024;
      ChannelClipAndQuantDequantKernelQuantAxis0<
          T><<<grid, block, 0, ctx.stream()>>>(in_data, scale_data, bin_cnt,
                                               num, in_dims[0], out_data);
    } else if (quant_axis == 1) {
      int grid = in_dims[0] * in_dims[1];
      int block = 1024;

      ChannelClipAndQuantDequantKernelQuantAxis1<
          T><<<grid, block, 0, ctx.stream()>>>(
          in_data, scale_data, bin_cnt, num, in_dims[0], in_dims[1], out_data);
    }
  }
};

template struct ChannelClipFakeQuantDequantFunctor<platform::CUDADeviceContext,
                                                   float>;

}  // namespace operators
}  // namespace paddle