quantize_linear_op.cu 5.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <string>
16

17 18 19
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/fake_quantize_op.cu.h"
#include "paddle/fluid/operators/quantize_linear_op.h"
20
#include "paddle/phi/backends/gpu/gpu_primitives.h"
21

22 23
using float16 = paddle::platform::float16;

24 25 26
namespace paddle {
namespace operators {

27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
template <typename T>
__global__ void KeDequantize(
    const T* in, const T* scale, T max_range, int64_t num, T* out) {
  int64_t idx = threadIdx.x + blockIdx.x * blockDim.x;
  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
    out[i] = in[i] * scale[0] / max_range;
  }
}

template <typename T>
__global__ void DequantizeOneScaleQuantAxisN(const T* in,
                                             const T* scale,
                                             const T max_range,
                                             const int64_t num,
                                             const int n_scales,
                                             const int quant_stride,
                                             T* out) {
  int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
  for (int64_t i = idx; i < num; i += blockDim.x * gridDim.x) {
    T s = scale[(i / quant_stride) % n_scales];
    out[i] = in[i] * s / max_range;
  }
}

template <typename T>
struct DequantizeFunctor<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& dev_ctx,
                  const phi::DenseTensor* in,
                  const phi::DenseTensor* scale,
                  T max_range,
                  phi::DenseTensor* out) {
    const T* in_data = in->data<T>();
    const T* scale_factor = scale->data<T>();
    T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));

    int64_t num = in->numel();
    int64_t block_size = std::min(
        num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
    int64_t max_threads =
        dev_ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
    const int64_t max_blocks =
        std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
    const int64_t grid_size =
        std::min(max_blocks, (num + block_size - 1) / block_size);
    KeDequantize<T><<<grid_size, block_size, 0, dev_ctx.stream()>>>(
        in_data, scale_factor, max_range, num, out_data);
  }
};

76
template <typename T>
L
Leo Chen 已提交
77 78
struct ChannelDequantizeFunctorV2<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& dev_ctx,
79 80
                  const phi::DenseTensor* in,
                  const phi::DenseTensor* scale,
81 82
                  T max_range,
                  const int quant_axis,
83
                  phi::DenseTensor* out) {
84 85
    auto in_dims = in->dims();
    const T* in_data = in->data<T>();
86
    T* out_data = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102
    int64_t num = in->numel();
    const T* scale_factor = scale->data<T>();
    int64_t block_size = std::min(
        num, static_cast<int64_t>(dev_ctx.GetMaxThreadsPerBlock() / 4));
    int64_t max_threads =
        dev_ctx.GetMaxPhysicalThreadCount();  // SM * block_per_SM
    const int64_t max_blocks =
        std::max(((max_threads - 1) / block_size + 1), static_cast<int64_t>(1));
    const int64_t grid_size =
        std::min(max_blocks, (num + block_size - 1) / block_size);

    int quant_stride = 1;
    for (int i = quant_axis + 1; i < in_dims.size(); i++) {
      quant_stride *= in_dims[i];
    }

103
    DequantizeOneScaleQuantAxisN<T>
104 105 106 107 108 109 110
        <<<grid_size, block_size, 0, dev_ctx.stream()>>>(in_data,
                                                         scale_factor,
                                                         max_range,
                                                         num,
                                                         in_dims[quant_axis],
                                                         quant_stride,
                                                         out_data);
111 112 113
  }
};

114 115 116 117
template struct DequantizeFunctor<phi::GPUContext, phi::dtype::float16>;
template struct DequantizeFunctor<phi::GPUContext, float>;
template struct DequantizeFunctor<phi::GPUContext, double>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float16>;
L
Leo Chen 已提交
118 119
template struct ChannelDequantizeFunctorV2<phi::GPUContext, float>;
template struct ChannelDequantizeFunctorV2<phi::GPUContext, double>;
120 121 122 123 124

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
L
Leo Chen 已提交
125
using CUDA = phi::GPUContext;
126
REGISTER_OP_CUDA_KERNEL(dequantize_linear,
127 128 129 130
                        ops::DeQuantizeLinearKernel<CUDA, float>,
                        ops::DeQuantizeLinearKernel<CUDA, float16>,
                        ops::DeQuantizeLinearKernel<CUDA, int8_t>,
                        ops::DeQuantizeLinearKernel<CUDA, double>);
131 132

REGISTER_OP_CUDA_KERNEL(quantize_linear,
133 134
                        ops::QuantizeLinearKernel<CUDA, float>,
                        ops::QuantizeLinearKernel<CUDA, float16>);