提交 9e640444 编写于 作者: Q qijun

fix elementwise add bug

上级 43702a89
......@@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) {
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input.")
if (x_dims == y_dims || product(y_dims) == 1) {
if (x_dims == y_dims) {
functor f;
f.template Run<Place, T>(x, y, z, ctx);
return;
......@@ -174,12 +174,6 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
return;
}
if (product(y_dims) == 1) {
functor1 f;
f(place, x, y, out, dx, dy, dout);
return;
}
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
......
......@@ -92,5 +92,33 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp):
}
class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 3, 4).astype(np.float32),
'Y': np.random.rand(3, 4).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4)
}
class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp):
def setUp(self):
self.op_type = "elementwise_add"
self.inputs = {
'X': np.random.rand(2, 1).astype(np.float32),
'Y': np.random.rand(1).astype(np.float32)
}
self.attrs = {'axis': 1}
self.outputs = {
'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1)
}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册