未验证 提交 6d22f5c7 编写于 作者: J Jack Zhou 提交者: GitHub

Add PADDLE_ENFORCE in nll loss cuda kernel (#26294)

* add nll loss API, update demo code of the comment
上级 d03dd9d5
......@@ -44,6 +44,8 @@ __global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data,
out_data[i] = 0;
continue;
}
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
"label should not be out of bounds.");
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight;
}
......@@ -62,6 +64,8 @@ __global__ void GPUNLLLossForward1D_with_reduce(
for (i = threadIdx.x; i < batch_size; i += NTHREADS) {
const auto cur_label = label_data[i];
if (cur_label != ignore_index) {
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
"label should not be out of bounds.");
const auto cur_weight = weight_data ? weight_data[cur_label] : (T)1;
sharedInputs[threadIdx.x] -=
x_data[i * n_classes + cur_label] * cur_weight;
......@@ -198,6 +202,8 @@ __global__ void GPUNLLLossForward2D_no_reduce(
out_data[index] = 0;
continue;
}
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
"label should not be out of bounds.");
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
out_data[index] =
-x_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] *
......@@ -226,6 +232,8 @@ __global__ void GPUNLLLossForward2D_with_reduce(
i < map_nelem; i += step) {
const int64_t cur_label = label_data[toffset + i];
if (cur_label != ignore_index) {
PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes,
"label should not be out of bounds.");
const T cur_weight = weight_data ? weight_data[cur_label] : (T)1;
input_sum -= x_data[ioffset + i + map_nelem * cur_label] * cur_weight;
acc_weight += cur_weight;
......
......@@ -907,10 +907,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
def test_x_dim_imperative_lt_2():
with fluid.dygraph.guard():
x_np = np.array(
[0.88103855, 0.9908683, 0.6226845, 0.53331435,
0.07999352]).astype(np.float32)
label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
x_np = np.random.random(size=(5, )).astype(np.float64)
label_np = np.random.randint(0, 10, size=(5, )).astype(np.int64)
x = paddle.to_variable(x_np)
label = paddle.to_variable(label_np)
nll_loss = paddle.nn.loss.NLLLoss()
......@@ -933,13 +931,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
def test_NLLLoss_reduction_imperative_not_sum_mean_none():
with fluid.dygraph.guard():
x_np = np.array(
[[0.88103855, 0.9908683, 0.6226845],
[0.53331435, 0.07999352, 0.8549948],
[0.25879037, 0.39530203, 0.698465],
[0.73427284, 0.63575995, 0.18827209],
[0.05689114, 0.0862954, 0.6325046]]).astype(np.float32)
label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
x_np = np.random.random(size=(5, 3)).astype(np.float64)
label_np = np.random.randint(0, 3, size=(5, )).astype(np.int64)
x = paddle.to_variable(x_np)
label = paddle.to_variable(label_np)
nll_loss = paddle.nn.loss.NLLLoss(reduction='')
......@@ -962,13 +955,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase):
def test_nll_loss_function_reduction_imperative_not_sum_mean_none():
with fluid.dygraph.guard():
x_np = np.array(
[[0.88103855, 0.9908683, 0.6226845],
[0.53331435, 0.07999352, 0.8549948],
[0.25879037, 0.39530203, 0.698465],
[0.73427284, 0.63575995, 0.18827209],
[0.05689114, 0.0862954, 0.6325046]]).astype(np.float32)
label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64)
x_np = np.random.random(size=(5, 3)).astype(np.float64)
label_np = np.random.randint(0, 3, size=(5, )).astype(np.int64)
x = paddle.to_variable(x_np)
label = paddle.to_variable(label_np)
res = paddle.nn.functional.nll_loss(x, label, reduction='')
......
......@@ -41,6 +41,7 @@ from ...fluid.layers import edit_distance #DEFINE_ALIAS
from ...fluid.layers import huber_loss #DEFINE_ALIAS
from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS
from ...fluid.layer_helper import LayerHelper
from ...fluid.framework import in_dygraph_mode
from ...fluid.framework import Variable
__all__ = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册