From e87068290e2f6b714b5b171d8cd4cbfe985bd921 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Sat, 16 Sep 2017 18:57:13 +0800 Subject: [PATCH] Update cross entropy operator by following reviewer's comments. --- paddle/operators/cross_entropy_op.cc | 6 ++++++ paddle/operators/cross_entropy_op.cu | 3 ++- python/paddle/v2/framework/tests/test_cross_entropy_op.py | 2 +- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index c31c1328985..61d2104b95c 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -54,6 +54,9 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { protected: void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), + "Input(X) of CrossEntropyOp must not be null."); + auto dx = ctx.Output(framework::GradVarName("X")); auto x = ctx.Input("X"); @@ -74,11 +77,14 @@ CrossEntropy Operator. The second input (Label tensor) supports two kinds of shapes: 1) Rank(Label) = 1, Label[i] indicates the class index for sample i: + Y[i] = -log(X[i, Label[i]]) 2) Rank(Label) = 2, Label[i, j] indicates the soft label of class j for sample i: + Y[i] = \sum_j{-Label[i, j] * log(X[i, j])} + Please make sure that in this case the summuation of each row of Label equals one. If each row of Label has only one non-zero element (equals 1), it degenerates to a standard one-hot representation. diff --git a/paddle/operators/cross_entropy_op.cu b/paddle/operators/cross_entropy_op.cu index 1f5e9c1b04e..e80dcec8e25 100644 --- a/paddle/operators/cross_entropy_op.cu +++ b/paddle/operators/cross_entropy_op.cu @@ -14,6 +14,7 @@ #include "paddle/framework/op_registry.h" #include "paddle/platform/assert.h" +#include "paddle/platform/hostdevice.h" namespace paddle { namespace operators { @@ -21,7 +22,7 @@ namespace operators { using Tensor = framework::Tensor; template -__host__ __device__ T tolerable_value(const T x) { +HOSTDEVICE T tolerable_value(const T x) { PADDLE_ASSERT(std::is_floating_point::value); const T kApproInf = 1e20; if (x == INFINITY) { diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index a630dea7f54..ccff2a386d3 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -45,7 +45,7 @@ class TestCrossEntropySoftLabel(OpTest): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y', max_relative_error=0.05) + self.check_grad(['X'], 'Y') if __name__ == "__main__": -- GitLab