未验证 提交 8bfd45ad 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick]Move valid check from python to kernel (#46980)

为了提升性能,将label的边界检查从python端转移到kernel内,减少额外op的调用,如min、max和同步拷贝等
    当前的模板参数IgnoreIndex仅在ignore_index取值范围在[0, dim)时才生效,但是当某个label值超出了边界,ignore_index等于该label,这种情况下是应该仍然能正常计算。虽然当前的计算逻辑在结果上不会出错,但逻辑上仍是有问题的,且模板参数IgnoreIndex是没有必要的
上级 5c2bea17
......@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss,
/*
Hard label cross entropy.
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
__global__ void CrossEntropyHardLabel(T* loss,
const T* softmax,
const LabelT* labels,
......@@ -185,22 +185,18 @@ __global__ void CrossEntropyHardLabel(T* loss,
// thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl < 0) { // label is negative
loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero
int64_t idx = idx_n * dim * d + lbl * d + idx_d;
if (IgnoreIndex == true) {
// IgnoreIndex is true
PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
int64_t idx = idx_n * dim * d + lbl * d + idx_d;
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
}
}
}
......@@ -209,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss,
Input: log softmax
Output: loss and exp(input)
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
__global__ void CrossEntropyExpHardLabel(T* loss,
T* softmax,
const LabelT* labels,
......@@ -225,24 +221,18 @@ __global__ void CrossEntropyExpHardLabel(T* loss,
if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (idx_dim == lbl) {
PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -softmax[idx];
}
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < dim) {
if (lbl == idx_dim) {
loss[ids] = -softmax[idx];
}
} else {
loss[ids] = static_cast<T>(0.0);
}
}
softmax[idx] = Exp(softmax[idx]);
}
......@@ -290,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
return val;
}
template <typename T, bool IgnoreIndex>
template <typename T>
__device__ __forceinline__ void ComputeLoss(T* loss,
const T loss_value,
const int label_id,
......@@ -300,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
const int offset,
const int ignore_index) {
int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
} else {
if (label_value == loss_id) {
loss[label_id] = loss_value;
......@@ -315,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss,
T* softmax,
......@@ -333,7 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
int tid = threadIdx.x;
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
int loss_id_offset = 0;
if (offset > 0) {
......@@ -345,8 +331,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
......@@ -355,7 +340,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset,
ignore_index);
}
}
size -= blockDim.x;
logits += blockDim.x;
softmax += blockDim.x;
......@@ -380,8 +364,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
outs[i] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
......@@ -390,7 +373,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset + i,
ignore_index);
}
}
// write
reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec;
......@@ -403,8 +385,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
......@@ -413,19 +394,9 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss,
T* softmax,
......@@ -438,7 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size;
PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
// main part
for (; tid < (size - remain); tid += VecSize * blockDim.x) {
......@@ -453,8 +430,7 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(ins[i]));
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
......@@ -464,15 +440,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
ignore_index);
}
}
}
// tail part
for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss
if (label_valid) {
ComputeLoss<T, IgnoreIndex>(loss,
ComputeLoss<T>(loss,
static_cast<T>(-log_softmax),
label_id,
label_value,
......@@ -481,19 +455,9 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
0,
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
}
template <typename T,
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
template <typename T, typename AccT, typename LabelT, int VecSize>
__global__ void VectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
......@@ -533,8 +497,7 @@ __global__ void VectorizedSoftmaxForward(T* loss,
// 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
loss,
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(loss,
softmax,
logits,
label,
......@@ -543,7 +506,7 @@ __global__ void VectorizedSoftmaxForward(T* loss,
func,
ignore_index);
} else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>(
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(
loss, softmax, logits, label, mid_dim, func, ignore_index);
}
}
......@@ -556,8 +519,8 @@ The computation includes
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle
api to compute max (sum) in one warp.
For reduction max (sum), firstly compute max (sum) to one warp, then use
shuffle api to compute max (sum) in one warp.
*/
template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss,
......@@ -876,8 +839,7 @@ template <typename T,
typename VecT,
typename AccT,
int Log2Elements,
SoftmaxMode mode,
bool IgnoreIndex>
SoftmaxMode mode>
__global__ void WarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
......@@ -1029,23 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
if (lbl == ignore_index) {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
}
}
} else { // softmax
......@@ -1072,19 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
if (lbl == ignore_index) {
loss[first_batch + i] = static_cast<T>(0.0);
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax;
}
} else {
loss[first_batch + i] = static_cast<T>(0.0);
PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
}
}
} else { // softmax
......@@ -1103,7 +1065,7 @@ __global__ void WarpSoftmaxForward(T* loss,
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, IgnoreIndex> \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode> \
<<<blocks, threads, 0, stream>>>(loss, \
softmax, \
src, \
......@@ -1117,7 +1079,7 @@ __global__ void WarpSoftmaxForward(T* loss,
/*
Wrapper of softmax with cross entropy forward hard label.
*/
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex>
template <typename T, typename LabelT, SoftmaxMode mode>
void SwitchWarpSoftmaxForward(T* loss,
T* softmax,
const T* src,
......@@ -1156,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss,
}
}
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
void LaunchVectorizedSoftmaxForward(T* loss,
T* softmax,
const T* logits,
......@@ -1180,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim);
dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size, IgnoreIndex>
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size>
<<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
}
......@@ -1191,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
- LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1
*/
template <typename T, typename LabelT, bool IgnoreIndex>
template <typename T, typename LabelT>
static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int rank,
int axis,
......@@ -1208,7 +1170,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
if (D == 1) {
if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(loss_data,
SwitchWarpSoftmaxForward<T, LabelT, mode>(loss_data,
softmax_data,
logits_data,
labels_data,
......@@ -1218,7 +1180,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
ignore_index,
stream);
} else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(loss_data,
LaunchVectorizedSoftmaxForward<T, LabelT>(loss_data,
softmax_data,
logits_data,
labels_data,
......@@ -1269,8 +1231,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, LabelT, IgnoreIndex>
<<<blocks, threads, 0, stream>>>(
CrossEntropyExpHardLabel<T, LabelT><<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
}
}
......@@ -1367,8 +1328,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
auto* labels_data = labels.data<LabelT>();
int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) {
CrossEntropyHardLabel<T, LabelT, true>
CrossEntropyHardLabel<T, LabelT>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
......@@ -1376,16 +1336,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
axis_dim,
d / axis_dim,
ignore_index);
} else {
CrossEntropyHardLabel<T, LabelT, false>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
}
// cause of input is softmax
......@@ -1450,20 +1400,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
} else {
auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(dev_ctx,
SoftmaxWithCrossEntropyHardLabel<T, LabelT>(dev_ctx,
rank,
axis_v,
logits_data,
......@@ -1476,7 +1413,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
ignore_index);
}
}
}
}
template <typename T, typename Context>
......
......@@ -2386,14 +2386,6 @@ def cross_entropy(input,
if soft_label == False:
valid_label = paddle.cast(label != ignore_index,
dtype=label.dtype) * label
label_min = paddle.min(valid_label)
label_max = paddle.max(valid_label)
if label_min < 0:
raise ValueError("Target {} is out of lower bound.".format(
label_min.item()))
if label_max >= input.shape[axis]:
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if soft_label == False:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册