提交 4bbd05fd 编写于 作者: Q Qiao Longfei 提交者: GitHub

check INFINITY in cross_entropy (#3287)

* check INFINITY in cross_entropy

* fix error

* use onehot_cross_entropy without GPU kernel

* add support_gpu

* fix allclose

* fix name error and symplify code
上级 0a38864a
......@@ -32,7 +32,7 @@ limitations under the License. */
namespace py = pybind11;
USE_OP(add_two);
USE_OP(onehot_cross_entropy);
USE_OP_CPU(onehot_cross_entropy);
USE_OP_WITHOUT_KERNEL(fc);
USE_OP(sgd);
USE_OP(mul);
......
......@@ -14,6 +14,3 @@
#define EIGEN_USE_GPU
#include "paddle/operators/cross_entropy_op.h"
REGISTER_OP_GPU_KERNEL(onehot_cross_entropy,
ops::OnehotCrossEntropyOpKernel<ops::GPUPlace, float>);
......@@ -18,7 +18,24 @@ limitations under the License. */
namespace paddle {
namespace operators {
static const float kCrossEntropyLogThreshold{1e-20};
template <typename T>
T tolerable_value(T x) {
static_assert(std::is_floating_point<T>::value,
"tolerable_value works only on float, "
"double and double double.");
const T kApproInf = 1e20;
if (x == INFINITY) {
return kApproInf;
}
if (x == -INFINITY) {
return -kApproInf;
}
return x;
}
template <typename Place, typename T>
class OnehotCrossEntropyOpKernel : public OpKernel {
......@@ -36,10 +53,9 @@ class OnehotCrossEntropyOpKernel : public OpKernel {
int batch_size = X->dims()[0];
int class_num = X->dims()[1];
// Y[i] = -log(X[i][j])
for (int i = 0; i < batch_size; ++i) {
Ydata[i] = -std::log(std::max(Xdata[i * class_num + label_data[i]],
kCrossEntropyLogThreshold));
int index = i * class_num + label_data[i];
Ydata[i] = -tolerable_value(std::log(Xdata[index]));
}
}
};
......@@ -62,9 +78,8 @@ class OnehotCrossEntropyGradientOpKernel : public OpKernel {
const int class_num = X->dims()[1];
for (int i = 0; i < batch_size; ++i) {
dXdata[i * class_num + label_data[i]] =
-dYdata[i] / std::max(Xdata[i * class_num + label_data[i]],
kCrossEntropyLogThreshold);
int index = i * class_num + label_data[i];
dXdata[index] = -tolerable_value(dYdata[i] / Xdata[index]);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册