fake_quantize_op.cu 6.1 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16
#include "paddle/fluid/memory/memcpy.h"
视言's avatar
视言 已提交
17 18 19 20 21 22 23
#include "paddle/fluid/operators/fake_quantize_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"

namespace paddle {
namespace operators {

template <typename T>
D
Dang Qingqing 已提交
24
__global__ void FindAbsMaxKernel(const T* in, const int n, T* out) {
视言's avatar
视言 已提交
25 26 27
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

28
  extern __shared__ T shared_max_data[];
视言's avatar
视言 已提交
29
  if (gridDim.x > 1) {
30
    shared_max_data[tid] = T(0);
视言's avatar
视言 已提交
31 32
    for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
      T tmp = fabs(in[i]);
33 34
      if (tmp > shared_max_data[tid]) {
        shared_max_data[tid] = tmp;
视言's avatar
视言 已提交
35 36 37 38
      }
    }
  } else {
    if (bid < n) {
39
      shared_max_data[tid] = fabs(in[bid]);
视言's avatar
视言 已提交
40
    } else {
41
      shared_max_data[tid] = T(0);
视言's avatar
视言 已提交
42 43 44 45 46
    }
  }
  __syncthreads();

  for (int i = blockDim.x / 2; i > 0; i >>= 1) {
47 48
    if (tid < i && (shared_max_data[tid] < shared_max_data[tid + i])) {
      shared_max_data[tid] = shared_max_data[tid + i];
视言's avatar
视言 已提交
49 50 51 52
    }
    __syncthreads();
  }
  if (tid == 0) {
53
    out[blockIdx.x] = shared_max_data[0];
视言's avatar
视言 已提交
54 55 56
  }
}

D
Dang Qingqing 已提交
57 58
template <typename T>
struct FindAbsMaxFunctor<platform::CUDADeviceContext, T> {
59 60
  void operator()(const platform::CUDADeviceContext& ctx, const T* in,
                  const int num, T* out) {
D
Dang Qingqing 已提交
61 62 63 64
    int block = 1024;
    int grid = (block - 1 + num) / block;
    grid = (grid > block) ? block : grid;

65
    framework::Tensor max;
D
Dang Qingqing 已提交
66 67
    T* max_data =
        max.mutable_data<T>(framework::make_ddim({grid}), ctx.GetPlace());
68 69 70
    FindAbsMaxKernel<T><<<grid, block, 1024 * sizeof(T), ctx.stream()>>>(
        in, num, max_data);
    FindAbsMaxKernel<T><<<1, block, 1024 * sizeof(T), ctx.stream()>>>(
D
Dang Qingqing 已提交
71 72 73
        max_data, grid, out);
  }
};
视言's avatar
视言 已提交
74

75 76
template struct FindAbsMaxFunctor<platform::CUDADeviceContext, float>;

视言's avatar
视言 已提交
77
template <typename T>
D
Dang Qingqing 已提交
78 79
__global__ void ClipAndQuantKernel(const T* in, const T* scale,
                                   const int bin_cnt, const int n, T* out) {
视言's avatar
视言 已提交
80 81 82
  int bid = threadIdx.x + blockIdx.x * blockDim.x;
  int tid = threadIdx.x;

D
Dang Qingqing 已提交
83
  T s = scale[0];
视言's avatar
视言 已提交
84
  for (int i = bid; i < n; i += blockDim.x * gridDim.x) {
D
Dang Qingqing 已提交
85 86 87 88 89
    T x = in[bid];
    T v = x > s ? s : x;
    v = v < -s ? -s : v;
    v = bin_cnt / s * v;
    out[bid] = round(v);
视言's avatar
视言 已提交
90 91 92 93
  }
}

template <typename T>
94 95 96 97 98 99 100
__global__ void FindRangeAbsMaxAndFillArray(const T* cur_scale,
                                            const T* last_scale,
                                            const int64_t* iter,
                                            const int window_size, T* scale_arr,
                                            T* out_scale, int* need_find_max,
                                            int* out_size) {
  int it = iter[0];
D
Dang Qingqing 已提交
101
  int idx = it % window_size;
102 103 104 105 106 107 108 109 110 111 112
  T removed = scale_arr[idx];
  T cur = cur_scale[0];
  scale_arr[idx] = cur;
  T max = last_scale[0];
  out_scale[0] = max < cur ? cur : max;
  if (fabs(removed - max) < 1e-6) {
    need_find_max[0] = 1;
    out_size[0] = it > window_size ? window_size : it;
  } else {
    need_find_max[0] = 0;
  }
视言's avatar
视言 已提交
113 114 115
}

template <typename T>
D
Dang Qingqing 已提交
116 117 118 119 120 121 122
struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& cur_scale,
                  const framework::Tensor& last_scale,
                  const framework::Tensor& iter, const int window_size,
                  framework::Tensor* scales_arr, framework::Tensor* out_scale) {
    auto& gpu_place = boost::get<platform::CUDAPlace>(ctx.GetPlace());
123
    T* scale_arr = scales_arr->mutable_data<T>(gpu_place);
D
Dang Qingqing 已提交
124
    T* out_scale_data = out_scale->mutable_data<T>(gpu_place);
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

    framework::Tensor need_find_max, out_size;
    int* find_max = need_find_max.mutable_data<int>(gpu_place);
    int* out_size_data = out_size.mutable_data<int>(gpu_place);

    FindRangeAbsMaxAndFillArray<T><<<1, 1, 0, ctx.stream()>>>(
        cur_scale.data<T>(), last_scale.data<T>(), iter.data<int64_t>(),
        window_size, scale_arr, out_scale_data, find_max, out_size_data);

    int g_find_max;
    memory::Copy(platform::CPUPlace(), &g_find_max, gpu_place, find_max,
                 sizeof(int), 0);
    if (g_find_max) {
      int len;
      memory::Copy(platform::CPUPlace(), &len, gpu_place, out_size_data,
                   sizeof(int), 0);
      FindAbsMaxFunctor<platform::CUDADeviceContext, T>()(ctx, scale_arr, len,
                                                          out_scale_data);
视言's avatar
视言 已提交
143 144
    }
  }
D
Dang Qingqing 已提交
145
};
视言's avatar
视言 已提交
146

147 148
template struct FindRangeAbsMaxFunctor<platform::CUDADeviceContext, float>;

D
Dang Qingqing 已提交
149
template <typename T>
150 151 152 153
struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, T> {
  void operator()(const platform::CUDADeviceContext& ctx,
                  const framework::Tensor& in, const framework::Tensor& scale,
                  const int bin_cnt, framework::Tensor* out) {
D
Dang Qingqing 已提交
154 155 156 157
    int num = in.numel();
    int block = 1024;
    int grid = (block - 1 + num) / block;

158 159
    const T* in_data = in.data<T>();
    const T* scale_data = scale.data<T>();
D
Dang Qingqing 已提交
160 161 162 163
    T* out_data = out->mutable_data<T>(ctx.GetPlace());

    ClipAndQuantKernel<T><<<grid, block, 0, ctx.stream()>>>(
        in_data, scale_data, bin_cnt, num, out_data);
视言's avatar
视言 已提交
164 165 166
  }
};

167 168
template struct ClipAndFakeQuantFunctor<platform::CUDADeviceContext, float>;

视言's avatar
视言 已提交
169 170 171
}  // namespace operators
}  // namespace paddle

172 173 174 175 176 177
namespace ops = paddle::operators;
using CUDA = paddle::platform::CUDADeviceContext;
REGISTER_OP_CUDA_KERNEL(fake_quantize_abs_max,
                        ops::FakeQuantizeAbsMaxKernel<CUDA, float>);
REGISTER_OP_CUDA_KERNEL(fake_quantize_range_abs_max,
                        ops::FakeQuantizeRangeAbsMaxKernel<CUDA, float>);