diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 488a35aafc8600bb8bb252fc3a5161c72a2f6df1..3eb97f60b59848d23bcd15ea1e3d2f21b721f6a4 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") - if (x_dims == y_dims) { + if (x_dims == y_dims || product(y_dims) == 1) { functor f; f.template Run(x, y, z, ctx); return; @@ -174,6 +174,12 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { return; } + if (product(y_dims) == 1) { + functor1 f; + f(place, x, y, out, dx, dy, dout); + return; + } + int axis = ctx.Attr("axis"); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); diff --git a/python/paddle/v2/framework/tests/test_elementwise_add_op.py b/python/paddle/v2/framework/tests/test_elementwise_add_op.py index 57daddd5698f77527bc5b78c436065a851867ae0..f3101a709b8bcf58e8682ab3d0ca5217a7f3572d 100644 --- a/python/paddle/v2/framework/tests/test_elementwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_elementwise_add_op.py @@ -92,33 +92,5 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): } -class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(3, 4).astype(np.float32) - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4) - } - - -class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 1).astype(np.float32), - 'Y': np.random.rand(1).astype(np.float32) - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1) - } - - if __name__ == '__main__': unittest.main()