fake_quantize_op.cu 12.0 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16
#include "paddle/fluid/memory/memcpy.h"
视言's avatar
视言 已提交
17 18 19 20 21 22 23
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

template <typename T>
24
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
视言's avatar
视言 已提交
25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  extern __shared__ T shared_max_data[];
  if (gridDim.x > 1) {
    shared_max_data[tid] = T(0);
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
      T tmp = fabs(in[i]);
      if (tmp > shared_max_data[tid]) {
        shared_max_data[tid] = tmp;
      }
    }
  } else {
    if (bid < n) {
      shared_max_data[tid] = fabs(in[bid]);
    } else {
      shared_max_data[tid] = T(0);
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
47
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
视言's avatar
视言 已提交
48 49 50 51 52 53 54 55 56
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, T* out) {
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

    framework::Tensor max;
    T* max_data =
        max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
    FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, max_data);
    FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
        max_data, grid, out);
  }
};

template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;
视言's avatar
视言 已提交
76

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
template <typename T>
__global__ void FindChannelAbsMaxKernel(const T* in, const int n, const int c,
                                        T* out) {
  int tid = threadIdx.x;
  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  extern __shared__ T shared_max_data[];
  shared_max_data[tid] = T(0);
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T tmp = fabs(in_c[i]);
    if (tmp > shared_max_data[tid]) {
      shared_max_data[tid] = tmp;
    }
  }
  __syncthreads();
  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
    }
    __syncthreads();
  }
  if (tid == 0) {
    out[blockIdx.x] = shared_max_data[0];
  }
}

template <typename T>
struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, const int channel, T* out) {
    int block = 1024;
    int grid = channel;
    FindChannelAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, channel, out);
  }
};

template struct FindChannelAbsMaxFunctor<platform::CUDADeviceContext, float>;

视言's avatar
视言 已提交
116
template <typename T>
117 118
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
                                   const int bin_cnt, const int n, T* out) {
视言's avatar
视言 已提交
119 120 121
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

122
  T s = scale[0];
视言's avatar
视言 已提交
123
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
124
    T x = in[i];
125 126 127
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt / s * v;
128
    out[i] = round(v);
视言's avatar
视言 已提交
129 130 131
  }
}

132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
template <typename T>
__global__ void ClipAndQuantDequantKernel(const T* in, const T* scale,
                                          const int bin_cnt, const int n,
                                          T* out) {
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

  T s = scale[0];
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
    T x = in[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt / s * v;
    out[i] = round(v) * s / bin_cnt;
  }
}

149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
template <typename T>
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;

169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189
template <typename T>
struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantDequantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
  }
};

template struct ClipAndFakeQuantDequantFunctor<platform::CUDADeviceContext,
                                               float>;

190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
template <typename T>
__global__ void ChannelClipAndQuantKernel(const T* in, const T* scale,
                                          const int bin_cnt, const int n,
                                          const int c, T* out) {
  int tid = threadIdx.x;

  int channel_size = n / c;
  const T* in_c = in + blockIdx.x * channel_size;
  T* out_c = out + blockIdx.x * channel_size;

  T s = scale[blockIdx.x];
  for (int i = tid; i < channel_size; i += blockDim.x) {
    T x = in_c[i];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt / s * v;
    out_c[i] = round(v);
  }
}

template <typename T>
struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, const int channel,
                  framework::Tensor* out) {
    int num = in.numel();
    int block = 1024;
    int grid = channel;

    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ChannelClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, channel, out_data);
  }
};

template struct ChannelClipAndFakeQuantFunctor<platform::CUDADeviceContext,
                                               float>;

视言's avatar
视言 已提交
232
template <typename T>
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
                                            const T* last_scale,
                                            const int64_t* iter,
                                            const int window_size, T* scale_arr,
                                            T* out_scale, int* need_find_max,
                                            int* out_size) {
  int it = iter[0];
  int idx = it % window_size;
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
  if (fabs(removed - max) < 1e-6) {
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
视言's avatar
视言 已提交
249
  } else {
250
    need_find_max[0] = 0;
视言's avatar
视言 已提交
251 252 253 254
  }
}

template <typename T>
255 256 257 258 259 260
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
M
minqiyang 已提交
261 262
    const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());

263 264 265 266 267 268 269 270 271 272 273 274 275
    T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);

    framework::Tensor need_find_max, out_size;
    int* find_max = need_find_max.mutable_data<int>(gpu_place);
    int* out_size_data = out_size.mutable_data<int>(gpu_place);

    FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
        cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
        window_size, scale_arr, out_scale_data, find_max, out_size_data);

    int g_find_max;
    memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
276 277
                 sizeof(int), ctx.stream());
    ctx.Wait();
278 279 280
    if (g_find_max) {
      int len;
      memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
281 282
                   sizeof(int), ctx.stream());
      ctx.Wait();
283 284
      FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
                                                          out_scale_data);
视言's avatar
视言 已提交
285 286
    }
  }
287
};
视言's avatar
视言 已提交
288

289
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;
视言's avatar
视言 已提交
290

291 292 293 294 295 296 297 298 299 300 301 302
template <typename T>
struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in_accum,
                  const framework::Tensor& in_state, const T* cur_scale,
                  const float rate, framework::Tensor* out_state,
                  framework::Tensor* out_accum, framework::Tensor* out_scale) {
    const auto gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());

    T accum;
    T state;
    T scale;
303 304 305 306
    memory::Copy(platform::CPUPlace(), &accum, gpu_place, in_accum.data<T>(),
                 sizeof(T), ctx.stream());
    memory::Copy(platform::CPUPlace(), &state, gpu_place, in_state.data<T>(),
                 sizeof(T), ctx.stream());
307
    memory::Copy(platform::CPUPlace(), &scale, gpu_place, cur_scale, sizeof(T),
308 309
                 ctx.stream());
    ctx.Wait();
310 311 312 313 314
    state = rate * state + 1;
    accum = rate * accum + scale;
    scale = accum / state;

    memory::Copy(gpu_place, out_accum->mutable_data<T>(gpu_place),
315
                 platform::CPUPlace(), &accum, sizeof(T), ctx.stream());
316
    memory::Copy(gpu_place, out_state->mutable_data<T>(gpu_place),
317
                 platform::CPUPlace(), &state, sizeof(T), ctx.stream());
318
    memory::Copy(gpu_place, out_scale->mutable_data<T>(gpu_place),
319 320
                 platform::CPUPlace(), &scale, sizeof(T), ctx.stream());
    ctx.Wait();
321 322 323 324 325 326
  }
};

template struct FindMovingAverageAbsMaxFunctor<platform::CUDADeviceContext,
                                               float>;

视言's avatar
视言 已提交
327 328 329
}  // namespace operators
}  // namespace paddle

330 331 332 333
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
                        ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
Z
Zhen Wang 已提交
334 335
REGISTER_OP_CUDA_KERNEL(fake_channel_wise_quantize_abs_max,
                        ops::FakeChannelWiseQuantizeAbsMaxKernel<CUDA, float>);
336 337
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
                        ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);
338 339 340
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_moving_average_abs_max,
    ops::FakeQuantizeMovingAverageAbsMaxKernel<CUDA, float>);
Z
Zhen Wang 已提交
341 342
REGISTER_OP_CUDA_KERNEL(moving_average_abs_max_scale,
                        ops::MovingAverageAbsMaxScaleKernel<CUDA, float>);
343 344 345
REGISTER_OP_CUDA_KERNEL(
    fake_quantize_dequantize_moving_average_abs_max,
    ops::FakeQuantizeDequantizeMovingAverageAbsMaxKernel<CUDA, float>);