提交 0e46f5eb 编写于 作者: C Cao Ying 提交者: GitHub

Merge pull request #4094 from lcy-seso/fix_cross_entropy_op_output_shape

fix shape of output tensor of cross_entropy_op.
......@@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
ctx.Output<Tensor>("Y")->Resize({X->dims()[0]});
ctx.Output<Tensor>("Y")->Resize({X->dims()[0], 1});
}
};
......
......@@ -8,20 +8,22 @@ class TestCrossEntropy(OpTest):
self.op_type = "onehot_cross_entropy"
batch_size = 30
class_num = 10
X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32")
label = (class_num / 2) * numpy.ones(batch_size).astype("int32")
self.inputs = {'X': X, 'label': label}
Y = []
for i in range(0, batch_size):
Y.append(-numpy.log(X[i][label[i]]))
self.outputs = {'Y': numpy.array(Y).astype("float32")}
labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
cross_entropy = numpy.asmatrix(
[[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
dtype="float32")
self.inputs = {"X": X, "label": labels}
self.outputs = {"Y": cross_entropy}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Y')
self.check_grad(["X"], "Y")
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册