activation.cu 16.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include "lite/backends/cuda/math/activation.h"
#include "lite/backends/cuda/math/utils.h"

namespace paddle {
namespace lite {
namespace cuda {
namespace math {

template <typename T>
__global__ void relu_kernel(const int num,
26
                            const float alpha,
27 28 29 30 31 32 33 34 35 36 37 38 39
                            const T* input,
                            T* output) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < num) {
#if __CUDA_ARCH__ >= 350
    output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
                                              : __ldg(input + index) * alpha;
#else
    output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
  }
}

40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59
template <>
__global__ void relu_kernel<half>(const int num,
                                  const float alpha,
                                  const half* input,
                                  half* output) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < num) {
    const half kZero = __float2half(0.0f);
#if __CUDA_ARCH__ >= 530
    output[index] = __hgt(__ldg(input + index), kZero)
                        ? __ldg(input + index)
                        : __hmul(__ldg(input + index), __float2half(alpha));
#else
    output[index] = (__half2float(input[index]) > 0)
                        ? input[index]
                        : __float2half(__half2float(input[index]) * alpha);
#endif
  }
}

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75
template <typename T>
__global__ void bias_relu_kernel(const int num,
                                 const T alpha,
                                 const T* input,
                                 T* output) {
  int index = blockIdx.x * blockDim.x + threadIdx.x;
  if (index < num) {
#if __CUDA_ARCH__ >= 350
    output[index] = __ldg(input + index) >= 0 ? __ldg(input + index)
                                              : __ldg(input + index) * alpha;
#else
    output[index] = input[index] >= 0 ? input[index] : input[index] * alpha;
#endif
  }
}

Z
Zhaolong Xing 已提交
76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101
template <typename Dtype>
__global__ void bias_relu_int8_nhwc_kernel(int num,
                                           const float* in,
                                           const float* bias,
                                           Dtype* out,
                                           int N,
                                           int C,
                                           int H,
                                           int W,
                                           const float* scale,
                                           float alpha) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int idx = tid % C;
#if __CUDA_ARCH__ >= 350
    float temp = __ldg(in + tid) * __ldg(scale + idx) + __ldg(bias + idx);
    out[tid] =
        temp > 0 ? from_float<Dtype>(temp) : from_float<Dtype>(temp * alpha);
#else
    float temp = in[tid] * scale[idx] + bias[idx];
    out[tid] =
        temp > 0 ? from_float<Dtype>(temp) : from_float<Dtype>(temp * alpha);
#endif
  }
}

102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167
__global__ void bias_relu_int8_nhwc4_kernel(int num,
                                            const float4* in,
                                            const float4* bias,
                                            float4* out,
                                            int N,
                                            int K,
                                            int H,
                                            int W,
                                            const float4* scale,
                                            float alpha) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int bias_idx = tid % K;
    const float4 bias_ptr = bias[bias_idx];
    const float4 scale_ptr = scale[bias_idx];
    const float4 in_ptr = in[tid];

    float4 packed_val;
    packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
    packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
    packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
    packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
    packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
    packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
    packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
    packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
    out[tid] = packed_val;
  }
}

__global__ void bias_relu_int8_nhwc4_kernel(int num,
                                            const float4* in,
                                            const float4* bias,
                                            char4* out,
                                            int N,
                                            int K,
                                            int H,
                                            int W,
                                            const float4* scale,
                                            float alpha) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int bias_idx = tid % K;
    const float4 bias_ptr = bias[bias_idx];
    const float4 scale_ptr = scale[bias_idx];
    const float4 in_ptr = in[tid];

    float4 packed_val;
    char4 result_val;
    packed_val.x = in_ptr.x * scale_ptr.x + bias_ptr.x;
    result_val.x =
        from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
    packed_val.y = in_ptr.y * scale_ptr.y + bias_ptr.y;
    result_val.y =
        from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
    packed_val.z = in_ptr.z * scale_ptr.z + bias_ptr.z;
    result_val.z =
        from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
    packed_val.w = in_ptr.w * scale_ptr.w + bias_ptr.w;
    result_val.w =
        from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));

    out[tid] = result_val;
  }
}

Z
Zhaolong Xing 已提交
168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
template <typename Dtype>
__global__ void bias_int8_nhwc_kernel(int num,
                                      const float* in,
                                      const float* bias,
                                      Dtype* out,
                                      int N,
                                      int C,
                                      int H,
                                      int W,
                                      const float* scale) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int idx = tid % C;
#if __CUDA_ARCH__ >= 350
    float temp = __ldg(in + tid) * __ldg(scale + idx) + __ldg(bias + idx);
    out[tid] = from_float<Dtype>(temp);
#else
    float temp = in[tid] * scale[idx] + bias[idx];
    out[tid] = from_float<Dtype>(temp);
#endif
  }
}

191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
__global__ void relu_int8_nhwc4_kernel(int num,
                                       const float4* in,
                                       float4* out,
                                       int N,
                                       int K,
                                       int H,
                                       int W,
                                       const float4* scale,
                                       float alpha) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int scale_idx = tid % K;
    const float4 scale_ptr = scale[scale_idx];
    const float4 in_ptr = in[tid];

    float4 packed_val;
    packed_val.x = in_ptr.x * scale_ptr.x;
    packed_val.x = fmaxf(packed_val.x * alpha, packed_val.x);
    packed_val.y = in_ptr.y * scale_ptr.y;
    packed_val.y = fmaxf(packed_val.y * alpha, packed_val.y);
    packed_val.z = in_ptr.z * scale_ptr.z;
    packed_val.z = fmaxf(packed_val.z * alpha, packed_val.z);
    packed_val.w = in_ptr.w * scale_ptr.w;
    packed_val.w = fmaxf(packed_val.w * alpha, packed_val.w);
    out[tid] = packed_val;
  }
}

__global__ void relu_int8_nhwc4_kernel(int num,
                                       const float4* in,
                                       char4* out,
                                       int N,
                                       int K,
                                       int H,
                                       int W,
                                       const float4* scale,
                                       float alpha) {
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
  if (tid < num) {
    int scale_idx = tid % K;
    const float4 scale_ptr = scale[scale_idx];
    const float4 in_ptr = in[tid];

    float4 packed_val;
    char4 result_val;
    packed_val.x = in_ptr.x * scale_ptr.x;
    result_val.x =
        from_float<int8_t>(fmaxf(packed_val.x * alpha, packed_val.x));
    packed_val.y = in_ptr.y * scale_ptr.y;
    result_val.y =
        from_float<int8_t>(fmaxf(packed_val.y * alpha, packed_val.y));
    packed_val.z = in_ptr.z * scale_ptr.z;
    result_val.z =
        from_float<int8_t>(fmaxf(packed_val.z * alpha, packed_val.z));
    packed_val.w = in_ptr.w * scale_ptr.w;
    result_val.w =
        from_float<int8_t>(fmaxf(packed_val.w * alpha, packed_val.w));

    out[tid] = result_val;
  }
}

template <>
Z
Zhaolong Xing 已提交
254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296
void bias_relu_int8_nhwc<float>(int num,
                                const void* in,
                                const void* bias,
                                void* out,
                                int N,
                                int C,
                                int H,
                                int W,
                                const void* scale,
                                float alpha,
                                cudaStream_t stream) {
  int thread = 256;
  if (C % 4 == 0) {
    int block = (num / 4 + thread - 1) / thread;
    bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
        num / 4,
        static_cast<const float4*>(in),
        static_cast<const float4*>(bias),
        static_cast<float4*>(out),
        N,
        C / 4,
        H,
        W,
        static_cast<const float4*>(scale),
        alpha);
  } else {
    int block = (num + thread - 1) / thread;
    bias_relu_int8_nhwc_kernel<<<block, thread, 0, stream>>>(
        num,
        static_cast<const float*>(in),
        static_cast<const float*>(bias),
        static_cast<float*>(out),
        N,
        C,
        H,
        W,
        static_cast<const float*>(scale),
        alpha);
  }
}

template <>
void bias_relu_int8_nhwc<int8_t>(int num,
297 298 299 300
                                 const void* in,
                                 const void* bias,
                                 void* out,
                                 int N,
Z
Zhaolong Xing 已提交
301
                                 int C,
302 303 304 305 306 307
                                 int H,
                                 int W,
                                 const void* scale,
                                 float alpha,
                                 cudaStream_t stream) {
  int thread = 256;
Z
Zhaolong Xing 已提交
308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334
  if (C % 4 == 0) {
    int block = (num / 4 + thread - 1) / thread;
    bias_relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
        num / 4,
        static_cast<const float4*>(in),
        static_cast<const float4*>(bias),
        static_cast<char4*>(out),
        N,
        C / 4,
        H,
        W,
        static_cast<const float4*>(scale),
        alpha);
  } else {
    int block = (num + thread - 1) / thread;
    bias_relu_int8_nhwc_kernel<<<block, thread, 0, stream>>>(
        num,
        static_cast<const float*>(in),
        static_cast<const float*>(bias),
        static_cast<int8_t*>(out),
        N,
        C,
        H,
        W,
        static_cast<const float*>(scale),
        alpha);
  }
335 336
}

Z
Zhaolong Xing 已提交
337 338 339 340 341 342 343 344 345 346 347
template <typename out_type>
void bias_int8_nhwc(int num,
                    const void* in,
                    const void* bias,
                    void* out,
                    int N,
                    int C,
                    int H,
                    int W,
                    const void* scale,
                    cudaStream_t stream) {
348 349
  int thread = 256;
  int block = (num + thread - 1) / thread;
Z
Zhaolong Xing 已提交
350
  bias_int8_nhwc_kernel<<<block, thread, 0, stream>>>(
351
      num,
Z
Zhaolong Xing 已提交
352 353 354
      static_cast<const float*>(in),
      static_cast<const float*>(bias),
      static_cast<out_type*>(out),
355
      N,
Z
Zhaolong Xing 已提交
356
      C,
357 358
      H,
      W,
Z
Zhaolong Xing 已提交
359
      static_cast<const float*>(scale));
360 361
}

Z
Zhaolong Xing 已提交
362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382
template void bias_int8_nhwc<float>(int,
                                    const void*,
                                    const void* bias,
                                    void*,
                                    int,
                                    int,
                                    int,
                                    int,
                                    const void*,
                                    cudaStream_t);
template void bias_int8_nhwc<int8_t>(int,
                                     const void*,
                                     const void* bias,
                                     void*,
                                     int,
                                     int,
                                     int,
                                     int,
                                     const void*,
                                     cudaStream_t);

383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440
template <>
void relu_int8_nhwc4<float>(int num,
                            const void* in,
                            void* out,
                            int N,
                            int K,
                            int H,
                            int W,
                            const void* scale,
                            float alpha,
                            cudaStream_t stream) {
  int thread = 256;
  int block = (num + thread - 1) / thread;
  relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
      num,
      static_cast<const float4*>(in),
      static_cast<float4*>(out),
      N,
      K,
      H,
      W,
      static_cast<const float4*>(scale),
      alpha);
}

template <>
void relu_int8_nhwc4<int8_t>(int num,
                             const void* in,
                             void* out,
                             int N,
                             int K,
                             int H,
                             int W,
                             const void* scale,
                             float alpha,
                             cudaStream_t stream) {
  int thread = 256;
  int block = (num + thread - 1) / thread;
  relu_int8_nhwc4_kernel<<<block, thread, 0, stream>>>(
      num,
      static_cast<const float4*>(in),
      static_cast<char4*>(out),
      N,
      K,
      H,
      W,
      static_cast<const float4*>(scale),
      alpha);
}

template <typename T>
void relu(int num, const T* din, T* dout, float alpha, cudaStream_t stream) {
  int thread = 256;
  int block = (num + thread - 1) / thread;
  relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
  cudaError_t error = cudaGetLastError();
  if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
441

442 443 444 445 446 447 448 449 450 451 452 453 454
template <>
void relu<half>(
    int num, const half* din, half* dout, float alpha, cudaStream_t stream) {
  if (num == 0) {
    return;
  }
  int thread = 256;
  int block = (num + thread - 1) / thread;
  relu_kernel<half><<<block, thread, 0, stream>>>(num, alpha, din, dout);
  cudaError_t error = cudaGetLastError();
  if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}

455 456 457 458 459 460 461 462 463 464 465 466 467
template <typename T>
void bias_relu(int num,
               const T* din,
               const float* bias,
               T* dout,
               float alpha,
               cudaStream_t stream) {
  int thread = 256;
  int block = (num + thread - 1) / thread;
  relu_kernel<<<block, thread, 0, stream>>>(num, alpha, din, dout);
  cudaError_t error = cudaGetLastError();
  if (error != cudaSuccess) std::cout << cudaGetErrorString(error);
}
468
template void relu(int, const float*, float*, float, cudaStream_t);
469
template void relu(int, const half*, half*, float, cudaStream_t);
470 471
template void bias_relu(
    int, const float*, const float* bias, float*, float, cudaStream_t);
472 473 474 475 476

}  // namespace math
}  // namespace cuda
}  // namespace lite
}  // namespace paddle