未验证 提交 bbe5228c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Optimize perf of softmax_with_cross_entropy (#39553)

* Optimize perf of softmax_with_cross_entropy

* fix

* fix

* fix accuracy error
上级 2fedd39b
......@@ -27,6 +27,8 @@ namespace cub = hipcub;
namespace paddle {
namespace operators {
#define ALIGN_BYTES 16
using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using DataLayout = platform::DataLayout;
using Tensor = framework::Tensor;
......@@ -47,6 +49,18 @@ static __device__ __forceinline__ T Exp(T x) {
return math::TolerableValue<T>()(static_cast<T>(expx));
}
template <typename Tx, typename Ty = Tx>
struct ExpAddFunctor {
HOSTDEVICE inline ExpAddFunctor(Tx max) : max(max) {}
HOSTDEVICE inline Ty operator()(const Tx& sum, const Tx& x) const {
return static_cast<Ty>(sum + std::exp(x - max));
}
private:
Tx max;
};
// log2(value)
static inline int Log2Ceil(int value) {
int log2_value = 0;
......@@ -419,10 +433,272 @@ void SwitchWarpSoftmaxForward(T* loss, T* softmax, const T* src,
}
}
template <typename T, bool IgnoreIndex>
__device__ __forceinline__ void ComputeLoss(T* loss, const T loss_value,
const int label_id,
const int64_t label_value,
const int tid, const int vec_size,
const int offset,
const int ignore_index) {
int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
} else {
if (label_value == loss_id) {
loss[label_id] = loss_value;
}
}
}
template <typename T, typename AccT, int VecSize, class ReduceFunctor>
__device__ __forceinline__ AccT ThreadReduce(const T* input, int size,
const int offset, AccT init,
ReduceFunctor reducer) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
AccT val = init;
if (offset > 0) {
input -= offset;
size += offset;
if (tid >= offset) {
val = reducer(val, input[tid]);
}
size -= blockDim.x;
input += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
val = reducer(val, ins[i]);
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
val = reducer(val, input[tid]);
}
return val;
}
template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, int size,
const int offset, const LogSoftmaxForwardFunctor<AccT>& func,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
int loss_id_offset = 0;
if (offset > 0) {
logits -= offset;
softmax -= offset;
size += offset;
loss_id_offset -= offset;
if (tid >= offset) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, 1,
loss_id_offset, ignore_index);
}
}
size -= blockDim.x;
logits += blockDim.x;
softmax += blockDim.x;
loss_id_offset += blockDim.x;
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
T outs[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
VecT* outs_vec = reinterpret_cast<VecT*>(&outs);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
// read
*ins_vec = reinterpret_cast<const VecT*>(logits)[tid];
#pragma unroll
// compute
for (int i = 0; i < VecSize; ++i) {
AccT log_softmax = func(static_cast<AccT>(ins[i]));
outs[i] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, VecSize,
loss_id_offset + i, ignore_index);
}
}
// write
reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec;
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
label_value, tid, 1, loss_id_offset,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* softmax, const T* logits, const LabelT* label, const int size,
const LogSoftmaxForwardFunctor<AccT>& func, const int ignore_index) {
int tid = threadIdx.x;
int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
// main part
for (; tid < (size - remain); tid += VecSize * blockDim.x) {
T ins[VecSize];
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
ins[i] = logits[tid + i * blockDim.x];
}
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
AccT log_softmax = func(static_cast<AccT>(ins[i]));
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
label_id, label_value, tid, VecSize, i,
ignore_index);
}
}
}
// tail part
for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax), label_id,
label_value, tid, 1, 0, ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T, typename AccT, typename LabelT, int VecSize,
bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
const LabelT* label,
const int high_dim, const int mid_dim,
const int ignore_index) {
using VecT = kps::details::VectorType<T, VecSize>;
// each block deal with one batch
logits += blockIdx.x * mid_dim;
softmax += blockIdx.x * mid_dim;
const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
// 1. reduce max
AccT max = ThreadReduce<T, AccT, VecSize, kps::MaxFunctor<AccT>>(
logits, mid_dim, input_offset, -std::numeric_limits<AccT>::infinity(),
kps::MaxFunctor<AccT>());
max = kps::details::BlockXReduce<AccT, kps::MaxFunctor<AccT>>(
max, kps::MaxFunctor<AccT>());
// 2. reduce sum
AccT sum = ThreadReduce<T, AccT, VecSize, ExpAddFunctor<AccT>>(
logits, mid_dim, input_offset, static_cast<AccT>(0),
ExpAddFunctor<AccT>(max));
sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
sum, kps::AddFunctor<AccT>());
// 3. softmax
LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, input_offset, func,
ignore_index);
} else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss, softmax, logits, label, mid_dim, func, ignore_index);
}
}
template <typename T, typename LabelT, bool IgnoreIndex>
void LaunchVectorizedSoftmaxForward(T* loss, T* softmax, const T* logits,
const LabelT* label, const int high_dim,
const int mid_dim, const int ignore_index,
gpuStream_t stream) {
using AccT = typename details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
const int max_num_threads = 1024;
int max_block_size = std::min(mid_dim / vec_size, max_num_threads);
if (vec_size > 1) {
max_block_size /= 2;
}
int block_size = 1;
while (block_size < max_block_size) {
block_size *= 2;
}
block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim);
dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size,
IgnoreIndex><<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}
/*
Wrapper of softmax with cross entropy hard label.
- SwitchWarpSoftmaxForward for small size
- cudnn function for large size
- SwitchWarpSoftmaxForward for small size when axis == -1
- LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1
*/
template <typename T, typename LabelT, bool IgnoreIndex>
static void SoftmaxWithCrossEntropyHardLabel(
......@@ -431,11 +707,17 @@ static void SoftmaxWithCrossEntropyHardLabel(
T* softmax_data, int N, int dim, int D, const int ignore_index) {
auto stream = ctx.stream();
constexpr int max_dim = 320;
if (D == 1 && dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
ignore_index, stream);
if (D == 1) {
if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim, dim,
ignore_index, stream);
} else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(
loss_data, softmax_data, logits_data, labels_data, N, dim,
ignore_index, stream);
}
} else {
ScopedTensorDescriptor desc;
std::vector<int> tensor_dims = {N, dim, D, 1};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册