提交 92369177 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #4941 from QiJune/fix_elementwis_add_bug

fix elementwise add bug
...@@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { ...@@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.") "Rank of first input must >= rank of second input.")
if (x_dims == y_dims || product(y_dims) == 1) { if (x_dims == y_dims) {
functor f; functor f;
f.template Run<Place, T>(x, y, z, ctx); f.template Run<Place, T>(x, y, z, ctx);
return; return;
...@@ -174,12 +174,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { ...@@ -174,12 +174,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
return; return;
} }
if (product(y_dims) == 1) {
functor1 f;
f(place, x, y, out, dx, dy, dout);
return;
}
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
......
...@@ -92,5 +92,33 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): ...@@ -92,5 +92,33 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
} }
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4)
}
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 1).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1)
}
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册