未验证 提交 450af30c 编写于 作者: Z Zhang Zheng 提交者: GitHub

Correct the logic and remove unnecessary template param (#46623)

* Correct the logic and remove unnecessary template param

* fix error throw

* fix print format

* fix ci
上级 1230a3f4
...@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss, ...@@ -170,7 +170,7 @@ __global__ void CrossEntropySoftLabel(T* loss,
/* /*
Hard label cross entropy. Hard label cross entropy.
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
__global__ void CrossEntropyHardLabel(T* loss, __global__ void CrossEntropyHardLabel(T* loss,
const T* softmax, const T* softmax,
const LabelT* labels, const LabelT* labels,
...@@ -185,24 +185,19 @@ __global__ void CrossEntropyHardLabel(T* loss, ...@@ -185,24 +185,19 @@ __global__ void CrossEntropyHardLabel(T* loss,
// thread ids compute loss[ids] using softmax[idx] // thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) { if (ids < n * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
assert(lbl >= 0 && lbl < dim || lbl == ignore_idx); PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
if (lbl < 0 || lbl >= dim) { // label is out of bound "The value of label expected >= 0 and < %d, or == %d, "
loss[ids] = static_cast<T>(0.0); "but got %ld. Please check label value.",
} else { dim,
int64_t idx = idx_n * dim * d + lbl * d + idx_d; ignore_idx,
if (IgnoreIndex == true) { lbl);
// IgnoreIndex is true
if (lbl == ignore_idx) { if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { } else {
loss[ids] = -Log(softmax[idx]); int64_t idx = idx_n * dim * d + lbl * d + idx_d;
}
} else {
// IgnoreIndex is false
loss[ids] = -Log(softmax[idx]); loss[ids] = -Log(softmax[idx]);
} }
} }
}
} }
/* /*
...@@ -210,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss, ...@@ -210,7 +205,7 @@ __global__ void CrossEntropyHardLabel(T* loss,
Input: log softmax Input: log softmax
Output: loss and exp(input) Output: loss and exp(input)
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
__global__ void CrossEntropyExpHardLabel(T* loss, __global__ void CrossEntropyExpHardLabel(T* loss,
T* softmax, T* softmax,
const LabelT* labels, const LabelT* labels,
...@@ -226,25 +221,18 @@ __global__ void CrossEntropyExpHardLabel(T* loss, ...@@ -226,25 +221,18 @@ __global__ void CrossEntropyExpHardLabel(T* loss,
if (idx < n * dim * d) { if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
assert(lbl >= 0 && lbl < dim || lbl == ignore_idx); PADDLE_ENFORCE(lbl >= 0 && lbl < dim || lbl == ignore_idx,
if (IgnoreIndex == true) { "The value of label expected >= 0 and < %d, or == %d, "
// IgnoreIndex is true "but got %ld. Please check label value.",
if (idx_dim == lbl) { dim,
ignore_idx,
lbl);
if (lbl == ignore_idx) { if (lbl == ignore_idx) {
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { } else {
loss[ids] = -softmax[idx];
}
}
} else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < dim) {
if (lbl == idx_dim) { if (lbl == idx_dim) {
loss[ids] = -softmax[idx]; loss[ids] = -softmax[idx];
} }
} else {
loss[ids] = static_cast<T>(0.0);
}
} }
softmax[idx] = Exp(softmax[idx]); softmax[idx] = Exp(softmax[idx]);
} }
...@@ -292,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input, ...@@ -292,7 +280,7 @@ __device__ __forceinline__ AccT ThreadReduce(const T* input,
return val; return val;
} }
template <typename T, bool IgnoreIndex> template <typename T>
__device__ __forceinline__ void ComputeLoss(T* loss, __device__ __forceinline__ void ComputeLoss(T* loss,
const T loss_value, const T loss_value,
const int label_id, const int label_id,
...@@ -302,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss, ...@@ -302,14 +290,8 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
const int offset, const int offset,
const int ignore_index) { const int ignore_index) {
int loss_id = vec_size * tid + offset; int loss_id = vec_size * tid + offset;
if (IgnoreIndex) {
if (label_value == loss_id) {
if (label_value == ignore_index) { if (label_value == ignore_index) {
loss[label_id] = static_cast<T>(0.0f); loss[label_id] = static_cast<T>(0.0f);
} else {
loss[label_id] = loss_value;
}
}
} else { } else {
if (label_value == loss_id) { if (label_value == loss_id) {
loss[label_id] = loss_value; loss[label_id] = loss_value;
...@@ -317,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss, ...@@ -317,11 +299,7 @@ __device__ __forceinline__ void ComputeLoss(T* loss,
} }
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void VectorizedSoftmaxForwardImpl( __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
T* loss, T* loss,
T* softmax, T* softmax,
...@@ -335,8 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -335,8 +313,13 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
int tid = threadIdx.x; int tid = threadIdx.x;
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index); PADDLE_ENFORCE(
const bool label_valid = label_value >= 0 && label_value < size; label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
int loss_id_offset = 0; int loss_id_offset = 0;
if (offset > 0) { if (offset > 0) {
...@@ -348,8 +331,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -348,8 +331,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(logits[tid])); AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax), static_cast<T>(-log_softmax),
label_id, label_id,
label_value, label_value,
...@@ -358,7 +340,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -358,7 +340,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset, loss_id_offset,
ignore_index); ignore_index);
} }
}
size -= blockDim.x; size -= blockDim.x;
logits += blockDim.x; logits += blockDim.x;
softmax += blockDim.x; softmax += blockDim.x;
...@@ -383,8 +364,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -383,8 +364,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
outs[i] = static_cast<T>(std::exp(log_softmax)); outs[i] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax), static_cast<T>(-log_softmax),
label_id, label_id,
label_value, label_value,
...@@ -393,7 +373,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -393,7 +373,6 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset + i, loss_id_offset + i,
ignore_index); ignore_index);
} }
}
// write // write
reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec; reinterpret_cast<VecT*>(softmax)[tid] = *outs_vec;
...@@ -406,8 +385,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -406,8 +385,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax), static_cast<T>(-log_softmax),
label_id, label_id,
label_value, label_value,
...@@ -416,19 +394,9 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -416,19 +394,9 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
loss_id_offset, loss_id_offset,
ignore_index); ignore_index);
} }
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__device__ __forceinline__ void ScalarSoftmaxForwardImpl( __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
T* loss, T* loss,
T* softmax, T* softmax,
...@@ -441,8 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -441,8 +409,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
int remain = size % (VecSize * blockDim.x); int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index); PADDLE_ENFORCE(
const bool label_valid = label_value >= 0 && label_value < size; label_value >= 0 && label_value < size || label_value == ignore_index,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
size,
ignore_index,
label_value);
// main part // main part
for (; tid < (size - remain); tid += VecSize * blockDim.x) { for (; tid < (size - remain); tid += VecSize * blockDim.x) {
...@@ -457,8 +430,7 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -457,8 +430,7 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
AccT log_softmax = func(static_cast<AccT>(ins[i])); AccT log_softmax = func(static_cast<AccT>(ins[i]));
softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax)); softmax[tid + i * blockDim.x] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax), static_cast<T>(-log_softmax),
label_id, label_id,
label_value, label_value,
...@@ -468,15 +440,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -468,15 +440,13 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
ignore_index); ignore_index);
} }
} }
}
// tail part // tail part
for (; tid < size; tid += blockDim.x) { for (; tid < size; tid += blockDim.x) {
AccT log_softmax = func(static_cast<AccT>(logits[tid])); AccT log_softmax = func(static_cast<AccT>(logits[tid]));
softmax[tid] = static_cast<T>(std::exp(log_softmax)); softmax[tid] = static_cast<T>(std::exp(log_softmax));
// loss // loss
if (label_valid) { ComputeLoss<T>(loss,
ComputeLoss<T, IgnoreIndex>(loss,
static_cast<T>(-log_softmax), static_cast<T>(-log_softmax),
label_id, label_id,
label_value, label_value,
...@@ -485,19 +455,9 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -485,19 +455,9 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
0, 0,
ignore_index); ignore_index);
} }
}
// invalid label, write once
if (!label_valid && threadIdx.x == 0) {
loss[label_id] = static_cast<T>(0.0f);
}
} }
template <typename T, template <typename T, typename AccT, typename LabelT, int VecSize>
typename AccT,
typename LabelT,
int VecSize,
bool IgnoreIndex>
__global__ void VectorizedSoftmaxForward(T* loss, __global__ void VectorizedSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* logits, const T* logits,
...@@ -537,8 +497,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, ...@@ -537,8 +497,7 @@ __global__ void VectorizedSoftmaxForward(T* loss,
// 3. softmax // 3. softmax
phi::LogSoftmaxForwardFunctor<AccT> func(max, sum); phi::LogSoftmaxForwardFunctor<AccT> func(max, sum);
if (input_offset == output_offset) { if (input_offset == output_offset) {
VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>( VectorizedSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(loss,
loss,
softmax, softmax,
logits, logits,
label, label,
...@@ -547,7 +506,7 @@ __global__ void VectorizedSoftmaxForward(T* loss, ...@@ -547,7 +506,7 @@ __global__ void VectorizedSoftmaxForward(T* loss,
func, func,
ignore_index); ignore_index);
} else { } else {
ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize, IgnoreIndex>( ScalarSoftmaxForwardImpl<T, AccT, LabelT, VecSize>(
loss, softmax, logits, label, mid_dim, func, ignore_index); loss, softmax, logits, label, mid_dim, func, ignore_index);
} }
} }
...@@ -560,8 +519,8 @@ The computation includes ...@@ -560,8 +519,8 @@ The computation includes
- Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} - - Compute: sum of - sum_{j}{ label_{i,j} * (src_{i,j} - maxvalue_{i} -
log(sum[i]))} log(sum[i]))}
One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize). One warp (32 threads) is used to compute 1 or 2 batch (kBatchSize).
For reduction max (sum), firstly compute max (sum) to one warp, then use shuffle For reduction max (sum), firstly compute max (sum) to one warp, then use
api to compute max (sum) in one warp. shuffle api to compute max (sum) in one warp.
*/ */
template <typename T, typename VecT, typename AccT, int Log2Elements> template <typename T, typename VecT, typename AccT, int Log2Elements>
__global__ void WarpSoftmaxForwardSoftLabel(T* loss, __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
...@@ -880,8 +839,7 @@ template <typename T, ...@@ -880,8 +839,7 @@ template <typename T,
typename VecT, typename VecT,
typename AccT, typename AccT,
int Log2Elements, int Log2Elements,
SoftmaxMode mode, SoftmaxMode mode>
bool IgnoreIndex>
__global__ void WarpSoftmaxForward(T* loss, __global__ void WarpSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* src, const T* src,
...@@ -1033,24 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1033,24 +991,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index); if (lbl == ignore_index) {
if (IgnoreIndex == true) {
// IgnoreIndex is true
if (lbl == loss_idx) {
if (lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
} else {
loss[first_batch + i] = static_cast<T>(0.0); loss[first_batch + i] = static_cast<T>(0.0);
}
}
} else { } else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) { if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) { if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
loss[first_batch + i] = static_cast<T>(0.0); PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
} }
} }
} else { // softmax } else { // softmax
...@@ -1077,20 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1077,20 +1032,21 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index); if (lbl == ignore_index) {
if (IgnoreIndex == true) { loss[first_batch + i] = static_cast<T>(0.0);
// IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) {
loss[first_batch + i] = -logsoftmax;
}
} else { } else {
// IgnoreIndex is false
if (lbl >= 0 && lbl < element_count) { if (lbl >= 0 && lbl < element_count) {
if (lbl == loss_idx) { if (lbl == loss_idx) {
loss[first_batch + i] = -logsoftmax; loss[first_batch + i] = -logsoftmax;
} }
} else { } else {
loss[first_batch + i] = static_cast<T>(0.0); PADDLE_ENFORCE(
false,
"The value of label expected >= 0 and < %d, or == %d, "
"but got %ld. Please check label value.",
element_count,
ignore_index,
lbl);
} }
} }
} else { // softmax } else { // softmax
...@@ -1109,7 +1065,7 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1109,7 +1065,7 @@ __global__ void WarpSoftmaxForward(T* loss,
#define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \ #define SOFTMAX_WARP_FORWARD_CASE(Log2Elements, LabelT, VecT, AccT) \
case Log2Elements: \ case Log2Elements: \
WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode, IgnoreIndex> \ WarpSoftmaxForward<T, LabelT, VecT, AccT, Log2Elements, mode> \
<<<blocks, threads, 0, stream>>>(loss, \ <<<blocks, threads, 0, stream>>>(loss, \
softmax, \ softmax, \
src, \ src, \
...@@ -1123,7 +1079,7 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1123,7 +1079,7 @@ __global__ void WarpSoftmaxForward(T* loss,
/* /*
Wrapper of softmax with cross entropy forward hard label. Wrapper of softmax with cross entropy forward hard label.
*/ */
template <typename T, typename LabelT, SoftmaxMode mode, bool IgnoreIndex> template <typename T, typename LabelT, SoftmaxMode mode>
void SwitchWarpSoftmaxForward(T* loss, void SwitchWarpSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* src, const T* src,
...@@ -1162,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss, ...@@ -1162,7 +1118,7 @@ void SwitchWarpSoftmaxForward(T* loss,
} }
} }
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
void LaunchVectorizedSoftmaxForward(T* loss, void LaunchVectorizedSoftmaxForward(T* loss,
T* softmax, T* softmax,
const T* logits, const T* logits,
...@@ -1186,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, ...@@ -1186,7 +1142,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
block_size = std::max(block_size, kps::details::kWarpSize); block_size = std::max(block_size, kps::details::kWarpSize);
dim3 grids(high_dim); dim3 grids(high_dim);
dim3 blocks(block_size); dim3 blocks(block_size);
VectorizedSoftmaxForward<T, AccT, LabelT, vec_size, IgnoreIndex> VectorizedSoftmaxForward<T, AccT, LabelT, vec_size>
<<<grids, blocks, 0, stream>>>( <<<grids, blocks, 0, stream>>>(
loss, softmax, logits, label, high_dim, mid_dim, ignore_index); loss, softmax, logits, label, high_dim, mid_dim, ignore_index);
} }
...@@ -1197,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss, ...@@ -1197,7 +1153,7 @@ void LaunchVectorizedSoftmaxForward(T* loss,
- LaunchVectorizedSoftmaxForward for large size when axis == -1 - LaunchVectorizedSoftmaxForward for large size when axis == -1
- cudnn function for axis != -1 - cudnn function for axis != -1
*/ */
template <typename T, typename LabelT, bool IgnoreIndex> template <typename T, typename LabelT>
static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int rank, int rank,
int axis, int axis,
...@@ -1214,7 +1170,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, ...@@ -1214,7 +1170,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
if (D == 1) { if (D == 1) {
if (dim <= max_dim) { // small size if (dim <= max_dim) { // small size
const SoftmaxMode mode = SoftmaxMode::kCrossEntropy; const SoftmaxMode mode = SoftmaxMode::kCrossEntropy;
SwitchWarpSoftmaxForward<T, LabelT, mode, IgnoreIndex>(loss_data, SwitchWarpSoftmaxForward<T, LabelT, mode>(loss_data,
softmax_data, softmax_data,
logits_data, logits_data,
labels_data, labels_data,
...@@ -1224,7 +1180,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, ...@@ -1224,7 +1180,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
ignore_index, ignore_index,
stream); stream);
} else { // large size } else { // large size
LaunchVectorizedSoftmaxForward<T, LabelT, IgnoreIndex>(loss_data, LaunchVectorizedSoftmaxForward<T, LabelT>(loss_data,
softmax_data, softmax_data,
logits_data, logits_data,
labels_data, labels_data,
...@@ -1275,8 +1231,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx, ...@@ -1275,8 +1231,7 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int threads = 128; int threads = 128;
int blocks = (N * dim * D + threads - 1) / threads; int blocks = (N * dim * D + threads - 1) / threads;
// compute cross entropy, input is log softmax // compute cross entropy, input is log softmax
CrossEntropyExpHardLabel<T, LabelT, IgnoreIndex> CrossEntropyExpHardLabel<T, LabelT><<<blocks, threads, 0, stream>>>(
<<<blocks, threads, 0, stream>>>(
loss_data, softmax_data, labels_data, N, dim, D, ignore_index); loss_data, softmax_data, labels_data, N, dim, D, ignore_index);
} }
} }
...@@ -1373,8 +1328,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1373,8 +1328,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
auto* labels_data = labels.data<LabelT>(); auto* labels_data = labels.data<LabelT>();
int threads = 128; int threads = 128;
int blocks = (n * d / axis_dim + threads - 1) / threads; int blocks = (n * d / axis_dim + threads - 1) / threads;
if (ignore_index >= 0 && ignore_index < axis_dim) { CrossEntropyHardLabel<T, LabelT>
CrossEntropyHardLabel<T, LabelT, true>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data, <<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data, logits_data,
labels_data, labels_data,
...@@ -1382,16 +1336,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1382,16 +1336,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
axis_dim, axis_dim,
d / axis_dim, d / axis_dim,
ignore_index); ignore_index);
} else {
CrossEntropyHardLabel<T, LabelT, false>
<<<blocks, threads, 0, dev_ctx.stream()>>>(loss_data,
logits_data,
labels_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
} }
// cause of input is softmax // cause of input is softmax
...@@ -1456,8 +1400,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1456,8 +1400,7 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
} else { } else {
auto* logits_data = logits.data<T>(); auto* logits_data = logits.data<T>();
auto* labels_data = label.data<LabelT>(); auto* labels_data = label.data<LabelT>();
if (ignore_index >= 0 && ignore_index < axis_dim) { SoftmaxWithCrossEntropyHardLabel<T, LabelT>(dev_ctx,
SoftmaxWithCrossEntropyHardLabel<T, LabelT, true>(dev_ctx,
rank, rank,
axis_v, axis_v,
logits_data, logits_data,
...@@ -1468,19 +1411,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx, ...@@ -1468,19 +1411,6 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
axis_dim, axis_dim,
d / axis_dim, d / axis_dim,
ignore_index); ignore_index);
} else {
SoftmaxWithCrossEntropyHardLabel<T, LabelT, false>(dev_ctx,
rank,
axis_v,
logits_data,
labels_data,
loss_data,
softmax_data,
n,
axis_dim,
d / axis_dim,
ignore_index);
}
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册