fc_functor.cu 11.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
16

17
#include "paddle/fluid/platform/device_context.h"
18
#include "paddle/phi/kernels/funcs/aligned_vector.h"
19 20 21 22 23 24
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"

namespace phi {
namespace funcs {

M
ming1753 已提交
25 26
using float16 = phi::dtype::float16;

27 28 29 30 31 32 33 34 35 36 37 38 39
template <typename T>
struct FcTypeTraits;

template <>
struct FcTypeTraits<float> {
  typedef float4 Type;
};

template <>
struct FcTypeTraits<double> {
  typedef double4 Type;
};

M
ming1753 已提交
40 41 42 43 44 45 46 47
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>

template <>
struct FcTypeTraits<float16> {
  typedef half2 Type;
};
#else
M
ming1753 已提交
48 49 50 51 52 53 54 55
struct float16_4 {
  float16 x, y, z, w;
};

template <>
struct FcTypeTraits<float16> {
  typedef float16_4 Type;
};
M
ming1753 已提交
56
#endif
M
ming1753 已提交
57

58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98
template <typename T, bool DoRelu>
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int bias_idx = tid % K;
    const T bias_ptr = bias[bias_idx];
    const T in_ptr = data[tid];
    T packed_val;
    packed_val.x = in_ptr.x + bias_ptr.x;
    packed_val.y = in_ptr.y + bias_ptr.y;
    packed_val.z = in_ptr.z + bias_ptr.z;
    packed_val.w = in_ptr.w + bias_ptr.w;
    if (DoRelu) {
      packed_val.x = fmaxf(0.f, packed_val.x);
      packed_val.y = fmaxf(0.f, packed_val.y);
      packed_val.z = fmaxf(0.f, packed_val.z);
      packed_val.w = fmaxf(0.f, packed_val.w);
    }
    data[tid] = packed_val;
  }
}

template <typename T, bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
  int offset = blockIdx.x * N;

  for (int i = threadIdx.x; i < N; i += BlockDim) {
    T temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
    temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
    temp = data[offset + i] + bias[i];
#endif
    if (DoRelu) {
      data[offset + i] = static_cast<int>(temp > 0) * temp;
    } else {
      data[offset + i] = temp;
    }
  }
}

M
ming1753 已提交
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
template <typename T>
void AddReluKernel(
    gpuStream_t stream, const int M, const int N, T* Y, const T* B, bool relu) {
  if (N % 4 == 0) {
    const int threads = 256;
    const int num = M * N / 4;
    const int blocks = (num + threads - 1) / threads;
    typedef typename FcTypeTraits<T>::Type trans_type;
    auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
    auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
    if (relu) {
      bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    } else {
      bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    }
  } else {
    const int threads = 256;
    const int blocks = M;

    if (relu) {
      InplaceAddReluKernel<T, true, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    } else {
      InplaceAddReluKernel<T, false, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    }
  }
}

M
ming1753 已提交
130
#if defined(PADDLE_WITH_CUDA)
131 132 133 134 135 136 137 138 139 140
template <bool DoRelu, int Half2VecSize>
__global__ void bias_relu_v4_half2(const int num,
                                   const half2* bias,
                                   half2* data,
                                   int K) {
  using LoadT = phi::AlignedVector<half2, Half2VecSize>;
  LoadT data_vec;
  LoadT bias_vec;
  const int32_t global_thread_idx = blockIdx.x * blockDim.x + threadIdx.x;
  const int32_t grid_stride = gridDim.x * blockDim.x;
M
ming1753 已提交
141

142 143 144 145 146 147 148 149 150
  for (int32_t linear_idx = global_thread_idx * Half2VecSize; linear_idx < num;
       linear_idx += grid_stride * Half2VecSize) {
    phi::Load<half2, Half2VecSize>(&data[linear_idx], &data_vec);
    const int bias_idx = linear_idx % K;
    phi::Load<half2, Half2VecSize>(&bias[bias_idx], &bias_vec);

#pragma unroll
    for (int unroll_idx = 0; unroll_idx < Half2VecSize; unroll_idx++) {
// Do biasAdd
M
ming1753 已提交
151
#if __CUDA_ARCH__ >= 530
152 153
      data_vec[unroll_idx] =
          __hadd2(data_vec[unroll_idx], bias_vec[unroll_idx]);
M
ming1753 已提交
154
#else
155 156 157 158
      data_vec[unroll_idx].x =
          __hadd(data_vec[unroll_idx].x, bias_vec[unroll_idx].x);
      data_vec[unroll_idx].y =
          __hadd(data_vec[unroll_idx].y, bias_vec[unroll_idx].y);
M
ming1753 已提交
159
#endif
160 161 162

      // Do relu
      if (DoRelu) {
M
ming1753 已提交
163
#if __CUDA_ARCH__ >= 800
164
        data_vec[unroll_idx] = __hmax2(__half2(0, 0), data_vec[unroll_idx]);
M
ming1753 已提交
165
#elif __CUDA_ARCH__ >= 530
166
        data_vec[unroll_idx] = __hmul2(
167
            __hgt2(data_vec[unroll_idx], __half2(0, 0)), data_vec[unroll_idx]);
M
ming1753 已提交
168
#else
169 170 171 172 173 174
        data_vec[unroll_idx].x =
            static_cast<int>(static_cast<float>(data_vec[unroll_idx].x) > 0) *
            static_cast<float>(data_vec[unroll_idx].x);
        data_vec[unroll_idx].y =
            static_cast<int>(static_cast<float>(data_vec[unroll_idx].y) > 0) *
            static_cast<float>(data_vec[unroll_idx].y);
M
ming1753 已提交
175
#endif
176
      }
M
ming1753 已提交
177
    }
178
    phi::Store<half2, Half2VecSize>(data_vec, &data[linear_idx]);
M
ming1753 已提交
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208
  }
}

template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
                                     const half* bias,
                                     half* data) {
  int offset = blockIdx.x * N;
  for (int i = threadIdx.x; i < N; i += BlockDim) {
    half temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
    temp = __hadd(__ldg(data + offset + i), __ldg(bias + i));
#else
    temp = __hadd(data[offset + i], bias[i]);
#endif
    if (DoRelu) {
#if __CUDA_ARCH__ >= 800
      data[offset + i] = __hmax(0, temp);
#elif __CUDA_ARCH__ >= 530
      data[offset + i] = __hmul(__hgt(temp, 0), temp);
#else
      data[offset + i] = static_cast<int>(static_cast<float>(temp) > 0) *
                         static_cast<float>(temp);
#endif
    } else {
      data[offset + i] = temp;
    }
  }
}

209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255
/**
 * brief: Launch BiasAddReluKernel with relu or not.
 **/
template <int Half2VecSize>
void LaunchBiasAddReluHalf2Kernel(cudaStream_t stream,
                                  const int32_t rows,
                                  const int32_t cols,
                                  float16* Y,
                                  const float16* B,
                                  bool relu) {
  const int threads = 256;
  const int vec_num = rows * cols / (Half2VecSize * 2);
  const int half2_num = rows * cols / 2;
  const int blocks = (vec_num + threads - 1) / threads;
  // Here reinterpret_cast to half2 type.
  typedef typename FcTypeTraits<float16>::Type trans_type;
  auto* bias_half2_ptr = reinterpret_cast<const trans_type*>(B);
  auto* data_half2_ptr = reinterpret_cast<trans_type*>(Y);
  if (relu) {
    bias_relu_v4_half2<true, Half2VecSize><<<blocks, threads, 0, stream>>>(
        half2_num, bias_half2_ptr, data_half2_ptr, cols / 2);
  } else {
    bias_relu_v4_half2<false, Half2VecSize><<<blocks, threads, 0, stream>>>(
        half2_num, bias_half2_ptr, data_half2_ptr, cols / 2);
  }
}

/**
 * brief: Dispatch BiasAddReluKernel half2 type with 8 / 4 / 2 vecsize.
 **/
void DispatchBiasAddReluKernelHalf2VecSize(cudaStream_t stream,
                                           const int32_t rows,
                                           const int32_t cols,
                                           float16* Y,
                                           const float16* B,
                                           bool relu) {
  // Half Max Vecsize is 128 / 16 = 8, since we use half2 type, here
  // Half2VecSize need divide 2.
  if (cols % 8 == 0) {
    LaunchBiasAddReluHalf2Kernel<4>(stream, rows, cols, Y, B, relu);
  } else if (cols % 4 == 0) {
    LaunchBiasAddReluHalf2Kernel<2>(stream, rows, cols, Y, B, relu);
  } else {
    LaunchBiasAddReluHalf2Kernel<1>(stream, rows, cols, Y, B, relu);
  }
}

M
ming1753 已提交
256 257 258 259 260 261 262 263
template <>
void AddReluKernel(cudaStream_t stream,
                   const int M,
                   const int N,
                   float16* Y,
                   const float16* B,
                   bool relu) {
  if (N % 2 == 0) {
264
    DispatchBiasAddReluKernelHalf2VecSize(stream, M, N, Y, B, relu);
M
ming1753 已提交
265 266 267 268 269 270 271 272 273 274 275 276 277 278 279
  } else {
    const int threads = 256;
    const int blocks = M;
    auto* halfB = reinterpret_cast<const half*>(B);
    auto* halfY = reinterpret_cast<half*>(Y);
    if (relu) {
      InplaceAddReluKernel<true, threads>
          <<<blocks, threads, 0, stream>>>(N, halfB, halfY);
    } else {
      InplaceAddReluKernel<false, threads>
          <<<blocks, threads, 0, stream>>>(N, halfB, halfY);
    }
  }
}
#else
M
ming1753 已提交
280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329
template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
                                     const float16* bias,
                                     float16* data) {
  int offset = blockIdx.x * N;
  for (int i = threadIdx.x; i < N; i += BlockDim) {
    float16 temp;
    temp = data[offset + i] + bias[i];
    if (DoRelu) {
      data[offset + i] = fmaxf(0.f, temp);
    } else {
      data[offset + i] = temp;
    }
  }
}

template <>
void AddReluKernel(gpuStream_t stream,
                   const int M,
                   const int N,
                   float16* Y,
                   const float16* B,
                   bool relu) {
  if (N % 4 == 0) {
    const int threads = 256;
    const int num = M * N / 4;
    const int blocks = (num + threads - 1) / threads;
    typedef typename FcTypeTraits<float16>::Type trans_type;
    auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
    auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
    if (relu) {
      bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    } else {
      bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    }
  } else {
    const int threads = 256;
    const int blocks = M;

    if (relu) {
      InplaceAddReluKernel<true, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    } else {
      InplaceAddReluKernel<false, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    }
  }
}
M
ming1753 已提交
330
#endif
M
ming1753 已提交
331

332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365
template <typename DeviceContext, typename T>
void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
                                             const int M,
                                             const int N,
                                             const int K,
                                             const T* X,
                                             const T* W,
                                             T* Y,
                                             const T* B,
                                             bool relu,
                                             bool padding_weights) {
  PADDLE_ENFORCE_EQ(padding_weights,
                    false,
                    errors::PermissionDenied(
                        "Weight padding in fc can not be used in GPU scope."));
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
  blas.GEMM(false,
            false,
            M,
            N,
            K,
            static_cast<T>(1.0),
            X,
            K,
            W,
            N,
            static_cast<T>(0.0),
            Y,
            N);
  if (B == NULL) {
    return;
  }

  // M * N
M
ming1753 已提交
366
  AddReluKernel(context.stream(), M, N, Y, B, relu);
367 368
}

M
ming1753 已提交
369
template class FCFunctor<GPUContext, float16>;
370 371 372 373 374
template class FCFunctor<GPUContext, float>;
template class FCFunctor<GPUContext, double>;

}  // namespace funcs
}  // namespace phi