提交 4bbd05fd 编写于 作者: Q Qiao Longfei 提交者: GitHub

check INFINITY in cross_entropy (#3287)

* check INFINITY in cross_entropy

* fix error

* use onehot_cross_entropy without GPU kernel

* add support_gpu

* fix allclose

* fix name error and symplify code
上级 0a38864a
...@@ -32,7 +32,7 @@ limitations under the License. */ ...@@ -32,7 +32,7 @@ limitations under the License. */
namespace py = pybind11; namespace py = pybind11;
USE_OP(add_two); USE_OP(add_two);
USE_OP(onehot_cross_entropy); USE_OP_CPU(onehot_cross_entropy);
USE_OP_WITHOUT_KERNEL(fc); USE_OP_WITHOUT_KERNEL(fc);
USE_OP(sgd); USE_OP(sgd);
USE_OP(mul); USE_OP(mul);
......
...@@ -14,6 +14,3 @@ ...@@ -14,6 +14,3 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h" #include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
...@@ -18,7 +18,24 @@ limitations under the License. */ ...@@ -18,7 +18,24 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static const float kCrossEntropyLogThreshold{1e-20}; template <typename T>
T tolerable_value(T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename Place, typename T> template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel { class OnehotCrossEntropyOpKernel : public OpKernel {
...@@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel { ...@@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
int batch_size = X->dims()[0]; int batch_size = X->dims()[0];
int class_num = X->dims()[1]; int class_num = X->dims()[1];
// Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]], int index = i * class_num + label_data[i];
kCrossEntropyLogThreshold)); Ydata[i] = -tolerable_value(std::log(Xdata[index]));
} }
} }
}; };
...@@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel { ...@@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel {
const int class_num = X->dims()[1]; const int class_num = X->dims()[1];
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
dXdata[i * class_num + label_data[i]] = int index = i * class_num + label_data[i];
-dYdata[i] / std::max(Xdata[i * class_num + label_data[i]], dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
kCrossEntropyLogThreshold);
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册