提交 cef27dab 编写于 作者: D dongzhihong

"add fixl"

上级 c332e4ee
......@@ -63,6 +63,7 @@ class RowwiseAddGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null");
auto dims0 = ctx.Input<Tensor>("X")->dims();
auto dims1 = ctx.Input<Tensor>("b")->dims();
PADDLE_ENFORCE_EQ(1, framework::product(dims1), "b dims should be 1")
ctx.Output<Tensor>(framework::GradVarName("X"))->Resize(dims0);
ctx.Output<Tensor>(framework::GradVarName("b"))->Resize(dims1);
}
......
......@@ -21,12 +21,10 @@ class RowwiseAddGradOpTest(GradientChecker):
op = create_op("rowwise_add")
inputs = {
"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10, 1]).astype("float32")
"b": np.random.uniform(0.1, 1, [10]).astype("float32")
}
self.check_grad(op, inputs, set(["X", "b"]), "Out")
#TODO(dzh): rowwise_grad check
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册