From 3c98ec90ce7a2a959ad19539420c4e7165d49201 Mon Sep 17 00:00:00 2001 From: WangXi Date: Mon, 18 Nov 2019 11:06:50 +0800 Subject: [PATCH] Fix INF bug of softmax_cross_entropy_op (#21165) --- .../softmax_with_cross_entropy_op.cu | 11 +++-- .../test_fused_multihead_matmul_op.py | 4 +- .../unittests/test_locality_aware_nms_op.py | 4 +- .../tests/unittests/test_multiclass_nms_op.py | 28 ++++-------- .../fluid/tests/unittests/test_softmax_op.py | 4 +- .../test_softmax_with_cross_entropy_op.py | 43 ++++++++++++++++++- 6 files changed, 65 insertions(+), 29 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu index 12b64052a7..287f0670a8 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cu +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cu @@ -150,10 +150,7 @@ static __global__ void RowReductionForMax(const T* logits_data, T* max_data, cur_max = BlockReduce(temp_storage).Reduce(cur_max, cub::Max()); - if (threadIdx.x == 0) { - max_data[blockIdx.x] = - cur_max < static_cast(-64) ? static_cast(-64) : cur_max; - } + if (threadIdx.x == 0) max_data[blockIdx.x] = cur_max; } // Make sure that BlockDim <= axis_dim @@ -175,6 +172,12 @@ static __global__ void RowReductionForDiffMaxSum(const T* logits_data, auto block_max = max_data[blockIdx.x]; int step = BlockDim * remain; + // In numeric stable mode softmax_with_loss, we calc loss with + // tmp_i_j = x_i_j - max_i - logDiffMaxSum_i, instead of + // log(exp(x_i_j - max_i)/DiffMaxSum_i). Therefore, log(0) will not occur. + // Also we calc softmax_i_j = e^{tmp_i_j}, the maximum and minimum value will + // be 1.0 and 0.0, represent prob is 1.0 and 0.0. + // So there is no need to clip on shift_softmax. softmax[beg_idx] = logits_data[beg_idx] - block_max; T diff_max_sum = exp_on_device(softmax[beg_idx]); auto idx = beg_idx + step; diff --git a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py index e574b987d8..ffb4cee8fb 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_multihead_matmul_op.py @@ -25,7 +25,9 @@ np.random.random(123) def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x).clip(-64.) + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) diff --git a/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py b/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py index 1c8526f4df..bed185551b 100644 --- a/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_locality_aware_nms_op.py @@ -200,7 +200,9 @@ class TestLocalAwareNMSOp(OpTest): scores = np.random.random((N * M, C)).astype('float32') def softmax(x): - shiftx = x - np.max(x).clip(-64.) + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) diff --git a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py index 9839126088..a7240bc65f 100644 --- a/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py +++ b/python/paddle/fluid/tests/unittests/test_multiclass_nms_op.py @@ -19,6 +19,14 @@ import copy from op_test import OpTest +def softmax(x): + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) + exps = np.exp(shiftx) + return exps / np.sum(exps) + + def iou(box_a, box_b, norm): """Apply intersection-over-union overlap between box_a and box_b """ @@ -254,11 +262,6 @@ class TestMulticlassNMSOp(OpTest): scores = np.random.random((N * M, C)).astype('float32') - def softmax(x): - shiftx = x - np.max(x).clip(-64.) - exps = np.exp(shiftx) - return exps / np.sum(exps) - scores = np.apply_along_axis(softmax, 1, scores) scores = np.reshape(scores, (N, M, C)) scores = np.transpose(scores, (0, 2, 1)) @@ -318,11 +321,6 @@ class TestMulticlassNMSLoDInput(OpTest): scores = np.random.random((M, C)).astype('float32') - def softmax(x): - shiftx = x - np.max(x).clip(-64.) - exps = np.exp(shiftx) - return exps / np.sum(exps) - scores = np.apply_along_axis(softmax, 1, scores) boxes = np.random.random((M, C, BOX_SIZE)).astype('float32') @@ -382,11 +380,6 @@ class TestMulticlassNMS2Op(TestMulticlassNMSOp): scores = np.random.random((N * M, C)).astype('float32') - def softmax(x): - shiftx = x - np.max(x).clip(-64.) - exps = np.exp(shiftx) - return exps / np.sum(exps) - scores = np.apply_along_axis(softmax, 1, scores) scores = np.reshape(scores, (N, M, C)) scores = np.transpose(scores, (0, 2, 1)) @@ -447,11 +440,6 @@ class TestMulticlassNMS2LoDInput(TestMulticlassNMSLoDInput): scores = np.random.random((M, C)).astype('float32') - def softmax(x): - shiftx = x - np.max(x).clip(-64.) - exps = np.exp(shiftx) - return exps / np.sum(exps) - scores = np.apply_along_axis(softmax, 1, scores) boxes = np.random.random((M, C, BOX_SIZE)).astype('float32') diff --git a/python/paddle/fluid/tests/unittests/test_softmax_op.py b/python/paddle/fluid/tests/unittests/test_softmax_op.py index 50b29ba1e5..f99d99850b 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_op.py @@ -24,7 +24,9 @@ from paddle.fluid import compiler, Program, program_guard def stable_softmax(x): """Compute the softmax of vector x in a numerically stable way.""" - shiftx = x - np.max(x).clip(-64.) + # clip to shiftx, otherwise, when calc loss with + # log(exp(shiftx)), may get log(0)=INF + shiftx = (x - np.max(x)).clip(-64.) exps = np.exp(shiftx) return exps / np.sum(exps) diff --git a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py index 5892937479..ae1429719b 100644 --- a/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py +++ b/python/paddle/fluid/tests/unittests/test_softmax_with_cross_entropy_op.py @@ -58,7 +58,9 @@ class TestSoftmaxWithCrossEntropyOp(OpTest): def setUp(self): self.initParams() - logits = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype) + logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)) softmax = np.apply_along_axis(stable_softmax, self.axis, logits) if self.soft_label: @@ -119,7 +121,9 @@ class TestSoftmaxWithCrossEntropyOpFp16(TestSoftmaxWithCrossEntropyOp): self.op_type = "softmax_with_cross_entropy" # NOTE: numpy float16 have very low accuracy, use float32 for numpy check. - logits = np.random.uniform(0.1, 1.0, self.shape).astype(np.float32) + logits = getattr( + self, "logits", + np.random.uniform(0.1, 1.0, self.shape).astype(np.float32)) softmax = np.apply_along_axis(stable_softmax, self.axis, logits) axis_dim = self.shape[self.axis] @@ -405,5 +409,40 @@ class TestSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4( self.dtype = np.float64 +class TestSoftmaxWithCrossEntropyOpBoundary0(TestSoftmaxWithCrossEntropyOp): + """ + Test stable softmax with cross entropy operator will not product INF + with small logits value. + """ + + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float32 + self.logits = np.full(self.shape, -500.0).astype(self.dtype) + + +class TestSoftmaxWithCrossEntropyOpBoundary1(TestSoftmaxWithCrossEntropyOp): + """ + Test stable softmax with cross entropy operator will not product INF + with small logits value. + """ + + def initParams(self): + self.op_type = "softmax_with_cross_entropy" + self.numeric_stable_mode = True + self.soft_label = False + self.shape = [3, 5, 7, 11] + self.axis = -1 + self.ignore_index = -1 + self.dtype = np.float32 + self.logits = np.full(self.shape, 1000.0).astype(self.dtype) + self.logits[:, :, 0, :] = -1000.0 + + if __name__ == "__main__": unittest.main() -- GitLab