未验证 提交 f8955602 编写于 作者: F Feng Xing 提交者: GitHub

optimize softmax with cross entropy soft label (#32387)

softmax_with_cross_entropy optimization with soft label. This PR includes optimization of
    "SoftmaxWithCrossEntropySoftLabel" : compute log_softmax and then compute loss.
    "CrossEntropySoftLabel" : compute loss with softmax as input.
These optimization includes following technics:
    read data to buffer with vectorization
    compute max and sum in warp
    fixed loop size with macro
Performance (computation time):
    softmax_with_cross_entropy_0 (forward) : -40.1%
    softmax_with_cross_entropy_0 (backward): -41%
上级 bb0713b2
......@@ -20,6 +20,7 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
#include "paddle/fluid/operators/softmax_with_cross_entropy_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -391,8 +392,8 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
using AccT = typename details::MPTypeTrait<T>::Type;
// use 128 threads per block to maximimize gpu utilization
const int Log2Elements = static_cast<int>(Log2Ceil(element_count));
const int kDimCeil = 1 << Log2Elements;
const int log2_elements = static_cast<int>(Log2Ceil(element_count));
const int kDimCeil = 1 << log2_elements;
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
constexpr int threads_per_block = 128;
......@@ -401,7 +402,7 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
int blocks = (batch_size + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
switch (Log2Elements) {
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_CASE(0, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(1, T, AccT);
SOFTMAX_WARP_FORWARD_CASE(2, T, AccT);
......@@ -494,6 +495,368 @@ __global__ void SoftmaxWithCrossEntropyGradHardLabel(
}
}
/*
Cross entropy soft label with dynamic size on axis (log2_elements is
varibale).
- if the input is softmax,compute loss with softmax
- if the input is log_softmax, compute loss with log_softmax and update
softmax
*/
template <typename T, typename VecT, bool InLogMode = false>
__global__ void CrossEntropySoftLabel(T* loss, T* softmaxwrt, const T* softmax,
const T* labels, const int n,
const int dim, const int d,
int log2_elements) {
const int kDimCeil = 1 << log2_elements;
const int kVSize = sizeof(VecT) / sizeof(T);
#ifdef __HIPCC__
const int kThreadPerBlock = 256;
#else
const int kThreadPerBlock = 512;
#endif
const int kBatchPerBlock = 1;
const int kWarpSize = 32; // (dim < 32) ? dim : 32;
const int kBatchSize = 1;
const int kThreadPerBatch = kThreadPerBlock / kBatchPerBlock;
const int kWarpPerBatch = kThreadPerBatch / kWarpSize;
const int kIterations = (dim + kThreadPerBatch - 1) / kThreadPerBatch;
const int kIterationsV = (kIterations >= kVSize) ? (kIterations / kVSize) : 1;
const int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
T sum[kBatchSize]{static_cast<T>(0.0)};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
int ids = first_batch + i;
if (ids >= n * d) break;
int idx_n = ids / d;
int idx_d = ids % d;
#pragma unroll
for (int it = 0; it < kIterations; ++it) {
int idx_dim = it * kThreadPerBatch + threadIdx.x;
int idx = idx_n * dim * d + idx_dim * d + idx_d;
if (idx_n < n && idx_dim < dim) {
VecT softmaxdata;
if (InLogMode) {
softmaxdata = reinterpret_cast<VecT*>(&softmaxwrt[idx])[0];
} else {
softmaxdata = reinterpret_cast<const VecT*>(&softmax[idx])[0];
}
VecT labelsdata = reinterpret_cast<const VecT*>(&labels[idx])[0];
T* softmaxptr = reinterpret_cast<T*>(&softmaxdata);
T* labelsptr = reinterpret_cast<T*>(&labelsdata);
#pragma unroll
for (int s = 0; s < kVSize; s++) {
if (InLogMode) {
sum[i] -= softmaxptr[s] * labelsptr[s];
softmaxptr[s] = Exp(softmaxptr[s]);
} else {
sum[i] -= Log(softmaxptr[s]) * labelsptr[s];
}
}
if (InLogMode) {
reinterpret_cast<VecT*>(&softmaxwrt[idx])[0] = softmaxdata;
}
}
}
}
WarpReduceSum<T, kBatchSize, kWarpSize>(sum);
__syncthreads();
__shared__ T sumshare[kWarpPerBatch][kBatchPerBlock][kBatchSize];
if (threadIdx.x % kWarpSize == 0) {
#pragma unroll
for (int i = 0; i < kBatchSize; i++) {
sumshare[threadIdx.x / kWarpSize][threadIdx.y][i] = sum[i];
}
}
__syncthreads();
// write
if (threadIdx.x == 0) {
for (int i = 0; i < kBatchSize; i++) {
int ids = first_batch + i;
if (ids < n * d) {
loss[ids] = sumshare[0][threadIdx.y][i];
for (int s = 1; s < kWarpPerBatch; s++) {
loss[ids] += sumshare[s][threadIdx.y][i];
}
}
}
}
}
/*
Core function of softmax with cross entropy forward soft label.
The computation includes
- Compute maximum of batch: maxvalue_{i} = max_j src_{i,j}
- Compute sum of exp batch: s_{i} = sum_{j}{ exp(src_{i,j} - maxvalue_{i} }
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss, T* softmax, const T* src,
const T* label,
const int batch_size,
const int stride,
const int element_count) {
const bool LogMode = true;
constexpr int kDimCeil = 1 << Log2Elements;
constexpr int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
constexpr int kVSize = sizeof(VecT) / sizeof(T);
constexpr int kIterations = kDimCeil / kWarpSize;
constexpr int kIterationsV =
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
int local_batches = batch_size - first_batch;
if (local_batches > kBatchSize) {
local_batches = kBatchSize;
}
// read data from global memory
VecT srcdata[kBatchSize][kIterationsV];
VecT labeldata[kBatchSize][kIterationsV];
for (int i = 0; i < kBatchSize; ++i) {
const VecT* src_v =
reinterpret_cast<const VecT*>(&src[(first_batch + i) * stride]);
const VecT* label_v =
reinterpret_cast<const VecT*>(&label[(first_batch + i) * stride]);
// max index to read
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
// read data
for (int it = 0; it < kIterationsV; ++it) {
int src_idx = threadIdx.x + it * kWarpSize;
if (src_idx < idx_max_v) {
srcdata[i][it] = src_v[src_idx];
labeldata[i][it] = label_v[src_idx];
} else {
#pragma unroll
for (int s = 0; s < kVSize; s++) {
reinterpret_cast<T*>(&srcdata[i][it])[s] =
-std::numeric_limits<AccT>::max();
reinterpret_cast<T*>(&labeldata[i][it])[s] = 0.0;
}
}
}
}
// compute max value
AccT max_value[kBatchSize];
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
max_value[i] = -std::numeric_limits<AccT>::infinity();
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
T valmax = srcptr_v[0];
#pragma unroll
for (int s = 1; s < kVSize; ++s) {
valmax = (valmax > srcptr_v[s]) ? valmax : srcptr_v[s];
}
max_value[i] = (max_value[i] > static_cast<AccT>(valmax))
? max_value[i]
: static_cast<AccT>(valmax);
}
}
WarpReduceMax<AccT, kBatchSize, kWarpSize>(max_value);
// compute sum
AccT sum[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcptr_v = reinterpret_cast<T*>(&srcdata[i][it]);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
sum[i] += std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
} else {
srcptr_v[s] = std::exp(static_cast<AccT>(srcptr_v[s]) - max_value[i]);
sum[i] += static_cast<AccT>(srcptr_v[s]);
}
}
}
}
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sum);
// log_softmax and loss
AccT sumloss[kBatchSize]{0.0};
#pragma unroll
for (int i = 0; i < kBatchSize; ++i) {
if (i >= local_batches) break;
VecT* softmax_v =
reinterpret_cast<VecT*>(&softmax[(first_batch + i) * stride]);
// max index to write
int idx_max = (i < local_batches) ? element_count : 0;
int idx_max_v = idx_max / kVSize;
if (LogMode) {
sum[i] = std::log(sum[i]);
}
#pragma unroll
for (int it = 0; it < kIterationsV; ++it) {
T* srcvp = reinterpret_cast<T*>(&srcdata[i][it]);
T* labelvp = reinterpret_cast<T*>(&labeldata[i][it]);
VecT tmpv;
T* tmpvp = reinterpret_cast<T*>(&tmpv);
#pragma unroll
for (int s = 0; s < kVSize; ++s) {
if (LogMode) {
AccT logsoftmax = static_cast<AccT>(srcvp[s]) - max_value[i] - sum[i];
sumloss[i] -= logsoftmax * static_cast<AccT>(labelvp[s]);
tmpvp[s] = std::exp(logsoftmax);
} else {
tmpvp[s] = static_cast<AccT>(srcvp[s]) / sum[i];
}
}
int idx = threadIdx.x + it * kWarpSize;
if (idx < idx_max_v) {
softmax_v[idx] = tmpv;
}
}
}
// loss
WarpReduceSum<AccT, kBatchSize, kWarpSize>(sumloss);
for (int i = 0; i < kBatchSize; i++) {
if (i >= local_batches) break;
loss[first_batch + i] = sumloss[i];
}
}
#define SOFTMAX_WARP_FORWARD_SOFT_CASE(Log2Elements, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForwardSoftLabel<T, VecT, AccT, \
Log2Elements><<<blocks, threads, 0, stream>>>( \
loss, softmax, src, label, batch_size, stride, element_count); \
break;
/*
Wrapper of softmax with cross entropy forward soft label.
*/
template <typename T>
void SwitchWarpSoftmaxForwardSoftLabel(const int blocks, const dim3 threads,
gpuStream_t stream, T* loss, T* softmax,
const T* src, const T* label,
const int batch_size, const int stride,
const int element_count,
const int log2_elements) {
using AccT = typename details::MPTypeTrait<T>::Type;
switch (log2_elements) {
SOFTMAX_WARP_FORWARD_SOFT_CASE(0, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(1, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(2, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(3, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(4, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(5, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(6, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(7, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(8, T, AccT);
SOFTMAX_WARP_FORWARD_SOFT_CASE(9, T, AccT);
default:
break;
}
}
template <typename T>
static void SoftmaxWithCrossEntropySoftLabel(
const platform::CUDADeviceContext& ctx, const int rank, const int axis,
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
int N, int dim, int D) {
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(dim)));
int64_t grid_dim = N * D;
constexpr int max_dim = 320;
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
auto stream = ctx.stream();
if (D == 1 && dim <= max_dim) {
int kWarpSize = (kDimCeil < 32) ? kDimCeil : 32;
int batches_per_warp = (kDimCeil <= 128) ? 2 : 1;
// use 128 threads per block to maximimize gpu utilization
constexpr int threads_per_block = 128;
int warps_per_block = (threads_per_block / kWarpSize);
int batches_per_block = warps_per_block * batches_per_warp;
int blocks = (N + batches_per_block - 1) / batches_per_block;
dim3 threads(kWarpSize, warps_per_block, 1);
SwitchWarpSoftmaxForwardSoftLabel<T>(blocks, threads, stream, loss_data,
softmax_data, logits_data, labels_data,
N, dim, dim, kDimLog2);
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
DataLayout layout = DataLayout::kNCHW;
#ifdef PADDLE_WITH_HIP
miopenTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#else
cudnnTensorDescriptor_t descp = desc.descriptor<T>(layout, tensor_dims);
#endif
auto handle = ctx.cudnn_handle();
#ifdef PADDLE_WITH_HIP
auto mode = axis == rank - 1 ? MIOPEN_SOFTMAX_MODE_INSTANCE
: MIOPEN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::miopenSoftmaxForward_V2(
handle, platform::CudnnDataType<T>::kOne(), descp, logits_data,
platform::CudnnDataType<T>::kZero(), descp, softmax_data,
MIOPEN_SOFTMAX_LOG, mode));
#else
auto mode = axis == rank - 1 ? CUDNN_SOFTMAX_MODE_INSTANCE
: CUDNN_SOFTMAX_MODE_CHANNEL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cudnnSoftmaxForward(
handle, CUDNN_SOFTMAX_LOG, mode, platform::CudnnDataType<T>::kOne(),
descp, logits_data, platform::CudnnDataType<T>::kZero(), descp,
softmax_data));
#endif
const int kDimLog2 = static_cast<int>(Log2Ceil(dim));
const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
int kThreadPerBlock = 256;
#else
int kThreadPerBlock = 512;
#endif
int kBatchPerBlock = 1;
int blocks = (N * D + kBatchPerBlock - 1) / kBatchPerBlock;
dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);
CrossEntropySoftLabel<T, T, true><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, NULL, labels_data, N, dim, D, kDimLog2);
}
}
template <typename T>
__global__ void SoftCrossEntropyGradientKernel(T* logit_grad,
const T* loss_grad,
......@@ -560,373 +923,6 @@ __global__ void ScaleCrossEntropyGradient(T* logit_grad, const T* loss_grad,
}
}
static __device__ __forceinline__ platform::float16 exp_on_device(
platform::float16 x) {
return ::Eigen::numext::exp(x);
}
static __device__ __forceinline__ float exp_on_device(float x) {
return expf(x);
}
static __device__ __forceinline__ double exp_on_device(double x) {
return exp(x);
}
static __device__ __forceinline__ platform::float16 log_on_device(
platform::float16 x) {
return math::TolerableValue<platform::float16>()(::Eigen::numext::log(x));
}
static __device__ __forceinline__ float log_on_device(float x) {
return math::TolerableValue<float>()(logf(x));
}
static __device__ __forceinline__ double log_on_device(double x) {
return math::TolerableValue<double>()(log(x));
}
/** In the following codes, 3 CUDA kernels are implemented to calculate softmax
* and loss **/
/*
Supposing the x is `logits` and y is `labels`, the equations are as
followings:
cross\_entropy_i = \sum_{j}[- y_i_j * log({e^{x_i_j}/\sum_{j}e^{x_i_j}})]
= \sum_{j}[- y_i_j * log({e^{x_i_j - max_i}/\sum_{j}e^{x_i_j-max_i}})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - log\sum_{j}e^{x_i_j - max_i})]
= \sum_{j}[-y_i_j * (x_i_j - max_i - logDiffMaxSum_i)]
= \sum_{j}(-y_i_j * tmp_i_j)
softmax_i_j = e^{tmp_i_j}
where:
max_i = \max_{j}{x_i_j}
logDiffMaxSum_i = log\sum_{j}e^{x_i_j - max_i}
tmp_i_j = x_i_j - max_i - logDiffMaxSum_i
Therefore, the calculation can be separated into 3 steps:
Step 1: row-wise operation to calculate max_i
Step 2: row-wise operation to calculate logDiffMaxSum_i
Step 3: calculate tmp_i_j, and finally get softmax_i_j and cross\_entropy_i
To save memory, we can share memory among max_i, logDiffMaxSum_i and
cross\_entropy_i.
In this way, the 3 steps should be changed to:
Step 1 (RowReductionForMax): row-wise operation to calculate max_i
Step 2 (RowReductionForDiffMaxSum): calculate immediate result of softmax'_i_j =
x_i_j - max_i, and row-wise operation to calculate logDiffMaxSum_i
Step 3 (RowReductionForSoftmaxAndCrossEntropy): calculate tmp_i_j = softmax'_i_j
- logDiffMaxSum_i, and finally get softmax_i_j and cross\_entropy_i
*/
// There are 3 kinds of reduce algorithms in cub:
// BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY
// BLOCK_REDUCE_RAKING
// BLOCK_REDUCE_WARP_REDUCTIONS (default)
template <typename T, int BlockDim>
using BlockReduce =
cub::BlockReduce<T, BlockDim /*, cub::BLOCK_REDUCE_WARP_REDUCTIONS*/>;
template <typename T, int BlockDim>
using BlockReduceTempStorage = typename BlockReduce<T, BlockDim>::TempStorage;
// Make sure that BlockDim <= axis_dim
// This kernel is used to calculate the max element of each row
template <typename T, int BlockDim>
static __global__ void RowReductionForMax(const T* logits_data, T* max_data,
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits_data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
int64_t step = BlockDim * remain;
T cur_max = logits_data[beg_idx];
beg_idx += step;
while (beg_idx < end_idx) {
if (cur_max < logits_data[beg_idx]) {
cur_max = logits_data[beg_idx];
}
beg_idx += step;
}
cur_max = BlockReduce<T, BlockDim>(temp_storage).Reduce(cur_max, cub::Max());
if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max;
}
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiffMaxSum(const T* logits_data,
T* max_data, T* softmax,
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax data view as [n, axis_dim, remain]
// max_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
auto block_max = max_data[blockIdx.x];
int64_t step = BlockDim * remain;
// In numeric stable mode softmax_with_loss, we calc loss with
// tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of
// log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur.
// Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will
// be 1.0 and 0.0, represent prob is 1.0 and 0.0.
// So there is no need to clip on shift_softmax.
softmax[beg_idx] = logits_data[beg_idx] - block_max;
T diff_max_sum = exp_on_device(softmax[beg_idx]);
auto idx = beg_idx + step;
while (idx < end_idx) {
softmax[idx] = logits_data[idx] - block_max;
diff_max_sum += exp_on_device(softmax[idx]);
idx += step;
}
diff_max_sum =
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
if (!CalculateLogSoftmax) return;
__syncthreads();
diff_max_sum = max_data[blockIdx.x];
softmax[beg_idx] -= diff_max_sum;
beg_idx += step;
while (beg_idx < end_idx) {
softmax[beg_idx] -= diff_max_sum;
beg_idx += step;
}
// Note(zhiqiu): since different threads may use max_data[blockIdx.x] to
// calculate diff_max_sum, __syncthreads() is needed here.
__syncthreads();
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
}
#ifdef __HIPCC__ // @{ HIP Seperate Kernel for RowReductionForDiffMaxSum
// Note(qili93): HIP do not support return in kernel, need to seperate
// RowReductionForDiffMaxSum into two kernels below
template <typename T, int BlockDim>
static __global__ void RowReductionForSum(const T* logits_data, T* max_data,
T* softmax, int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
auto block_max = max_data[blockIdx.x];
int64_t step = BlockDim * remain;
softmax[beg_idx] = logits_data[beg_idx] - block_max;
T diff_max_sum = exp_on_device(softmax[beg_idx]);
auto idx = beg_idx + step;
while (idx < end_idx) {
softmax[idx] = logits_data[idx] - block_max;
diff_max_sum += exp_on_device(softmax[idx]);
idx += step;
}
diff_max_sum =
BlockReduce<T, BlockDim>(temp_storage).Reduce(diff_max_sum, cub::Sum());
if (threadIdx.x == 0) max_data[blockIdx.x] = log_on_device(diff_max_sum);
}
template <typename T, int BlockDim, bool CalculateLogSoftmax = false>
static __global__ void RowReductionForDiff(const T* logits_data, T* max_data,
T* softmax, int d, int axis_dim) {
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
int step = BlockDim * remain;
T diff_max_sum = max_data[blockIdx.x];
softmax[beg_idx] -= diff_max_sum;
beg_idx += step;
while (beg_idx < end_idx) {
softmax[beg_idx] -= diff_max_sum;
beg_idx += step;
}
__syncthreads();
if (threadIdx.x == 0) max_data[blockIdx.x] = 0;
}
#endif // @} End HIP Seperate Kernel for RowReductionForDiffMaxSum
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim>
static __global__ void RowReductionForSoftmaxAndCrossEntropy(
const T* logits_data, const T* labels_data, T* loss_data, T* softmax,
int64_t d, int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int64_t remain = d / axis_dim;
int64_t idx_n = blockIdx.x / remain;
int64_t idx_remain = blockIdx.x % remain;
int64_t beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int64_t end_idx = (idx_n + 1) * d;
// log_diff_max_sum shares memory with loss
auto block_log_diff_max_sum = loss_data[blockIdx.x];
auto tmp = softmax[beg_idx] - block_log_diff_max_sum;
softmax[beg_idx] = exp_on_device(tmp);
auto loss = -labels_data[beg_idx] * tmp;
int64_t step = BlockDim * remain;
beg_idx += step;
while (beg_idx < end_idx) {
tmp = softmax[beg_idx] - block_log_diff_max_sum;
softmax[beg_idx] = exp_on_device(tmp);
loss -= (labels_data[beg_idx] * tmp);
beg_idx += step;
}
loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}
// Make sure that BlockDim <= axis_dim
template <typename T, int BlockDim>
static __global__ void RowReductionForCrossEntropy(const T* logits_data,
const T* labels_data,
T* loss_data, int d,
int axis_dim) {
__shared__ BlockReduceTempStorage<T, BlockDim> temp_storage;
// logits, softmax, labels data view as [n, axis_dim, remain]
// loss_data view as [n, 1, remain]
// blockDim = n * remain, split blockIdx to idx_n and idx_remain
int remain = d / axis_dim;
int idx_n = blockIdx.x / remain;
int idx_remain = blockIdx.x % remain;
int beg_idx = idx_n * d + threadIdx.x * remain + idx_remain;
int end_idx = (idx_n + 1) * d;
// log_diff_max_sum shares memory with loss
auto block_log_diff_max_sum = loss_data[blockIdx.x];
auto tmp = log_on_device(logits_data[beg_idx]); // when not with softmax,
// softmax is stored in
// logits_data
auto loss = -labels_data[beg_idx] * tmp;
int step = BlockDim * remain;
beg_idx += step;
while (beg_idx < end_idx) {
tmp = log_on_device(logits_data[beg_idx]); // when not with softmax,
// softmax is stored in
// logits_data
loss -= (labels_data[beg_idx] * tmp);
beg_idx += step;
}
loss = BlockReduce<T, BlockDim>(temp_storage).Reduce(loss, cub::Sum());
if (threadIdx.x == 0) loss_data[blockIdx.x] = loss;
}
template <typename T>
static void SoftmaxWithCrossEntropyFusedKernel(
const T* logits_data, const T* labels_data, T* softmax_data, T* loss_data,
int64_t n, int64_t d, int axis_dim, gpuStream_t stream) {
#ifdef __HIPCC__
constexpr int kMaxBlockDim = 256;
#else
constexpr int kMaxBlockDim = 512;
#endif
int64_t block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int64_t grid_dim = n * d / axis_dim;
#ifdef __HIPCC__
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForMax<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, d, axis_dim); \
hipLaunchKernelGGL(HIP_KERNEL_NAME(RowReductionForSum<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, \
loss_data, softmax_data, d, axis_dim); \
hipLaunchKernelGGL( \
HIP_KERNEL_NAME(RowReductionForSoftmaxAndCrossEntropy<T, BlockDim>), \
dim3(grid_dim), dim3(BlockDim), 0, stream, logits_data, labels_data, \
loss_data, softmax_data, d, axis_dim); \
break
#else
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForMax<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, loss_data, d, axis_dim); \
RowReductionForDiffMaxSum<T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, loss_data, softmax_data, d, axis_dim); \
RowReductionForSoftmaxAndCrossEntropy< \
T, BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, softmax_data, d, axis_dim); \
break
#endif
switch (block_dim) {
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
// not with softmax
template <typename T>
static void CrossEntropyFusedKernel(const T* logits_data, const T* labels_data,
T* loss_data, int n, int d, int axis_dim,
gpuStream_t stream) {
constexpr int kMaxBlockDim = 512;
int block_dim = axis_dim >= kMaxBlockDim
? kMaxBlockDim
: (1 << static_cast<int>(std::log2(axis_dim)));
int grid_dim = n * d / axis_dim;
#define CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(BlockDim) \
case BlockDim: \
RowReductionForCrossEntropy<T, \
BlockDim><<<grid_dim, BlockDim, 0, stream>>>( \
logits_data, labels_data, loss_data, d, axis_dim); \
break
switch (block_dim) {
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(512);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(256);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(128);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(64);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(32);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(16);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(8);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(4);
CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL(2);
default:
PADDLE_THROW(platform::errors::Unavailable(
"Block Dimension must be 2^n in softmax_with_cross_entropy_op."));
break;
}
#undef CALL_SOFTMAX_WITH_CROSS_ENTROPY_FUSED_KERNEL
}
template <typename T>
class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
public:
......@@ -983,9 +979,22 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
if (soft_label) {
auto* logits_data = softmax->data<T>();
auto* labels_data = labels->data<T>();
CrossEntropyFusedKernel(logits_data, labels_data, loss_data, n, d,
axis_dim,
context.cuda_device_context().stream());
const int kDimLog2 = static_cast<int>(Log2Ceil(axis_dim));
const int kDimCeil = 1 << kDimLog2;
#ifdef __HIPCC__
int kThreadPerBlock = 256;
#else
int kThreadPerBlock = 512;
#endif
int kBatchPerBlock = 1;
int blocks = (n * d + kBatchPerBlock - 1) / kBatchPerBlock;
dim3 threads(kThreadPerBlock / kBatchPerBlock, kBatchPerBlock, 1);
CrossEntropySoftLabel<T, T, false><<<
blocks, threads, 0, context.cuda_device_context().stream()>>>(
loss_data, NULL, logits_data, labels_data, n, axis_dim,
d / axis_dim, kDimLog2);
} else { // HardLabel
auto* logits_data = softmax->data<T>();
auto* labels_data = labels->data<int64_t>();
......@@ -1040,9 +1049,9 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
if (soft_label) {
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<T>();
SoftmaxWithCrossEntropyFusedKernel(
logits_data, labels_data, softmax_data, loss_data, n, d, axis_dim,
context.cuda_device_context().stream());
SoftmaxWithCrossEntropySoftLabel<T>(
context.cuda_device_context(), rank, axis, logits_data, labels_data,
softmax_data, loss_data, n, axis_dim, d / axis_dim);
} else {
if (!context.Attr<bool>("numeric_stable_mode")) {
// CUDNN kernel only suppoer 2-D tensor and perfome softmax on last dim
......@@ -1103,7 +1112,11 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
const int64_t d = SizeFromAxis(axis, logit_grad->dims());
const int64_t remain = d / axis_dim;
#ifdef __HIPCC__
int block = 256;
#else
int block = 512;
#endif
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
auto use_softmax = context.Attr<bool>("use_softmax");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册