fc_functor.cu 9.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include <algorithm>
16

17 18 19 20 21 22 23
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/fc_functor.h"

namespace phi {
namespace funcs {

M
ming1753 已提交
24 25
using float16 = phi::dtype::float16;

26 27 28 29 30 31 32 33 34 35 36 37 38
template <typename T>
struct FcTypeTraits;

template <>
struct FcTypeTraits<float> {
  typedef float4 Type;
};

template <>
struct FcTypeTraits<double> {
  typedef double4 Type;
};

M
ming1753 已提交
39 40 41 42 43 44 45 46
#if defined(PADDLE_WITH_CUDA)
#include <cuda_fp16.h>

template <>
struct FcTypeTraits<float16> {
  typedef half2 Type;
};
#else
M
ming1753 已提交
47 48 49 50 51 52 53 54
struct float16_4 {
  float16 x, y, z, w;
};

template <>
struct FcTypeTraits<float16> {
  typedef float16_4 Type;
};
M
ming1753 已提交
55
#endif
M
ming1753 已提交
56

57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97
template <typename T, bool DoRelu>
__global__ void bias_relu_v4(const int num, const T* bias, T* data, int K) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int bias_idx = tid % K;
    const T bias_ptr = bias[bias_idx];
    const T in_ptr = data[tid];
    T packed_val;
    packed_val.x = in_ptr.x + bias_ptr.x;
    packed_val.y = in_ptr.y + bias_ptr.y;
    packed_val.z = in_ptr.z + bias_ptr.z;
    packed_val.w = in_ptr.w + bias_ptr.w;
    if (DoRelu) {
      packed_val.x = fmaxf(0.f, packed_val.x);
      packed_val.y = fmaxf(0.f, packed_val.y);
      packed_val.z = fmaxf(0.f, packed_val.z);
      packed_val.w = fmaxf(0.f, packed_val.w);
    }
    data[tid] = packed_val;
  }
}

template <typename T, bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N, const T* bias, T* data) {
  int offset = blockIdx.x * N;

  for (int i = threadIdx.x; i < N; i += BlockDim) {
    T temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
    temp = __ldg(data + offset + i) + __ldg(bias + i);
#else
    temp = data[offset + i] + bias[i];
#endif
    if (DoRelu) {
      data[offset + i] = static_cast<int>(temp > 0) * temp;
    } else {
      data[offset + i] = temp;
    }
  }
}

M
ming1753 已提交
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
template <typename T>
void AddReluKernel(
    gpuStream_t stream, const int M, const int N, T* Y, const T* B, bool relu) {
  if (N % 4 == 0) {
    const int threads = 256;
    const int num = M * N / 4;
    const int blocks = (num + threads - 1) / threads;
    typedef typename FcTypeTraits<T>::Type trans_type;
    auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
    auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
    if (relu) {
      bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    } else {
      bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    }
  } else {
    const int threads = 256;
    const int blocks = M;

    if (relu) {
      InplaceAddReluKernel<T, true, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    } else {
      InplaceAddReluKernel<T, false, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    }
  }
}

M
ming1753 已提交
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151
#if defined(PADDLE_WITH_CUDA)
template <bool DoRelu>
__global__ void bias_relu_v2(const int num,
                             const half2* bias,
                             half2* data,
                             int K) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;

  if (tid < num) {
    int bias_idx = tid % K;
    const half2 bias_ptr = bias[bias_idx];
    const half2 in_ptr = data[tid];
    half2 packed_val;
#if __CUDA_ARCH__ >= 530
    packed_val = __hadd2(bias_ptr, in_ptr);
#else
    packed_val.x = __hadd(bias_ptr.x, in_ptr.x);
    packed_val.y = __hadd(bias_ptr.y, in_ptr.y);
#endif
    if (DoRelu) {
#if __CUDA_ARCH__ >= 800
      packed_val = __hmax2(__half2(0, 0), packed_val);
#elif __CUDA_ARCH__ >= 530
152
      packed_val = __hmul2(__hgt2(packed_val, __half2(0, 0)), packed_val);
M
ming1753 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226
#else
      packed_val.x = static_cast<int>(static_cast<float>(packed_val.x) > 0) *
                     static_cast<float>(packed_val.x);
      packed_val.y = static_cast<int>(static_cast<float>(packed_val.y) > 0) *
                     static_cast<float>(packed_val.y);
#endif
    }
    data[tid] = packed_val;
  }
}

template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
                                     const half* bias,
                                     half* data) {
  int offset = blockIdx.x * N;
  for (int i = threadIdx.x; i < N; i += BlockDim) {
    half temp;
#if defined(__HIPCC__) || __CUDA_ARCH__ >= 350
    temp = __hadd(__ldg(data + offset + i), __ldg(bias + i));
#else
    temp = __hadd(data[offset + i], bias[i]);
#endif
    if (DoRelu) {
#if __CUDA_ARCH__ >= 800
      data[offset + i] = __hmax(0, temp);
#elif __CUDA_ARCH__ >= 530
      data[offset + i] = __hmul(__hgt(temp, 0), temp);
#else
      data[offset + i] = static_cast<int>(static_cast<float>(temp) > 0) *
                         static_cast<float>(temp);
#endif
    } else {
      data[offset + i] = temp;
    }
  }
}

template <>
void AddReluKernel(cudaStream_t stream,
                   const int M,
                   const int N,
                   float16* Y,
                   const float16* B,
                   bool relu) {
  if (N % 2 == 0) {
    const int threads = 256;
    const int num = M * N / 2;
    const int blocks = (num + threads - 1) / threads;
    typedef typename FcTypeTraits<float16>::Type trans_type;
    auto* bias_ptr_v2 = reinterpret_cast<const trans_type*>(B);
    auto* data_ptr_v2 = reinterpret_cast<trans_type*>(Y);
    if (relu) {
      bias_relu_v2<true><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v2, data_ptr_v2, N / 2);
    } else {
      bias_relu_v2<false><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v2, data_ptr_v2, N / 2);
    }
  } else {
    const int threads = 256;
    const int blocks = M;
    auto* halfB = reinterpret_cast<const half*>(B);
    auto* halfY = reinterpret_cast<half*>(Y);
    if (relu) {
      InplaceAddReluKernel<true, threads>
          <<<blocks, threads, 0, stream>>>(N, halfB, halfY);
    } else {
      InplaceAddReluKernel<false, threads>
          <<<blocks, threads, 0, stream>>>(N, halfB, halfY);
    }
  }
}
#else
M
ming1753 已提交
227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276
template <bool DoRelu, int BlockDim>
__global__ void InplaceAddReluKernel(const int N,
                                     const float16* bias,
                                     float16* data) {
  int offset = blockIdx.x * N;
  for (int i = threadIdx.x; i < N; i += BlockDim) {
    float16 temp;
    temp = data[offset + i] + bias[i];
    if (DoRelu) {
      data[offset + i] = fmaxf(0.f, temp);
    } else {
      data[offset + i] = temp;
    }
  }
}

template <>
void AddReluKernel(gpuStream_t stream,
                   const int M,
                   const int N,
                   float16* Y,
                   const float16* B,
                   bool relu) {
  if (N % 4 == 0) {
    const int threads = 256;
    const int num = M * N / 4;
    const int blocks = (num + threads - 1) / threads;
    typedef typename FcTypeTraits<float16>::Type trans_type;
    auto* bias_ptr_v4 = reinterpret_cast<const trans_type*>(B);
    auto* data_ptr_v4 = reinterpret_cast<trans_type*>(Y);
    if (relu) {
      bias_relu_v4<trans_type, true><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    } else {
      bias_relu_v4<trans_type, false><<<blocks, threads, 0, stream>>>(
          num, bias_ptr_v4, data_ptr_v4, N / 4);
    }
  } else {
    const int threads = 256;
    const int blocks = M;

    if (relu) {
      InplaceAddReluKernel<true, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    } else {
      InplaceAddReluKernel<false, threads>
          <<<blocks, threads, 0, stream>>>(N, B, Y);
    }
  }
}
M
ming1753 已提交
277
#endif
M
ming1753 已提交
278

279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294
template <typename DeviceContext, typename T>
void FCFunctor<DeviceContext, T>::operator()(const DeviceContext& context,
                                             const int M,
                                             const int N,
                                             const int K,
                                             const T* X,
                                             const T* W,
                                             T* Y,
                                             const T* B,
                                             bool relu,
                                             bool padding_weights) {
  PADDLE_ENFORCE_EQ(padding_weights,
                    false,
                    errors::PermissionDenied(
                        "Weight padding in fc can not be used in GPU scope."));
  auto blas = phi::funcs::GetBlas<DeviceContext, T>(context);
Y
Yuanle Liu 已提交
295 296
  blas.GEMM(CblasNoTrans,
            CblasNoTrans,
297 298 299 300 301 302 303
            M,
            N,
            K,
            static_cast<T>(1.0),
            X,
            W,
            static_cast<T>(0.0),
Y
Yuanle Liu 已提交
304
            Y);
305 306 307 308 309
  if (B == NULL) {
    return;
  }

  // M * N
M
ming1753 已提交
310
  AddReluKernel(context.stream(), M, N, Y, B, relu);
311 312
}

M
ming1753 已提交
313
template class FCFunctor<GPUContext, float16>;
314 315 316 317 318
template class FCFunctor<GPUContext, float>;
template class FCFunctor<GPUContext, double>;

}  // namespace funcs
}  // namespace phi