未验证 提交 450af30c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Correct the logic and remove unnecessary template param (#46623)

* Correct the logic and remove unnecessary template param

* fix error throw

* fix print format

* fix ci
上级 1230a3f4
......@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss,
/*
Hard label cross entropy.
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
__global__ void CrossEntropyHardLabel(T* loss,
const T* softmax,
const LabelT* labels,
......@@ -185,22 +185,17 @@ __global__ void CrossEntropyHardLabel(T* loss,
// thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
assert(lbl >= 0 && lbl < dim || lbl == ignore_idx);
if (lbl < 0 || lbl >= dim) { // label is out of bound
PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
int64_t idx = idx_n * dim * d + lbl * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
loss[ids] = -Log(softmax[idx]);
}
}
}
......@@ -210,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss,
Input: log softmax
Output: loss and exp(input)
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
__global__ void CrossEntropyExpHardLabel(T* loss,
T* softmax,
const LabelT* labels,
......@@ -226,24 +221,17 @@ __global__ void CrossEntropyExpHardLabel(T* loss,
if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
assert(lbl >= 0 && lbl < dim || lbl == ignore_idx);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (idx_dim == lbl) {
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -softmax[idx];
}
}
PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < dim) {
if (lbl == idx_dim) {
loss[ids] = -softmax[idx];
}
} else {
loss[ids] = static_cast<T>(0.0);
if (lbl == idx_dim) {
loss[ids] = -softmax[idx];
}
}
softmax[idx] = Exp(softmax[idx]);
......@@ -292,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
return val;
}
template <typename T, bool IgnoreIndex>
template <typename T>
__device__ __forceinline__ void ComputeLoss(T* loss,
const T loss_value,
const int label_id,
......@@ -302,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
const int offset,
const int ignore_index) {
int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
if (label_value == loss_id) {
loss[label_id] = loss_value;
......@@ -317,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss,
T* softmax,
......@@ -335,8 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
int tid = threadIdx.x;
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index);
const bool label_valid = label_value >= 0 && label_value < size;
PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
int loss_id_offset = 0;
if (offset > 0) {
......@@ -348,16 +331,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
size -= blockDim.x;
logits += blockDim.x;
......@@ -383,16 +364,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
outs[i] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
loss_id_offset + i,
ignore_index);
}
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
loss_id_offset + i,
ignore_index);
}
// write
......@@ -406,29 +385,18 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
loss_id_offset,
ignore_index);
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss,
T* softmax,
......@@ -441,8 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index);
const bool label_valid = label_value >= 0 && label_value < size;
PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
// main part
for (; tid < (size - remain); tid += VecSize * blockDim.x) {
......@@ -457,16 +430,14 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(ins[i]));
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
i,
ignore_index);
}
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
VecSize,
i,
ignore_index);
}
}
......@@ -475,29 +446,18 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
0,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
tid,
1,
0,
ignore_index);
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__global__ void VectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
......@@ -537,17 +497,16 @@ __global__ void VectorizedSoftmaxForward(T* loss,
// 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss,
softmax,
logits,
label,
mid_dim,
input_offset,
func,
ignore_index);
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(loss,
softmax,
logits,
label,
mid_dim,
input_offset,
func,
ignore_index);
} else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(
loss, softmax, logits, label, mid_dim, func, ignore_index);
}
}
......@@ -560,8 +519,8 @@ The computation includes
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss,
......@@ -880,8 +839,7 @@ template <typename T,
typename VecT,
typename AccT,
int Log2Elements,
SoftmaxMode mode,
bool IgnoreIndex>
SoftmaxMode mode>
__global__ void WarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
......@@ -1033,24 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
if (lbl == ignore_index) {
loss[first_batch + i] = static_cast<T>(0.0);
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
}
}
} else { // softmax
......@@ -1077,20 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
if (lbl == ignore_index) {
loss[first_batch + i] = static_cast<T>(0.0);
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
}
}
} else { // softmax
......@@ -1107,23 +1063,23 @@ __global__ void WarpSoftmaxForward(T* loss,
}
}
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, IgnoreIndex> \
<<<blocks, threads, 0, stream>>>(loss, \
softmax, \
src, \
label, \
batch_size, \
stride, \
element_count, \
ignore_index); \
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode> \
<<<blocks, threads, 0, stream>>>(loss, \
softmax, \
src, \
label, \
batch_size, \
stride, \
element_count, \
ignore_index); \
break;
/*
Wrapper of softmax with cross entropy forward hard label.
*/
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
template <typename T, typename LabelT, SoftmaxMode mode>
void SwitchWarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
......@@ -1162,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss,
}
}
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
void LaunchVectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
......@@ -1186,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim);
dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size, IgnoreIndex>
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size>
<<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}
......@@ -1197,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
- LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int rank,
int axis,
......@@ -1214,24 +1170,24 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
if (D == 1) {
if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
dim,
ignore_index,
stream);
SwitchWarpSoftmaxForward<T, LabelT, mode>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
dim,
ignore_index,
stream);
} else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
ignore_index,
stream);
LaunchVectorizedSoftmaxForward<T, LabelT>(loss_data,
softmax_data,
logits_data,
labels_data,
N,
dim,
ignore_index,
stream);
}
} else {
ScopedTensorDescriptor desc;
......@@ -1275,9 +1231,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, LabelT, IgnoreIndex>
<<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
CrossEntropyExpHardLabel<T, LabelT><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
}
}
......@@ -1373,25 +1328,14 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
auto* labels_data = labels.data<LabelT>();
int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) {
CrossEntropyHardLabel<T, LabelT, true>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
} else {
CrossEntropyHardLabel<T, LabelT, false>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
CrossEntropyHardLabel<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
// cause of input is softmax
......@@ -1456,31 +1400,17 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
} else {
auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
SoftmaxWithCrossEntropyHardLabel<T, LabelT>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册