softmax_cudnn_op.cu 22.9 KB
Newer Older
Z
zlsh80826 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
16
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
Z
zlsh80826 已提交
17
#include "paddle/fluid/operators/math/math_cuda_utils.h"
18
#include "paddle/fluid/operators/softmax_impl.cuh"
Z
zlsh80826 已提交
19
#include "paddle/fluid/operators/softmax_op.h"
G
GaoWei8 已提交
20
#include "paddle/fluid/platform/cuda_device_function.h"
21 22 23
#ifdef PADDLE_WITH_HIP
#include "paddle/fluid/platform/miopen_helper.h"
#else
Z
zlsh80826 已提交
24
#include "paddle/fluid/platform/cudnn_helper.h"
25
#endif
Z
zlsh80826 已提交
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40

namespace paddle {
namespace platform {
struct CUDAPlace;
struct float16;
}  // namespace platform
}  // namespace paddle

namespace paddle {
namespace operators {

using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;

41 42 43 44 45 46 47
// Vectorization trait 4 * sizeof(T)
template <typename T>
class VecT4 {};
template <>
class VecT4<double> {
 public:
  using Type = long4;
Z
zlsh80826 已提交
48 49
};
template <>
50 51 52 53 54 55 56 57
class VecT4<float> {
 public:
  using Type = int4;
};
template <>
class VecT4<platform::float16> {
 public:
  using Type = int2;
Z
zlsh80826 已提交
58 59
};

60 61 62
// Vectorization trait 2 * sizeof(T)
template <typename T>
class VecT2 {};
Z
zlsh80826 已提交
63
template <>
64 65 66 67 68 69 70 71 72 73 74 75 76
class VecT2<double> {
 public:
  using Type = int4;
};
template <>
class VecT2<float> {
 public:
  using Type = int2;
};
template <>
class VecT2<platform::float16> {
 public:
  using Type = int;
Z
zlsh80826 已提交
77 78
};

79 80 81 82
int static inline log2_ceil(int value) {
  int log2_value = 0;
  while ((1 << log2_value) < value) ++log2_value;
  return log2_value;
Z
zlsh80826 已提交
83 84
}

85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
/*
Core function of computing softmax forward for axis=-1.
The computation includes
  - Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
  - Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
  - Compute: (a_{i,j} - maxvalue_{i}) / s_{i}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
          bool LogMode = false>
__global__ void WarpSoftmaxForward(T* softmax, const T* src,
                                   const int batch_size, const int stride,
                                   const int element_count) {
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kIterations = kDimCeil / kWarpSize;
  constexpr int kIterationsV =
      (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
  constexpr int kBatchSize = (kDimCeil <= 32) ? 2 : 1;

  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;

  // max index to read
  int idx_max_v[kBatchSize];
G
GaoWei8 已提交
112
#pragma unroll
113 114 115
  for (int i = 0; i < kBatchSize; i++) {
    int idx_max = ((i + first_batch) < batch_size) ? element_count : 0;
    idx_max_v[i] = idx_max / kVSize;
G
GaoWei8 已提交
116 117
  }

118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
  // read data from global memory
  AccT srcdata[kBatchSize][kIterationsV][kVSize];

#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
// read data
#pragma unroll
    for (int it = 0; it < kIterationsV; ++it) {
      int src_idx = threadIdx.x + it * kWarpSize;
      if (kVSize == 1) {
        if (src_idx < idx_max_v[i]) {
          srcdata[i][it][0] =
              static_cast<AccT>(src[(first_batch + i) * stride + src_idx]);
        } else {
          srcdata[i][it][0] = -std::numeric_limits<AccT>::infinity();
        }
      } else {
        const VecT* src_v =
            reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
        if (src_idx < idx_max_v[i]) {
          VecT srctmp = src_v[src_idx];
          const T* srcinptr = reinterpret_cast<const T*>(&srctmp);
G
GaoWei8 已提交
140
#pragma unroll
141 142 143 144
          for (int s = 0; s < kVSize; s++) {
            srcdata[i][it][s] = static_cast<AccT>(srcinptr[s]);
          }
        } else {
G
GaoWei8 已提交
145
#pragma unroll
146 147 148 149 150
          for (int s = 0; s < kVSize; s++) {
            srcdata[i][it][s] = -std::numeric_limits<AccT>::infinity();
          }
        }
      }
G
GaoWei8 已提交
151 152 153
    }
  }

154 155 156 157 158 159 160 161 162 163 164
  // compute max value
  AccT max_value[kBatchSize];
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
    // it = 0
    AccT valmax = srcdata[i][0][0];
#pragma unroll
    for (int s = 1; s < kVSize; ++s) {
      valmax = (valmax > srcdata[i][0][s]) ? valmax : srcdata[i][0][s];
    }
    max_value[i] = valmax;
G
GaoWei8 已提交
165

166 167 168 169 170 171 172
// it = 1, 2, ...
#pragma unroll
    for (int it = 1; it < kIterationsV; ++it) {
      AccT valmax = srcdata[i][it][0];
#pragma unroll
      for (int s = 1; s < kVSize; ++s) {
        valmax = (valmax > srcdata[i][it][s]) ? valmax : srcdata[i][it][s];
G
GaoWei8 已提交
173
      }
174
      max_value[i] = (max_value[i] > valmax) ? max_value[i] : valmax;
G
GaoWei8 已提交
175 176
    }
  }
177
  WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
G
GaoWei8 已提交
178

179 180
  // compute sum
  AccT sum[kBatchSize];
G
GaoWei8 已提交
181
#pragma unroll
182 183 184 185 186 187 188 189
  for (int i = 0; i < kBatchSize; ++i) {
    // it = 0
    if (LogMode) {
      sum[i] = std::exp(srcdata[i][0][0] - max_value[i]);
    } else {
      srcdata[i][0][0] = std::exp(srcdata[i][0][0] - max_value[i]);
      sum[i] = srcdata[i][0][0];
    }
G
GaoWei8 已提交
190
#pragma unroll
191 192 193 194 195 196 197
    for (int s = 1; s < kVSize; ++s) {
      if (LogMode) {
        sum[i] += std::exp(srcdata[i][0][s] - max_value[i]);
      } else {
        srcdata[i][0][s] = std::exp(srcdata[i][0][s] - max_value[i]);
        sum[i] += srcdata[i][0][s];
      }
G
GaoWei8 已提交
198 199
    }

200
// it = 1, 2, ...
G
GaoWei8 已提交
201
#pragma unroll
202
    for (int it = 1; it < kIterationsV; ++it) {
G
GaoWei8 已提交
203
#pragma unroll
204 205 206 207 208 209 210 211
      for (int s = 0; s < kVSize; ++s) {
        if (LogMode) {
          sum[i] += std::exp(srcdata[i][it][s] - max_value[i]);
        } else {
          srcdata[i][it][s] = std::exp(srcdata[i][it][s] - max_value[i]);
          sum[i] += srcdata[i][it][s];
        }
      }
G
GaoWei8 已提交
212 213
    }
  }
214
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
G
GaoWei8 已提交
215

216
// write result to global memory
G
GaoWei8 已提交
217
#pragma unroll
218 219 220 221 222
  for (int i = 0; i < kBatchSize; ++i) {
    if (LogMode) {
      sum[i] = std::log(sum[i]);
    }

G
GaoWei8 已提交
223
#pragma unroll
224 225 226 227 228 229 230 231 232 233 234 235 236 237
    for (int it = 0; it < kIterationsV; ++it) {
      int idx = threadIdx.x + it * kWarpSize;
      if (kVSize == 1) {
        if (idx < idx_max_v[i]) {
          if (LogMode) {
            softmax[(first_batch + i) * stride + idx] =
                srcdata[i][it][0] - max_value[i] - sum[i];
          } else {
            softmax[(first_batch + i) * stride + idx] =
                srcdata[i][it][0] / sum[i];
          }
        } else {
          break;
        }
G
GaoWei8 已提交
238
      } else {
239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
        VecT* softmax_v =
            reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
        VecT tmpdata;
        T* tmpptr = reinterpret_cast<T*>(&tmpdata);
#pragma unroll
        for (int s = 0; s < kVSize; ++s) {
          if (LogMode) {
            tmpptr[s] = srcdata[i][it][s] - max_value[i] - sum[i];
          } else {
            tmpptr[s] = srcdata[i][it][s] / sum[i];
          }
        }

        if (idx < idx_max_v[i]) {
          softmax_v[idx] = tmpdata;
        } else {
          break;
        }
G
GaoWei8 已提交
257 258 259 260 261
      }
    }
  }
}

262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
/*
Core function of computing softmax backward for axis=-1.
The computation includes
  - Compute sum of exp batch: s_{i} = sum_{j} {src_{i,j} * grad_{i,j}
  - Compute src_{i,j} * ( grad_{i,j}) - s_{i} )
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements,
          bool LogMode = false>
__global__ void WarpSoftmaxBackward(T* dst, const T* grad, const T* src,
                                    int batch_size, int stride,
                                    int element_count) {
  constexpr int kVSize = sizeof(VecT) / sizeof(T);
  constexpr int kDimCeil = 1 << Log2Elements;
  constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
  constexpr int kIterations = kDimCeil / kWarpSize;
  constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
  constexpr int kIterationsV =
      (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
  int element_count_v = element_count / kVSize;

  int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
G
GaoWei8 已提交
286
  int local_batches = batch_size - first_batch;
287 288
  if (local_batches > kBatchSize) {
    local_batches = kBatchSize;
G
GaoWei8 已提交
289 290
  }

291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
  // read data from global memory
  VecT src_reg[kBatchSize][kIterationsV];
  VecT grad_reg[kBatchSize][kIterationsV];

  for (int i = 0; i < kBatchSize; ++i) {
    const VecT* src_v =
        reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
    const VecT* grad_v =
        reinterpret_cast<const VecT*>(&grad[(first_batch + i) * stride]);

    // max index to read
    int idx_max = (i < local_batches) ? element_count : 0;
    int idx_max_v = idx_max / kVSize;

    // read data
    for (int it = 0; it < kIterationsV; ++it) {
      int src_idx = threadIdx.x + it * kWarpSize;
      if (src_idx < idx_max_v) {
        src_reg[i][it] = src_v[src_idx];
        grad_reg[i][it] = grad_v[src_idx];
G
GaoWei8 已提交
311
      } else {
312 313 314 315 316
#pragma unroll
        for (int s = 0; s < kVSize; s++) {
          reinterpret_cast<T*>(&src_reg[i][it])[s] = 0.0;
          reinterpret_cast<T*>(&grad_reg[i][it])[s] = 0.0;
        }
G
GaoWei8 已提交
317 318 319 320
      }
    }
  }

321 322 323 324
  // compute sum
  AccT sum[kBatchSize]{0.0};
#pragma unroll
  for (int i = 0; i < kBatchSize; ++i) {
G
GaoWei8 已提交
325
#pragma unroll
326 327 328
    for (int it = 0; it < kIterationsV; ++it) {
      T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
      T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
G
GaoWei8 已提交
329
#pragma unroll
330 331 332 333 334 335 336
      for (int s = 0; s < kVSize; ++s) {
        if (LogMode) {
          sum[i] += static_cast<AccT>(gradptr[s]);
        } else {
          sum[i] += static_cast<AccT>(gradptr[s] * srcptr[s]);
        }
      }
G
GaoWei8 已提交
337 338
    }
  }
339
  WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
G
GaoWei8 已提交
340

341
// write result
G
GaoWei8 已提交
342
#pragma unroll
343
  for (int i = 0; i < kBatchSize; ++i) {
G
GaoWei8 已提交
344
    if (i >= local_batches) break;
345 346 347 348 349 350 351

    VecT* dst_v = reinterpret_cast<VecT*>(&dst[(first_batch + i) * stride]);

    // max index to write
    int idx_max = (i < local_batches) ? element_count : 0;
    int idx_max_v = idx_max / kVSize;

G
GaoWei8 已提交
352
#pragma unroll
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371
    for (int it = 0; it < kIterationsV; ++it) {
      VecT tmpdata;
      T* tmpptr = reinterpret_cast<T*>(&tmpdata);
      T* gradptr = reinterpret_cast<T*>(&grad_reg[i][it]);
      T* srcptr = reinterpret_cast<T*>(&src_reg[i][it]);
#pragma unroll
      for (int s = 0; s < kVSize; ++s) {
        if (LogMode) {
          tmpptr[s] = static_cast<AccT>(gradptr[s]) -
                      std::exp(static_cast<AccT>(srcptr[s])) * sum[i];
        } else {
          tmpptr[s] = static_cast<AccT>(srcptr[s]) *
                      (static_cast<AccT>(gradptr[s]) - sum[i]);
        }
      }

      int idx = threadIdx.x + it * kWarpSize;
      if (idx < idx_max_v) {
        dst_v[idx] = tmpdata;
G
GaoWei8 已提交
372 373 374 375 376
      }
    }
  }
}

377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, AccT)                         \
  case Log2Elements:                                                          \
    WarpSoftmaxForward<                                                       \
        T, VecT, AccT, Log2Elements,                                          \
        LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
        dst, src, batch_size, stride, element_count);                         \
    break;

/*
  Wrapper of softmax formward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxForward(const int blocks, const dim3 threads,
                              const framework::ExecutionContext& ctx, T* dst,
                              const T* src, const int batch_size,
                              const int stride, const int element_count,
                              int Log2Elements) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  switch (Log2Elements) {
    SOFTMAX_WARP_FORWARD_CASE(0, AccT);
    SOFTMAX_WARP_FORWARD_CASE(1, AccT);
    SOFTMAX_WARP_FORWARD_CASE(2, AccT);
    SOFTMAX_WARP_FORWARD_CASE(3, AccT);
    SOFTMAX_WARP_FORWARD_CASE(4, AccT);
    SOFTMAX_WARP_FORWARD_CASE(5, AccT);
    SOFTMAX_WARP_FORWARD_CASE(6, AccT);
    SOFTMAX_WARP_FORWARD_CASE(7, AccT);
    SOFTMAX_WARP_FORWARD_CASE(8, AccT);
    SOFTMAX_WARP_FORWARD_CASE(9, AccT);
    default:
      break;
G
GaoWei8 已提交
408 409 410
  }
}

411 412 413 414 415 416 417
#define SOFTMAX_WARP_BACKWARD_CASE(Log2Elements, AccT)                        \
  case Log2Elements:                                                          \
    WarpSoftmaxBackward<                                                      \
        T, VecT, AccT, Log2Elements,                                          \
        LogMode><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>( \
        dst, grad, src, batch_size, stride, element_count);                   \
    break;
Z
zlsh80826 已提交
418

419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441
/*
Wrapper of softmax backward with template instantiation on size of input.
*/
template <typename T, typename VecT, bool LogMode>
void SwitchWarpSoftmaxBackward(const int blocks, const dim3 threads,
                               const framework::ExecutionContext& ctx, T* dst,
                               const T* grad, const T* src,
                               const int batch_size, const int stride,
                               const int element_count, int Log2Elements) {
  using AccT = typename details::MPTypeTrait<T>::Type;
  switch (Log2Elements) {
    SOFTMAX_WARP_BACKWARD_CASE(0, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(1, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(2, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(3, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(4, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(5, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(6, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(7, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(8, AccT);
    SOFTMAX_WARP_BACKWARD_CASE(9, AccT);
    default:
      break;
Z
zlsh80826 已提交
442 443 444
  }
}

445 446 447 448
#undef SOFTMAX_WARP_FORWARD_CASE
#undef SOFTMAX_WARP_BACKWARD_CASE

template <typename T, bool LogMode = false>
Z
zlsh80826 已提交
449 450 451 452 453 454 455 456 457 458 459 460 461 462 463
class SoftmaxCUDNNKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* x = ctx.Input<Tensor>("X");
    auto* out = ctx.Output<Tensor>("Out");
    out->mutable_data<T>(ctx.GetPlace());
    auto* out_data = out->data<T>();

    auto dims = x->dims();
    const int rank = dims.size();
    const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
    const int dim = dims[axis];
    const int N = SizeToAxis(axis, dims);
    const int D = SizeOutAxis(axis, dims);

G
GaoWei8 已提交
464
    constexpr int max_dim = 320;
Z
zlsh80826 已提交
465
    constexpr int warps_per_block = 4;
466

G
GaoWei8 已提交
467
    if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495
      const int kDimLog2 = static_cast<int>(log2_ceil(dim));
      const int kDimCeil = 1 << kDimLog2;
      int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
      int batches_per_warp = (kDimCeil <= 32) ? 2 : 1;

      // use 128 threads per block to maximimize gpu utilization
      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / kWarpSize);
      int batches_per_block = warps_per_block * batches_per_warp;
      int blocks = (N + batches_per_block - 1) / batches_per_block;
      dim3 threads(kWarpSize, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;
      if (dim % 4 == 0) {
        SwitchWarpSoftmaxForward<T, T4, LogMode>(blocks, threads, ctx, out_data,
                                                 x->data<T>(), N, dim, dim,
                                                 kDimLog2);
      } else if (dim % 2 == 0) {
        SwitchWarpSoftmaxForward<T, T2, LogMode>(blocks, threads, ctx, out_data,
                                                 x->data<T>(), N, dim, dim,
                                                 kDimLog2);
      } else {
        SwitchWarpSoftmaxForward<T, T, LogMode>(blocks, threads, ctx, out_data,
                                                x->data<T>(), N, dim, dim,
                                                kDimLog2);
Z
zlsh80826 已提交
496
      }
497
    } else {
Z
zlsh80826 已提交
498 499 500
      ScopedTensorDescriptor desc;
      std::vector<int> tensor_dims = {N, dim, D, 1};
      DataLayout layout = DataLayout::kNCHW;
501 502 503
#ifdef PADDLE_WITH_HIP
      miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
Z
zlsh80826 已提交
504
      cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
505
#endif
Z
zlsh80826 已提交
506 507 508 509

      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      auto handle = dev_ctx.cudnn_handle();
510 511 512 513

#ifdef PADDLE_WITH_HIP
      auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                   : MIOPEN_SOFTMAX_MODE_CHANNEL;
514 515 516 517 518 519 520 521 522 523 524
      if (LogMode) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
            handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
            platform::CudnnDataType<T>::kZero(), desc_, out_data,
            MIOPEN_SOFTMAX_LOG, mode));
      } else {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
            handle, platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
            platform::CudnnDataType<T>::kZero(), desc_, out_data,
            MIOPEN_SOFTMAX_ACCURATE, mode));
      }
525
#else
Z
zlsh80826 已提交
526 527
      auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                   : CUDNN_SOFTMAX_MODE_CHANNEL;
528 529 530 531 532 533 534 535 536 537 538
      if (LogMode) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
            handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
            desc_, x->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
            out_data));
      } else {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxForward(
            handle, CUDNN_SOFTMAX_ACCURATE, mode,
            platform::CudnnDataType<T>::kOne(), desc_, x->data<T>(),
            platform::CudnnDataType<T>::kZero(), desc_, out_data));
      }
539
#endif
Z
zlsh80826 已提交
540 541 542 543
    }
  }
};

544
template <typename T, bool LogMode = false>
Z
zlsh80826 已提交
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560
class SoftmaxGradCUDNNKernel : public framework::OpKernel<T> {
 public:
  void Compute(const framework::ExecutionContext& ctx) const override {
    auto* out = ctx.Input<Tensor>("Out");
    auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
    auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
    dx->mutable_data<T>(ctx.GetPlace());
    auto* dx_data = dx->data<T>();

    auto dims = out->dims();
    const int rank = dims.size();
    const int axis = CanonicalAxis(ctx.Attr<int>("axis"), rank);
    const int dim = dims[axis];
    const int N = SizeToAxis(axis, dims);
    const int D = SizeOutAxis(axis, dims);

561
    constexpr int max_dim = 320;
Z
zlsh80826 已提交
562
    constexpr int warps_per_block = 4;
563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590

    if (D == 1 && dim <= max_dim && sizeof(T) <= 4) {
      const int kDimLog2 = log2_ceil(dim);
      const int kDimCeil = 1 << kDimLog2;
      int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
      int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
      constexpr int threads_per_block = 128;

      int warps_per_block = (threads_per_block / kWarpSize);
      int batches_per_block = warps_per_block * batches_per_warp;
      int blocks = (N + batches_per_block - 1) / batches_per_block;
      dim3 threads(kWarpSize, warps_per_block, 1);

      // vectorization read/write
      using T4 = typename VecT4<T>::Type;
      using T2 = typename VecT2<T>::Type;
      if (dim % 4 == 0) {
        SwitchWarpSoftmaxBackward<T, T4, LogMode>(
            blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
            dim, dim, kDimLog2);
      } else if (dim % 2 == 0) {
        SwitchWarpSoftmaxBackward<T, T2, LogMode>(
            blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
            dim, dim, kDimLog2);
      } else {
        SwitchWarpSoftmaxBackward<T, T, LogMode>(
            blocks, threads, ctx, dx_data, dout->data<T>(), out->data<T>(), N,
            dim, dim, kDimLog2);
Z
zlsh80826 已提交
591
      }
592
    } else {
Z
zlsh80826 已提交
593 594 595
      ScopedTensorDescriptor desc;
      std::vector<int> tensor_dims = {N, dim, D, 1};
      DataLayout layout = DataLayout::kNCHW;
596 597 598
#ifdef PADDLE_WITH_HIP
      miopenTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
#else
Z
zlsh80826 已提交
599
      cudnnTensorDescriptor_t desc_ = desc.descriptor<T>(layout, tensor_dims);
600
#endif
Z
zlsh80826 已提交
601 602 603 604

      auto& dev_ctx =
          ctx.template device_context<platform::CUDADeviceContext>();
      auto handle = dev_ctx.cudnn_handle();
605 606 607 608

#ifdef PADDLE_WITH_HIP
      auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
                                   : MIOPEN_SOFTMAX_MODE_CHANNEL;
609 610 611 612 613 614 615 616 617 618 619
      if (LogMode) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
            handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(),
            desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
            dx_data, MIOPEN_SOFTMAX_LOG, mode));
      } else {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::miopenSoftmaxBackward_V2(
            handle, platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(),
            desc_, dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
            dx_data, MIOPEN_SOFTMAX_ACCURATE, mode));
      }
620
#else
Z
zlsh80826 已提交
621 622
      auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
                                   : CUDNN_SOFTMAX_MODE_CHANNEL;
623 624 625 626 627 628 629 630 631 632 633 634
      if (LogMode) {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
            handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
            desc_, out->data<T>(), desc_, dout->data<T>(),
            platform::CudnnDataType<T>::kZero(), desc_, dx_data));
      } else {
        PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::cudnnSoftmaxBackward(
            handle, CUDNN_SOFTMAX_ACCURATE, mode,
            platform::CudnnDataType<T>::kOne(), desc_, out->data<T>(), desc_,
            dout->data<T>(), platform::CudnnDataType<T>::kZero(), desc_,
            dx_data));
      }
635
#endif
Z
zlsh80826 已提交
636 637 638 639 640 641 642 643 644
    }
  }
};

}  // namespace operators
}  // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
645 646 647 648 649 650 651 652 653
#ifdef PADDLE_WITH_HIP
// MIOPEN do not support double
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
                   ops::SoftmaxCUDNNKernel<float>,
                   ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
                   ops::SoftmaxGradCUDNNKernel<float>,
                   ops::SoftmaxGradCUDNNKernel<plat::float16>);
#else
Z
zlsh80826 已提交
654 655 656 657 658 659 660 661
REGISTER_OP_KERNEL(softmax, CUDNN, plat::CUDAPlace,
                   ops::SoftmaxCUDNNKernel<float>,
                   ops::SoftmaxCUDNNKernel<double>,
                   ops::SoftmaxCUDNNKernel<plat::float16>);
REGISTER_OP_KERNEL(softmax_grad, CUDNN, plat::CUDAPlace,
                   ops::SoftmaxGradCUDNNKernel<float>,
                   ops::SoftmaxGradCUDNNKernel<double>,
                   ops::SoftmaxGradCUDNNKernel<plat::float16>);
662
#endif