diff --git a/paddle/framework/pybind.cc b/paddle/framework/pybind.cc index d4ac8fda5411a51c7c4b86400ae5506c39a64c00..b189e6f9e833b4374f493cc51ac6bfc08491bece 100644 --- a/paddle/framework/pybind.cc +++ b/paddle/framework/pybind.cc @@ -32,7 +32,7 @@ limitations under the License. */ namespace py = pybind11; USE_OP(add_two); -USE_OP(onehot_cross_entropy); +USE_OP_CPU(onehot_cross_entropy); USE_OP_WITHOUT_KERNEL(fc); USE_OP(sgd); USE_OP(mul); diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 2f453f8379ca7ce0612fed757719acb2d2cf0ad8..ec73721a810fa86d65409f643401eb77248ad5de 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -14,6 +14,3 @@ #define EIGEN_USE_GPU #include "paddle/operators/cross_entropy_op.h" - -REGISTER_OP_GPU_KERNEL(onehot_cross_entropy, - ops::OnehotCrossEntropyOpKernel); diff --git a/paddle/operators/cross_entropy_op.h b/paddle/operators/cross_entropy_op.h index 88d06e13469f8e6fc9e634d804c1fe0bed5e2d75..e02e3e2945af13fe283f95f7faa03b2a76d06125 100644 --- a/paddle/operators/cross_entropy_op.h +++ b/paddle/operators/cross_entropy_op.h @@ -18,7 +18,24 @@ limitations under the License. */ namespace paddle { namespace operators { -static const float kCrossEntropyLogThreshold{1e-20}; +template +T tolerable_value(T x) { + static_assert(std::is_floating_point::value, + "tolerable_value works only on float, " + "double and double double."); + + const T kApproInf = 1e20; + + if (x == INFINITY) { + return kApproInf; + } + + if (x == -INFINITY) { + return -kApproInf; + } + + return x; +} template class OnehotCrossEntropyOpKernel : public OpKernel { @@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { int batch_size = X->dims()[0]; int class_num = X->dims()[1]; - // Y[i] = -log(X[i][j]) for (int i = 0; i < batch_size; ++i) { - Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]], - kCrossEntropyLogThreshold)); + int index = i * class_num + label_data[i]; + Ydata[i] = -tolerable_value(std::log(Xdata[index])); } } }; @@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel { const int class_num = X->dims()[1]; for (int i = 0; i < batch_size; ++i) { - dXdata[i * class_num + label_data[i]] = - -dYdata[i] / std::max(Xdata[i * class_num + label_data[i]], - kCrossEntropyLogThreshold); + int index = i * class_num + label_data[i]; + dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]); } } };