提交 e8706829 编写于 作者: X Xinghai Sun

Update cross entropy operator by following reviewer's comments.

上级 d7717f2e
...@@ -54,6 +54,9 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { ...@@ -54,6 +54,9 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(const framework::InferShapeContext &ctx) const override { void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input(X) of CrossEntropyOp must not be null.");
auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X")); auto dx = ctx.Output<LoDTensor>(framework::GradVarName("X"));
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
...@@ -74,11 +77,14 @@ CrossEntropy Operator. ...@@ -74,11 +77,14 @@ CrossEntropy Operator.
The second input (Label tensor) supports two kinds of shapes: The second input (Label tensor) supports two kinds of shapes:
1) Rank(Label) = 1, Label[i] indicates the class index for sample i: 1) Rank(Label) = 1, Label[i] indicates the class index for sample i:
Y[i] = -log(X[i, Label[i]]) Y[i] = -log(X[i, Label[i]])
2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j 2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j
for sample i: for sample i:
Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} Y[i] = \sum_j{-Label[i, j] * log(X[i, j])}
Please make sure that in this case the summuation of each row of Label Please make sure that in this case the summuation of each row of Label
equals one. If each row of Label has only one non-zero element (equals 1), equals one. If each row of Label has only one non-zero element (equals 1),
it degenerates to a standard one-hot representation. it degenerates to a standard one-hot representation.
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h" #include "paddle/platform/assert.h"
#include "paddle/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -21,7 +22,7 @@ namespace operators { ...@@ -21,7 +22,7 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
template <typename T> template <typename T>
__host__ __device__ T tolerable_value(const T x) { HOSTDEVICE T tolerable_value(const T x) {
PADDLE_ASSERT(std::is_floating_point<T>::value); PADDLE_ASSERT(std::is_floating_point<T>::value);
const T kApproInf = 1e20; const T kApproInf = 1e20;
if (x == INFINITY) { if (x == INFINITY) {
......
...@@ -45,7 +45,7 @@ class TestCrossEntropySoftLabel(OpTest): ...@@ -45,7 +45,7 @@ class TestCrossEntropySoftLabel(OpTest):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y', max_relative_error=0.05) self.check_grad(['X'], 'Y')
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册