未验证 提交 37bc2d7b 编写于 作者: Z Zhang Zheng 提交者: GitHub

Move valid check from python to kernel (#46412)

* Move valid check from python to kernel

* fix error throw

* fix

* invalid label check

* fix

* Revert "fix"

This reverts commit 79fad6799cfa4b30423dbc84e67d7d843d22b84a.

* Revert "invalid label check"

This reverts commit 402a9707390ad5386b3222e85844b92d2e9b9fa4.

* Revert "fix"

This reverts commit 09ba3080ee0587447f875c19cdf060485f15ae3b.

* Revert "fix error throw"

This reverts commit a901bfcc2179d5c120ec29af766f392b122dab52.

* Revert "Move valid check from python to kernel"

This reverts commit baa03cc4ef82d8d45516c30dfb52bf5aead30748.

* final fix

* fix

* fix
上级 c7d60ce4
...@@ -185,9 +185,10 @@ __global__ void CrossEntropyHardLabel(T* loss, ...@@ -185,9 +185,10 @@ __global__ void CrossEntropyHardLabel(T* loss,
// thread ids compute loss[ids] using softmax[idx] // thread ids compute loss[ids] using softmax[idx]
if (ids < n * d) { if (ids < n * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
if (lbl < 0) { // label is negative assert(lbl >= 0 && lbl < dim || lbl == ignore_idx);
if (lbl < 0 || lbl >= dim) { // label is out of bound
loss[ids] = static_cast<T>(0.0); loss[ids] = static_cast<T>(0.0);
} else { // label is positive of zero } else {
int64_t idx = idx_n * dim * d + lbl * d + idx_d; int64_t idx = idx_n * dim * d + lbl * d + idx_d;
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
...@@ -225,6 +226,7 @@ __global__ void CrossEntropyExpHardLabel(T* loss, ...@@ -225,6 +226,7 @@ __global__ void CrossEntropyExpHardLabel(T* loss,
if (idx < n * dim * d) { if (idx < n * dim * d) {
auto lbl = static_cast<int64_t>(labels[ids]); auto lbl = static_cast<int64_t>(labels[ids]);
assert(lbl >= 0 && lbl < dim || lbl == ignore_idx);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (idx_dim == lbl) { if (idx_dim == lbl) {
...@@ -333,6 +335,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl( ...@@ -333,6 +335,7 @@ __device__ __forceinline__ void VectorizedSoftmaxForwardImpl(
int tid = threadIdx.x; int tid = threadIdx.x;
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index);
const bool label_valid = label_value >= 0 && label_value < size; const bool label_valid = label_value >= 0 && label_value < size;
int loss_id_offset = 0; int loss_id_offset = 0;
...@@ -438,6 +441,7 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl( ...@@ -438,6 +441,7 @@ __device__ __forceinline__ void ScalarSoftmaxForwardImpl(
int remain = size % (VecSize * blockDim.x); int remain = size % (VecSize * blockDim.x);
int label_id = blockIdx.x; int label_id = blockIdx.x;
auto label_value = static_cast<int64_t>(label[label_id]); auto label_value = static_cast<int64_t>(label[label_id]);
assert(label_value >= 0 && label_value < size || label_value == ignore_index);
const bool label_valid = label_value >= 0 && label_value < size; const bool label_valid = label_value >= 0 && label_value < size;
// main part // main part
...@@ -1029,6 +1033,7 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1029,6 +1033,7 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (lbl == loss_idx) { if (lbl == loss_idx) {
...@@ -1072,6 +1077,7 @@ __global__ void WarpSoftmaxForward(T* loss, ...@@ -1072,6 +1077,7 @@ __global__ void WarpSoftmaxForward(T* loss,
// label // label
int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s; int loss_idx = (threadIdx.x + it * kWarpSize) * kVSize + s;
auto lbl = static_cast<int64_t>(label[first_batch + i]); auto lbl = static_cast<int64_t>(label[first_batch + i]);
assert(lbl >= 0 && lbl < element_count || lbl == ignore_index);
if (IgnoreIndex == true) { if (IgnoreIndex == true) {
// IgnoreIndex is true // IgnoreIndex is true
if (lbl == loss_idx && lbl != ignore_index) { if (lbl == loss_idx && lbl != ignore_index) {
......
...@@ -2382,14 +2382,6 @@ def cross_entropy(input, ...@@ -2382,14 +2382,6 @@ def cross_entropy(input,
if soft_label == False: if soft_label == False:
valid_label = paddle.cast(label != ignore_index, valid_label = paddle.cast(label != ignore_index,
dtype=label.dtype) * label dtype=label.dtype) * label
label_min = paddle.min(valid_label)
label_max = paddle.max(valid_label)
if label_min < 0:
raise ValueError("Target {} is out of lower bound.".format(
label_min.item()))
if label_max >= input.shape[axis]:
raise ValueError("Target {} is out of upper bound.".format(
label_max.item()))
if core.is_compiled_with_npu() or core.is_compiled_with_mlu(): if core.is_compiled_with_npu() or core.is_compiled_with_mlu():
if soft_label == False: if soft_label == False:
_, _, out = _legacy_C_ops.softmax_with_cross_entropy( _, _, out = _legacy_C_ops.softmax_with_cross_entropy(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册