提交 950dbde5 编写于 作者: Q qiaolongfei

fix rowwise add grad op

上级 b7a6cc9c
...@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel { ...@@ -63,7 +63,7 @@ class RowwiseAddGradKernel : public framework::OpKernel {
// https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html // https://eigen.tuxfamily.org/dox/unsupported/TensorBase_8h_source.html
// colwise add // colwise add
Eigen::array<int, 1> dims{{1}}; /* dimension to reduce */ Eigen::array<int, 1> dims{{0}}; /* dimension to reduce */
EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims); EigenVector<T>::Flatten(*db).device(place) = OutGrad.sum(dims);
} }
}; };
......
...@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker): ...@@ -20,7 +20,7 @@ class RowwiseAddGradOpTest(GradientChecker):
def test_rowwise_add(self): def test_rowwise_add(self):
op = create_op("rowwise_add") op = create_op("rowwise_add")
inputs = { inputs = {
"X": np.random.uniform(0.1, 1, [10, 10]).astype("float32"), "X": np.random.uniform(0.1, 1, [5, 10]).astype("float32"),
"b": np.random.uniform(0.1, 1, [10]).astype("float32") "b": np.random.uniform(0.1, 1, [10]).astype("float32")
} }
self.check_grad(op, inputs, set(["X", "b"]), "Out") self.check_grad(op, inputs, set(["X", "b"]), "Out")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册