math_function.cu 17.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <vector>
16

17
#include "paddle/fluid/platform/device_context.h"
18
#include "paddle/phi/backends/gpu/gpu_context.h"
19
#include "paddle/phi/common/bfloat16.h"
20
#include "paddle/phi/common/data_type.h"
21
#include "paddle/phi/common/float16.h"
22
#include "paddle/phi/common/memory_utils.h"
23 24 25
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
26

27
namespace phi {
28 29
namespace funcs {

30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
// The following part of the code refers to NVIDIA-cutlass
// https://github.com/NVIDIA/cutlass/blob/master/tools/util/include/cutlass/util/device_nchw_to_nhwc.h
// Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights
// reserved. SPDX-License-Identifier: BSD-3-Clause
template <typename T>
__global__ void batch_transpose_kernel(
    T* output, const T* input, const int batch, const int M, const int N) {
  const int num = M * N;
  // "+1" to avoid smem bank conflict
  __shared__ T shbuf[32 * (32 + 1)];
  const int32_t tid = threadIdx.y * blockDim.x + threadIdx.x;
  const int32_t wid = tid / 32;
  const int32_t lid = tid % 32;
  const int32_t batch_i = blockIdx.z;
  const int32_t mi0 = blockIdx.y * 32;
  const int32_t ni0 = blockIdx.x * 32;

  const size_t input_idx = batch_i * num + (mi0 + wid) * N + ni0;
  const T* A = input + input_idx;
  if (ni0 + lid < N) {
    const int lid_x_33 = lid * 33;
    if ((mi0 + 32) <= M) {
      int mi = wid;  // between 0 and 7
#pragma unroll
      for (int mLoopIdx = 0; mLoopIdx < 4; mLoopIdx++) {
        shbuf[lid_x_33 + mi] = A[lid];
        A = &A[8 * N];
        mi += 8;
      }
    } else {
      for (int mi = wid; mi < 32; mi += 8) {
        if ((mi + mi0) < M) {
          shbuf[lid_x_33 + mi] = A[lid];
        }
        A = &A[8 * N];
      }
    }
  }
  __syncthreads();

  const int32_t miOut = mi0 + lid;
  output = &output[batch_i * num + miOut];
  if (miOut < M) {
    if (ni0 + 32 < N) {
      int nI = wid;
#pragma unroll
      for (int nLoopIdx = 0; nLoopIdx < 4; ++nLoopIdx) {
        output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
        nI += 8;
      }
    } else {
      for (int nI = wid; nI < 32; nI += 8) {
        if (ni0 + nI < N) {
          output[(ni0 + nI) * M] = shbuf[(nI)*33 + lid];
        }
      }
    }
  }
}

template <typename T>
void BatchTranspose(T* output, const T* input, int batch, int m, int n) {
  dim3 grid((n + 31) / 32, (m + 31) / 32, batch);
  dim3 block(32, 8);
  batch_transpose_kernel<<<grid, block>>>(output, input, batch, m, n);
}

97 98
using float16 = phi::dtype::float16;
using bfloat16 = phi::dtype::bfloat16;
99

100 101 102 103 104 105 106
template void BatchTranspose(
    float16* output, const float16* input, int batch, int m, int n);
template void BatchTranspose(
    float* output, const float* input, int batch, int m, int n);

template struct SetConstant<phi::GPUContext, float16>;
template struct SetConstant<phi::GPUContext, bfloat16>;
107 108 109 110 111 112 113 114 115
template struct SetConstant<phi::GPUContext, float>;
template struct SetConstant<phi::GPUContext, double>;
template struct SetConstant<phi::GPUContext, uint8_t>;
template struct SetConstant<phi::GPUContext, int>;
template struct SetConstant<phi::GPUContext, int16_t>;
template struct SetConstant<phi::GPUContext, int64_t>;
template struct SetConstant<phi::GPUContext, bool>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<float>>;
template struct SetConstant<phi::GPUContext, phi::dtype::complex<double>>;
116

117
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float16>;
118
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
119
                            bfloat16>;
120 121 122 123 124 125 126 127
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, float>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, double>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, uint8_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int16_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, int64_t>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext, bool>;
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
128
                            phi::dtype::complex<float>>;
129
template struct SetConstant<paddle::platform::CUDAPinnedDeviceContext,
130
                            phi::dtype::complex<double>>;
131

132 133 134 135 136 137 138 139 140 141 142 143 144 145
#define DEFINE_GPU_TRANS(RANK)                                     \
  template struct Transpose<phi::GPUContext, bool, RANK>;          \
  template struct Transpose<phi::GPUContext, unsigned char, RANK>; \
  template struct Transpose<phi::GPUContext, float, RANK>;         \
  template struct Transpose<phi::GPUContext, double, RANK>;        \
  template struct Transpose<phi::GPUContext, float16, RANK>;       \
  template struct Transpose<phi::GPUContext, bfloat16, RANK>;      \
  template struct Transpose<phi::GPUContext, int8_t, RANK>;        \
  template struct Transpose<phi::GPUContext, int16_t, RANK>;       \
  template struct Transpose<phi::GPUContext, int32_t, RANK>;       \
  template struct Transpose<phi::GPUContext, int64_t, RANK>;       \
  template struct Transpose<phi::GPUContext,                       \
                            phi::dtype::complex<float>,            \
                            RANK>;                                 \
146
  template struct Transpose<phi::GPUContext, phi::dtype::complex<double>, RANK>;
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177

DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
DEFINE_GPU_TRANS(3);
DEFINE_GPU_TRANS(4);
DEFINE_GPU_TRANS(5);
DEFINE_GPU_TRANS(6);

#define REINTERPRET(T, DST_PTR, SRC_PTR) \
  T* DST_PTR = reinterpret_cast<T*>(SRC_PTR)

template <typename T>
__global__ void TransposeNormalKernel(const T* in_ptr,
                                      T* out_ptr,
                                      int64_t element,
                                      const int64_t* in_stride_ptr,
                                      const int64_t* out_stride_ptr,
                                      const int64_t* axis_ptr,
                                      int rank) {
  CUDA_KERNEL_LOOP(out_idx, element) {
    int64_t in_idx = 0;
    int64_t tmp_idx = out_idx;
    for (int i = 0; i < rank; ++i) {
      const int64_t coordinate = tmp_idx / out_stride_ptr[i];
      tmp_idx -= coordinate * out_stride_ptr[i];
      in_idx += coordinate * in_stride_ptr[axis_ptr[i]];
    }
    out_ptr[out_idx] = in_ptr[in_idx];
  }
}

178 179 180
template <typename DeviceContext, typename T>
void TransposeNormal<DeviceContext, T>::operator()(
    const DeviceContext& context,
181 182
    const phi::DenseTensor& in,
    phi::DenseTensor* out,
183 184
    const std::vector<int>& axis) {
  const int rank = axis.size();
185 186
  auto in_stride = phi::stride(in.dims());
  auto out_stride = phi::stride(out->dims());
187 188
  auto* in_ptr = in.data<T>();
  auto* out_ptr = out->data<T>();
189

190
  // copy in_stride, out_stride, axis to gpu device
191 192
  const phi::GPUPlace& cuda_place = context.GetPlace();
  phi::CPUPlace cpu_place = phi::CPUPlace();
193
  size_t size = 3 * rank * sizeof(int64_t);
194 195
  auto cpu_buf_holder = phi::memory_utils::Alloc(cpu_place, size);
  auto cuda_buf_holder = phi::memory_utils::Alloc(cuda_place, size);
196 197 198 199 200 201
  REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
  REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
  for (int i = 0; i < rank; ++i) {
    cpu_buf[i] = in_stride[i];
    cpu_buf[rank + i] = out_stride[i];
    cpu_buf[2 * rank + i] = axis[i];
202
  }
203
  memory_utils::Copy(
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
      cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream());
  REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
  REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
  REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);

  const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock();
  const int MAX_GRID_DIM = context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
  int64_t elements = in.numel();
  int block_size = (elements >= MAX_BLOCK_DIM)
                       ? MAX_BLOCK_DIM
                       : (1 << static_cast<int>(std::log2(elements)));
  int grid_size = elements / block_size;
  grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
  TransposeNormalKernel<T><<<grid_size, block_size, 0, context.stream()>>>(
      in_ptr, out_ptr, elements, in_stride_ptr, out_stride_ptr, axis_ptr, rank);
}
220

H
hong 已提交
221 222 223 224 225 226 227 228 229 230 231 232 233 234
template <typename T>
struct TransposeNormal<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
                  const DenseTensor& in,
                  DenseTensor* out,
                  const std::vector<int>& axis) {
    const int rank = axis.size();
    auto in_stride = stride(in.dims());
    auto out_stride = stride(out->dims());
    auto* in_ptr = in.data<T>();
    auto* out_ptr = out->data<T>();

    // copy in_stride, out_stride, axis to gpu device
    const phi::GPUPlace& cuda_place = context.GetPlace();
235
    phi::CPUPlace cpu_place = phi::CPUPlace();
H
hong 已提交
236
    size_t size = 3 * rank * sizeof(int64_t);
237 238
    auto cpu_buf_holder = phi::memory_utils::Alloc(cpu_place, size);
    auto cuda_buf_holder = phi::memory_utils::Alloc(cuda_place, size);
H
hong 已提交
239 240 241 242 243 244 245
    REINTERPRET(int64_t, cpu_buf, cpu_buf_holder->ptr());
    REINTERPRET(int64_t, cuda_buf, cuda_buf_holder->ptr());
    for (int i = 0; i < rank; ++i) {
      cpu_buf[i] = in_stride[i];
      cpu_buf[rank + i] = out_stride[i];
      cpu_buf[2 * rank + i] = axis[i];
    }
246
    memory_utils::Copy(
H
hong 已提交
247 248 249 250 251 252 253 254 255 256 257 258 259 260
        cuda_place, cuda_buf, cpu_place, cpu_buf, size, context.stream());
    REINTERPRET(const int64_t, in_stride_ptr, cuda_buf);
    REINTERPRET(const int64_t, out_stride_ptr, cuda_buf + rank);
    REINTERPRET(const int64_t, axis_ptr, cuda_buf + 2 * rank);

    const int MAX_BLOCK_DIM = context.GetMaxThreadsPerBlock();
    const int MAX_GRID_DIM =
        context.GetMaxPhysicalThreadCount() / MAX_BLOCK_DIM;
    int64_t elements = in.numel();
    int block_size = (elements >= MAX_BLOCK_DIM)
                         ? MAX_BLOCK_DIM
                         : (1 << static_cast<int>(std::log2(elements)));
    int grid_size = elements / block_size;
    grid_size = (grid_size >= MAX_GRID_DIM) ? MAX_GRID_DIM : grid_size;
261 262 263 264 265 266 267 268
    TransposeNormalKernel<T>
        <<<grid_size, block_size, 0, context.stream()>>>(in_ptr,
                                                         out_ptr,
                                                         elements,
                                                         in_stride_ptr,
                                                         out_stride_ptr,
                                                         axis_ptr,
                                                         rank);
H
hong 已提交
269 270 271
  }
};

272
// define transpose normal
273
#define DEFINE_GPU_TRANS_NORMAL(TYPE) \
274
  template struct TransposeNormal<phi::GPUContext, TYPE>
275 276 277 278 279 280 281 282 283 284 285

DEFINE_GPU_TRANS_NORMAL(float16);
DEFINE_GPU_TRANS_NORMAL(bfloat16);
DEFINE_GPU_TRANS_NORMAL(float);
DEFINE_GPU_TRANS_NORMAL(double);
DEFINE_GPU_TRANS_NORMAL(int);
DEFINE_GPU_TRANS_NORMAL(int64_t);
DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
286 287
DEFINE_GPU_TRANS_NORMAL(phi::dtype::complex<float>);
DEFINE_GPU_TRANS_NORMAL(phi::dtype::complex<double>);
288 289

struct TensorSetConstantGPU {
290
  TensorSetConstantGPU(const phi::DeviceContext& context,
291
                       phi::DenseTensor* tensor,
292 293 294 295 296
                       float value)
      : context_(context), tensor_(tensor), value_(value) {}

  template <typename T>
  void apply() const {
L
Leo Chen 已提交
297 298 299 300
    SetConstant<phi::GPUContext, T> functor;
    functor(reinterpret_cast<const phi::GPUContext&>(context_),
            tensor_,
            static_cast<T>(value_));
301 302
  }

303
  const phi::DeviceContext& context_;
304
  phi::DenseTensor* tensor_;
305 306 307 308
  float value_;
};

template <>
309 310 311
void set_constant_with_place<phi::GPUPlace>(const phi::DeviceContext& context,
                                            phi::DenseTensor* tensor,
                                            float value) {
312 313
  phi::VisitDataType(tensor->dtype(),
                     TensorSetConstantGPU(context, tensor, value));
314 315 316 317 318 319 320 321 322 323 324 325 326 327
}

template <typename T>
__global__ void RowwiseAddKernel(
    const T* a, const T* b, T* c, int width, int num) {
  T tmp = 1.0 / width;
  CUDA_KERNEL_LOOP(i, num) {
    int h = i * tmp;
    int w = i - h * width;
    c[i] = a[i] + b[w];
  }
}

template <typename T>
L
Leo Chen 已提交
328 329
struct RowwiseAdd<phi::GPUContext, T> {
  void operator()(const phi::GPUContext& context,
330 331 332
                  const phi::DenseTensor& input,
                  const phi::DenseTensor& vector,
                  phi::DenseTensor* output) {
333 334 335 336 337 338
    auto in_dims = input.dims();
    auto out_dims = output->dims();
    auto size = input.numel() / in_dims[0];
    PADDLE_ENFORCE_EQ(
        vector.numel(),
        size,
339
        phi::errors::InvalidArgument(
340 341 342 343 344 345 346 347 348 349
            "The input vector size"
            " should be equal to the size of each row of input tensor."
            " Expected vector size=%d, but received %d",
            size,
            vector.numel()));
    const char* in_dims_cstr = in_dims.to_str().c_str();
    const char* out_dims_cstr = out_dims.to_str().c_str();
    PADDLE_ENFORCE_EQ(
        out_dims,
        in_dims,
350
        phi::errors::InvalidArgument(
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
            "The output tensor shape should be same as the input tensor"
            " shape. Expected output tensor shape: %s,"
            " but received %s",
            in_dims_cstr,
            out_dims_cstr));
    int blocks = 512;
    int grids = (input.numel() + blocks - 1) / blocks;
    RowwiseAddKernel<T><<<grids, blocks, 0, context.stream()>>>(
        input.data<T>(),
        vector.data<T>(),
        output->data<T>(),
        static_cast<int>(in_dims[1]),
        static_cast<int>(input.numel()));
  }
};

L
Leo Chen 已提交
367 368 369 370 371 372 373
template struct RowwiseAdd<phi::GPUContext, float>;
template struct RowwiseAdd<phi::GPUContext, double>;
template struct ColwiseSum<phi::GPUContext, float>;
template struct ColwiseSum<phi::GPUContext, int>;
template struct ColwiseSum<phi::GPUContext, int64_t>;
// template struct ColwiseSum<phi::GPUContext, double>;
// The ColwiseSum<phi::GPUContext, double> failed in debug
374 375 376
// mode,
// and only failed for this case. So reimplemented it.
template <>
L
Leo Chen 已提交
377 378
void ColwiseSum<phi::GPUContext, double>::operator()(
    const phi::GPUContext& context,
379 380
    const phi::DenseTensor& input,
    phi::DenseTensor* vector) {
381 382 383 384
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(),
                    size,
385
                    phi::errors::InvalidArgument(
386 387 388 389 390
                        "The size of input vector"
                        " should be equal to the size of input tensor column"
                        " dimension. Expected vector size=%d, but received %d",
                        size,
                        vector->numel()));
391
  phi::DenseTensor one;
392 393 394
  one.Resize({in_dims[0]});
  context.template Alloc<double>(&one);

L
Leo Chen 已提交
395
  SetConstant<phi::GPUContext, double> set;
396
  set(context, &one, static_cast<double>(1.0));
L
Leo Chen 已提交
397 398 399 400 401 402 403 404 405
  phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
      true,
      static_cast<int>(in_dims[0]),
      static_cast<int>(in_dims[1]),
      1.0,
      input.data<double>(),
      one.data<double>(),
      0.0,
      vector->data<double>());
406 407
}

L
Leo Chen 已提交
408 409
template struct RowwiseSum<phi::GPUContext, float>;
// template struct RowwiseSum<phi::GPUContext, double>;
410
// TODO(zcd): Following ColwiseSum format, need to confirm.
L
Leo Chen 已提交
411
// The RowwiseSum<phi::GPUContext, double> failed in debug
412 413 414
// mode,
// and only failed for this case. So reimplemented it.
template <>
L
Leo Chen 已提交
415 416
void RowwiseSum<phi::GPUContext, double>::operator()(
    const phi::GPUContext& context,
417 418
    const phi::DenseTensor& input,
    phi::DenseTensor* vector) {
419 420 421 422
  auto in_dims = input.dims();
  auto size = input.numel() / in_dims[0];
  PADDLE_ENFORCE_EQ(vector->numel(),
                    in_dims[0],
423
                    phi::errors::InvalidArgument(
424 425 426 427 428
                        "The size of input vector"
                        " should be equal to the size of input tensor row"
                        " dimension. Expected vector size=%d, but received %d",
                        in_dims[0],
                        vector->numel()));
429
  phi::DenseTensor one;
430 431 432
  one.Resize({size});
  context.template Alloc<double>(&one);

L
Leo Chen 已提交
433
  SetConstant<phi::GPUContext, double> set;
434
  set(context, &one, static_cast<double>(1.0));
L
Leo Chen 已提交
435 436 437 438 439 440 441 442 443
  phi::funcs::GetBlas<phi::GPUContext, double>(context).GEMV(
      true,
      static_cast<int>(in_dims[1]),
      static_cast<int>(in_dims[0]),
      1.0,
      one.data<double>(),
      input.data<double>(),
      0.0,
      vector->data<double>());
444 445
}

L
Leo Chen 已提交
446 447
template struct RowwiseMean<phi::GPUContext, float>;
template struct RowwiseMean<phi::GPUContext, double>;
448 449

}  // namespace funcs
450
}  // namespace phi