提交 e0ca4d7a 编写于 作者: C caoying03

fix shape of output tensor of cross_entropy_op.

上级 8778957c
...@@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel { ...@@ -29,7 +29,7 @@ class OnehotCrossEntropyOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2."); PADDLE_ENFORCE_EQ(X->dims().size(), 2, "X's dimension must be 2.");
PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1."); PADDLE_ENFORCE_EQ(label->dims().size(), 1, "label's dimension must be 1.");
PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]); PADDLE_ENFORCE_EQ(X->dims()[0], label->dims()[0]);
ctx.Output<Tensor>("Y")->Resize({X->dims()[0]}); ctx.Output<Tensor>("Y")->Resize({X->dims()[0], 1});
} }
}; };
......
...@@ -8,20 +8,22 @@ class TestCrossEntropy(OpTest): ...@@ -8,20 +8,22 @@ class TestCrossEntropy(OpTest):
self.op_type = "onehot_cross_entropy" self.op_type = "onehot_cross_entropy"
batch_size = 30 batch_size = 30
class_num = 10 class_num = 10
X = numpy.random.uniform(0.1, 1.0, X = numpy.random.uniform(0.1, 1.0,
[batch_size, class_num]).astype("float32") [batch_size, class_num]).astype("float32")
label = (class_num / 2) * numpy.ones(batch_size).astype("int32") labels = numpy.random.randint(0, class_num, batch_size, dtype="int32")
self.inputs = {'X': X, 'label': label}
Y = [] cross_entropy = numpy.asmatrix(
for i in range(0, batch_size): [[-numpy.log(X[i][labels[i]])] for i in range(X.shape[0])],
Y.append(-numpy.log(X[i][label[i]])) dtype="float32")
self.outputs = {'Y': numpy.array(Y).astype("float32")} self.inputs = {"X": X, "label": labels}
self.outputs = {"Y": cross_entropy}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
def test_check_grad(self): def test_check_grad(self):
self.check_grad(['X'], 'Y') self.check_grad(["X"], "Y")
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册