From 4bbd05fd724a650d32f2b7e842c9fcd55032722a Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 7 Aug 2017 20:42:43 +0800 Subject: [PATCH] check INFINITY in cross_entropy (#3287) * check INFINITY in cross_entropy * fix error * use onehot_cross_entropy without GPU kernel * add support_gpu * fix allclose * fix name error and symplify code --- paddle/framework/pybind.cc | 2 +- paddle/operators/cross_entropy_op.cu | 3 --- paddle/operators/cross_entropy_op.h | 29 +++++++++++++++++++++------- 3 files changed, 23 insertions(+), 11 deletions(-) diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index d4ac8fda541..b189e6f9e83 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -32,7 +32,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_OP(onehot_cross_entropy); +USE_OP_CPU(onehot_cross_entropy); USE_OP_WITHOUT_KERNEL(fc); USE_OP(sgd); USE_OP(mul); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 2f453f8379c..ec73721a810 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -14,6 +14,3 @@ #define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" - -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 88d06e13469..e02e3e2945a 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -18,7 +18,24 @@ limitations under the License. */ namespace paddle { namespace operators { -static const float kCrossEntropyLogThreshold{1e-20}; +template +T tolerable_value(T x) { + static_assert(std::is_floating_point::value, + "tolerable_value works only on float, " + "double and double double."); + + const T kApproInf = 1e20; + + if (x == INFINITY) { + return kApproInf; + } + + if (x == -INFINITY) { + return -kApproInf; + } + + return x; +} template class OnehotCrossEntropyOpKernel : public OpKernel { @@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { int batch_size = X->dims()[0]; int class_num = X->dims()[1]; - // Y[i] = -log(X[i][j]) for (int i = 0; i < batch_size; ++i) { - Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]], - kCrossEntropyLogThreshold)); + int index = i * class_num + label_data[i]; + Ydata[i] = -tolerable_value(std::log(Xdata[index])); } } }; @@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel { const int class_num = X->dims()[1]; for (int i = 0; i < batch_size; ++i) { - dXdata[i * class_num + label_data[i]] = - -dYdata[i] / std::max(Xdata[i * class_num + label_data[i]], - kCrossEntropyLogThreshold); + int index = i * class_num + label_data[i]; + dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); } } }; -- GitLab