未验证 提交 8bfd45ad 编写于 作者: Z Zhang Zheng 提交者: GitHub

[Cherry-Pick]Move valid check from python to kernel (#46980)

为了提升性能,将label的边界检查从python端转移到kernel内,减少额外op的调用,如min、max和同步拷贝等
    当前的模板参数IgnoreIndex仅在ignore_index取值范围在[0, dim)时才生效,但是当某个label值超出了边界,ignore_index等于该label,这种情况下是应该仍然能正常计算。虽然当前的计算逻辑在结果上不会出错,但逻辑上仍是有问题的,且模板参数IgnoreIndex是没有必要的
上级 5c2bea17
...@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss, ...@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss,
/* /*
Hard label cross entropy. Hard label cross entropy.
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
__global__ void CrossEntropyHardLabel(T* loss, __global__ void CrossEntropyHardLabel(T* loss,
const T* softmax, const T* softmax,
const LabelT* labels, const LabelT* labels,
...@@ -185,21 +185,17 @@ __global__ void CrossEntropyHardLabel(T* loss, ...@@ -185,21 +185,17 @@ __global__ void CrossEntropyHardLabel(T* loss,
// thread ids compute loss[ids] using softmax[idx] // thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) { if (ids < n * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl < 0) { // label is negative PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero } else {
int64_t idx = idx_n * dim * d + lbl * d + idx_d; int64_t idx = idx_n * dim * d + lbl * d + idx_d;
if (IgnoreIndex == true) { loss[ids] = -Log(softmax[idx]);
// IgnoreIndex is true
if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0);
} else {
loss[ids] = -Log(softmax[idx]);
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]);
}
} }
} }
} }
...@@ -209,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss, ...@@ -209,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss,
Input: log softmax Input: log softmax
Output: loss and exp(input) Output: loss and exp(input)
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
__global__ void CrossEntropyExpHardLabel(T* loss, __global__ void CrossEntropyExpHardLabel(T* loss,
T* softmax, T* softmax,
const LabelT* labels, const LabelT* labels,
...@@ -225,23 +221,17 @@ __global__ void CrossEntropyExpHardLabel(T* loss, ...@@ -225,23 +221,17 @@ __global__ void CrossEntropyExpHardLabel(T* loss,
if (idx < n * dim * d) { if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
if (IgnoreIndex == true) { PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
// IgnoreIndex is true "The value of label expected >= 0 and < %d, or == %d, "
if (idx_dim == lbl) { "but got %ld. Please check label value.",
if (lbl == ignore_idx) { dim,
loss[ids] = static_cast<T>(0.0); ignore_idx,
} else { lbl);
loss[ids] = -softmax[idx]; if (lbl == ignore_idx) {
} loss[ids] = static_cast<T>(0.0);
}
} else { } else {
// IgnoreIndex is false if (lbl == idx_dim) {
if (lbl >= 0 && lbl < dim) { loss[ids] = -softmax[idx];
if (lbl == idx_dim) {
loss[ids] = -softmax[idx];
}
} else {
loss[ids] = static_cast<T>(0.0);
} }
} }
softmax[idx] = Exp(softmax[idx]); softmax[idx] = Exp(softmax[idx]);
...@@ -290,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, ...@@ -290,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
return val; return val;
} }
template <typename T, bool IgnoreIndex> template <typename T>
__device__ __forceinline__ void ComputeLoss(T* loss, __device__ __forceinline__ void ComputeLoss(T* loss,
const T loss_value, const T loss_value,
const int label_id, const int label_id,
...@@ -300,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss, ...@@ -300,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
const int offset, const int offset,
const int ignore_index) { const int ignore_index) {
int loss_id = vec_size * tid + offset; int loss_id = vec_size * tid + offset;
if (IgnoreIndex) { if (label_value == ignore_index) {
if (label_value == loss_id) { loss[label_id] = static_cast<T>(0.0f);
if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
} else { } else {
if (label_value == loss_id) { if (label_value == loss_id) {
loss[label_id] = loss_value; loss[label_id] = loss_value;
...@@ -315,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss, ...@@ -315,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
} }
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl( __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* loss,
T* softmax, T* softmax,
...@@ -333,7 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -333,7 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
int tid = threadIdx.x; int tid = threadIdx.x;
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size; PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
int loss_id_offset = 0; int loss_id_offset = 0;
if (offset > 0) { if (offset > 0) {
...@@ -345,16 +331,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -345,16 +331,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid])); AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
static_cast<T>(-log_softmax), label_id,
label_id, label_value,
label_value, tid,
tid, 1,
1, loss_id_offset,
loss_id_offset, ignore_index);
ignore_index);
}
} }
size -= blockDim.x; size -= blockDim.x;
logits += blockDim.x; logits += blockDim.x;
...@@ -380,16 +364,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -380,16 +364,14 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
outs[i] = static_cast<T>(std::exp(log_softmax)); outs[i] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
static_cast<T>(-log_softmax), label_id,
label_id, label_value,
label_value, tid,
tid, VecSize,
VecSize, loss_id_offset + i,
loss_id_offset + i, ignore_index);
ignore_index);
}
} }
// write // write
...@@ -403,29 +385,18 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -403,29 +385,18 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
static_cast<T>(-log_softmax), label_id,
label_id, label_value,
label_value, tid,
tid, 1,
1, loss_id_offset,
loss_id_offset, ignore_index);
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
} }
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl( __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* loss,
T* softmax, T* softmax,
...@@ -438,7 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -438,7 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
int remain = size % (VecSize * blockDim.x); int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
const bool label_valid = label_value >= 0 && label_value < size; PADDLE_ENFORCE(
label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
// main part // main part
for (; tid < (size - remain); tid += VecSize * blockDim.x) { for (; tid < (size - remain); tid += VecSize * blockDim.x) {
...@@ -453,16 +430,14 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -453,16 +430,14 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(ins[i])); AccT log_softmax = func(static_cast<AccT>(ins[i]));
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax)); softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
static_cast<T>(-log_softmax), label_id,
label_id, label_value,
label_value, tid,
tid, VecSize,
VecSize, i,
i, ignore_index);
ignore_index);
}
} }
} }
...@@ -471,29 +446,18 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -471,29 +446,18 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid])); AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss, static_cast<T>(-log_softmax),
static_cast<T>(-log_softmax), label_id,
label_id, label_value,
label_value, tid,
tid, 1,
1, 0,
0, ignore_index);
ignore_index);
}
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
} }
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss, __global__ void VectorizedSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* logits, const T* logits,
...@@ -533,17 +497,16 @@ __global__ void VectorizedSoftmaxForward(T* loss, ...@@ -533,17 +497,16 @@ __global__ void VectorizedSoftmaxForward(T* loss,
// 3. softmax // 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum); phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) { if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>( VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(loss,
loss, softmax,
softmax, logits,
logits, label,
label, mid_dim,
mid_dim, input_offset,
input_offset, func,
func, ignore_index);
ignore_index);
} else { } else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>( ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(
loss, softmax, logits, label, mid_dim, func, ignore_index); loss, softmax, logits, label, mid_dim, func, ignore_index);
} }
} }
...@@ -556,8 +519,8 @@ The computation includes ...@@ -556,8 +519,8 @@ The computation includes
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))} log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle For reduction max (sum), firstly compute max (sum) to one warp, then use
api to compute max (sum) in one warp. shuffle api to compute max (sum) in one warp.
*/ */
template <typename T, typename VecT, typename AccT, int Log2Elements> template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss, __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
...@@ -876,8 +839,7 @@ template <typename T, ...@@ -876,8 +839,7 @@ template <typename T,
typename VecT, typename VecT,
typename AccT, typename AccT,
int Log2Elements, int Log2Elements,
SoftmaxMode mode, SoftmaxMode mode>
bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss, __global__ void WarpSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* src, const T* src,
...@@ -1029,23 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1029,23 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) { if (lbl == ignore_index) {
// IgnoreIndex is true loss[first_batch + i] = static_cast<T>(0.0);
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { } else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) { if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) { if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
loss[first_batch + i] = static_cast<T>(0.0); PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
} }
} }
} else { // softmax } else { // softmax
...@@ -1072,19 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1072,19 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
if (IgnoreIndex == true) { if (lbl == ignore_index) {
// IgnoreIndex is true loss[first_batch + i] = static_cast<T>(0.0);
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
} else { } else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) { if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) { if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
loss[first_batch + i] = static_cast<T>(0.0); PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
} }
} }
} else { // softmax } else { // softmax
...@@ -1101,23 +1063,23 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1101,23 +1063,23 @@ __global__ void WarpSoftmaxForward(T* loss,
} }
} }
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, IgnoreIndex> \ WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode> \
<<<blocks, threads, 0, stream>>>(loss, \ <<<blocks, threads, 0, stream>>>(loss, \
softmax, \ softmax, \
src, \ src, \
label, \ label, \
batch_size, \ batch_size, \
stride, \ stride, \
element_count, \ element_count, \
ignore_index); \ ignore_index); \
break; break;
/* /*
Wrapper of softmax with cross entropy forward hard label. Wrapper of softmax with cross entropy forward hard label.
*/ */
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex> template <typename T, typename LabelT, SoftmaxMode mode>
void SwitchWarpSoftmaxForward(T* loss, void SwitchWarpSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* src, const T* src,
...@@ -1156,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss, ...@@ -1156,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss,
} }
} }
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
void LaunchVectorizedSoftmaxForward(T* loss, void LaunchVectorizedSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* logits, const T* logits,
...@@ -1180,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, ...@@ -1180,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
block_size = std::max(block_size, kps::details::kWarpSize); block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim); dim3 grids(high_dim);
dim3 blocks(block_size); dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size, IgnoreIndex> VectorizedSoftmaxForward<T, AccT, LabelT, vec_size>
<<<grids, blocks, 0, stream>>>( <<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index); loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
} }
...@@ -1191,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, ...@@ -1191,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
- LaunchVectorizedSoftmaxForward for large size when axis == -1 - LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1 - cudnn function for axis != -1
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int rank, int rank,
int axis, int axis,
...@@ -1208,24 +1170,24 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, ...@@ -1208,24 +1170,24 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
if (D == 1) { if (D == 1) {
if (dim <= max_dim) { // small size if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(loss_data, SwitchWarpSoftmaxForward<T, LabelT, mode>(loss_data,
softmax_data, softmax_data,
logits_data, logits_data,
labels_data, labels_data,
N, N,
dim, dim,
dim, dim,
ignore_index, ignore_index,
stream); stream);
} else { // large size } else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(loss_data, LaunchVectorizedSoftmaxForward<T, LabelT>(loss_data,
softmax_data, softmax_data,
logits_data, logits_data,
labels_data, labels_data,
N, N,
dim, dim,
ignore_index, ignore_index,
stream); stream);
} }
} else { } else {
ScopedTensorDescriptor desc; ScopedTensorDescriptor desc;
...@@ -1269,9 +1231,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, ...@@ -1269,9 +1231,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int threads = 128; int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads; int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax // compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, LabelT, IgnoreIndex> CrossEntropyExpHardLabel<T, LabelT><<<blocks, threads, 0, stream>>>(
<<<blocks, threads, 0, stream>>>( loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
} }
} }
...@@ -1367,25 +1328,14 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1367,25 +1328,14 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
auto* labels_data = labels.data<LabelT>(); auto* labels_data = labels.data<LabelT>();
int threads = 128; int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads; int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) { CrossEntropyHardLabel<T, LabelT>
CrossEntropyHardLabel<T, LabelT, true> <<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data, logits_data,
logits_data, labels_data,
labels_data, n,
n, axis_dim,
axis_dim, d / axis_dim,
d / axis_dim, ignore_index);
ignore_index);
} else {
CrossEntropyHardLabel<T, LabelT, false>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
} }
// cause of input is softmax // cause of input is softmax
...@@ -1450,31 +1400,17 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1450,31 +1400,17 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
} else { } else {
auto* logits_data = logits.data<T>(); auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>(); auto* labels_data = label.data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) { SoftmaxWithCrossEntropyHardLabel<T, LabelT>(dev_ctx,
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(dev_ctx, rank,
rank, axis_v,
axis_v, logits_data,
logits_data, labels_data,
labels_data, loss_data,
loss_data, softmax_data,
softmax_data, n,
n, axis_dim,
axis_dim, d / axis_dim,
d / axis_dim, ignore_index);
ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
} }
} }
} }
......
...@@ -2386,14 +2386,6 @@ def cross_entropy(input, ...@@ -2386,14 +2386,6 @@ def cross_entropy(input,
if soft_label == False: if soft_label == False:
valid_label = paddle.cast(label != ignore_index, valid_label = paddle.cast(label != ignore_index,
dtype=label.dtype) * label dtype=label.dtype) * label
label_min = paddle.min(valid_label)
label_max = paddle.max(valid_label)
if label_min < 0:
raise ValueError("Target {} is out of lower bound.".format(
label_min.item()))
if label_max >= input.shape[axis]:
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if soft_label == False: if soft_label == False:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy( _, _, out = _legacy_C_ops.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册