diff --git a/paddle/fluid/operators/nll_loss_op.cu b/paddle/fluid/operators/nll_loss_op.cu index 3d618805f02aa9b6d5310bfc8a79857f522f8ac5..531c175e03e5eee3eba609c322944b1398253726 100644 --- a/paddle/fluid/operators/nll_loss_op.cu +++ b/paddle/fluid/operators/nll_loss_op.cu @@ -44,6 +44,8 @@ __global__ void GPUNLLLossForward1D_no_reduce(T* out_data, const T* x_data, out_data[i] = 0; continue; } + PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes, + "label should not be out of bounds."); const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; out_data[i] = -x_data[i * n_classes + cur_label] * cur_weight; } @@ -62,6 +64,8 @@ __global__ void GPUNLLLossForward1D_with_reduce( for (i = threadIdx.x; i < batch_size; i += NTHREADS) { const auto cur_label = label_data[i]; if (cur_label != ignore_index) { + PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes, + "label should not be out of bounds."); const auto cur_weight = weight_data ? weight_data[cur_label] : (T)1; sharedInputs[threadIdx.x] -= x_data[i * n_classes + cur_label] * cur_weight; @@ -198,6 +202,8 @@ __global__ void GPUNLLLossForward2D_no_reduce( out_data[index] = 0; continue; } + PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes, + "label should not be out of bounds."); const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; out_data[index] = -x_data[b * sample_size + cur_label * map_size + h * in_dim3 + w] * @@ -226,6 +232,8 @@ __global__ void GPUNLLLossForward2D_with_reduce( i < map_nelem; i += step) { const int64_t cur_label = label_data[toffset + i]; if (cur_label != ignore_index) { + PADDLE_ENFORCE(cur_label >= 0 && cur_label < n_classes, + "label should not be out of bounds."); const T cur_weight = weight_data ? weight_data[cur_label] : (T)1; input_sum -= x_data[ioffset + i + map_nelem * cur_label] * cur_weight; acc_weight += cur_weight; diff --git a/python/paddle/fluid/tests/unittests/test_nll_loss.py b/python/paddle/fluid/tests/unittests/test_nll_loss.py index c25f8832807bc9a9da84ee44ee8172e8d1d0dd94..e7154193beaf788a9d20f3c131b1df3420918266 100644 --- a/python/paddle/fluid/tests/unittests/test_nll_loss.py +++ b/python/paddle/fluid/tests/unittests/test_nll_loss.py @@ -907,10 +907,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): def test_x_dim_imperative_lt_2(): with fluid.dygraph.guard(): - x_np = np.array( - [0.88103855, 0.9908683, 0.6226845, 0.53331435, - 0.07999352]).astype(np.float32) - label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x_np = np.random.random(size=(5, )).astype(np.float64) + label_np = np.random.randint(0, 10, size=(5, )).astype(np.int64) x = paddle.to_variable(x_np) label = paddle.to_variable(label_np) nll_loss = paddle.nn.loss.NLLLoss() @@ -933,13 +931,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): def test_NLLLoss_reduction_imperative_not_sum_mean_none(): with fluid.dygraph.guard(): - x_np = np.array( - [[0.88103855, 0.9908683, 0.6226845], - [0.53331435, 0.07999352, 0.8549948], - [0.25879037, 0.39530203, 0.698465], - [0.73427284, 0.63575995, 0.18827209], - [0.05689114, 0.0862954, 0.6325046]]).astype(np.float32) - label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x_np = np.random.random(size=(5, 3)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, )).astype(np.int64) x = paddle.to_variable(x_np) label = paddle.to_variable(label_np) nll_loss = paddle.nn.loss.NLLLoss(reduction='') @@ -962,13 +955,8 @@ class TestNLLLossInvalidArgs(unittest.TestCase): def test_nll_loss_function_reduction_imperative_not_sum_mean_none(): with fluid.dygraph.guard(): - x_np = np.array( - [[0.88103855, 0.9908683, 0.6226845], - [0.53331435, 0.07999352, 0.8549948], - [0.25879037, 0.39530203, 0.698465], - [0.73427284, 0.63575995, 0.18827209], - [0.05689114, 0.0862954, 0.6325046]]).astype(np.float32) - label_np = np.array([0, 2, 1, 1, 0]).astype(np.int64) + x_np = np.random.random(size=(5, 3)).astype(np.float64) + label_np = np.random.randint(0, 3, size=(5, )).astype(np.int64) x = paddle.to_variable(x_np) label = paddle.to_variable(label_np) res = paddle.nn.functional.nll_loss(x, label, reduction='') diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 19ffe572a9cbad36e3701269525c41bba61a2bfd..e08c707b8daa6bae8bc30b2753852d41319cebb4 100644 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -41,6 +41,7 @@ from ...fluid.layers import edit_distance #DEFINE_ALIAS from ...fluid.layers import huber_loss #DEFINE_ALIAS from ...fluid.layers import sampled_softmax_with_cross_entropy #DEFINE_ALIAS from ...fluid.layer_helper import LayerHelper +from ...fluid.framework import in_dygraph_mode from ...fluid.framework import Variable __all__ = [