From e0ca4d7a29533a8ee7a4dc7af4c9623187539707 Mon Sep 17 00:00:00 2001 From: caoying03 Date: Thu, 14 Sep 2017 09:45:10 +0800 Subject: [PATCH] fix shape of output tensor of cross_entropy_op. --- paddle/operators/cross_entropy_op.cc | 2 +- .../v2/framework/tests/test_cross_entropy_op.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index ab1e1c101..337ec41e5 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); - ctx.Output("Y")->Resize({X->dims()[0]}); + ctx.Output("Y")->Resize({X->dims()[0], 1}); } }; diff --git a/python/paddle/v2/framework/tests/test_cross_entropy_op.py b/python/paddle/v2/framework/tests/test_cross_entropy_op.py index c2fc102a8..253e7b8a2 100644 --- a/python/paddle/v2/framework/tests/test_cross_entropy_op.py +++ b/python/paddle/v2/framework/tests/test_cross_entropy_op.py @@ -8,20 +8,22 @@ class TestCrossEntropy(OpTest): self.op_type = "onehot_cross_entropy" batch_size = 30 class_num = 10 + X = numpy.random.uniform(0.1, 1.0, [batch_size, class_num]).astype("float32") - label = (class_num / 2) * numpy.ones(batch_size).astype("int32") - self.inputs = {'X': X, 'label': label} - Y = [] - for i in range(0, batch_size): - Y.append(-numpy.log(X[i][label[i]])) - self.outputs = {'Y': numpy.array(Y).astype("float32")} + labels = numpy.random.randint(0, class_num, batch_size, dtype="int32") + + cross_entropy = numpy.asmatrix( + [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])], + dtype="float32") + self.inputs = {"X": X, "label": labels} + self.outputs = {"Y": cross_entropy} def test_check_output(self): self.check_output() def test_check_grad(self): - self.check_grad(['X'], 'Y') + self.check_grad(["X"], "Y") if __name__ == "__main__": -- GitLab