From eb1aa015cda4fc12b6dc778ada6c3507b98134f5 Mon Sep 17 00:00:00 2001 From: qijun Date: Thu, 19 Oct 2017 17:13:52 -0700 Subject: [PATCH] revert code --- paddle/operators/elementwise_op_function.h | 8 +++++- .../tests/test_elementwise_add_op.py | 28 ------------------- 2 files changed, 7 insertions(+), 29 deletions(-) diff --git a/paddle/operators/elementwise_op_function.h b/paddle/operators/elementwise_op_function.h index 488a35aafc..3eb97f60b5 100644 --- a/paddle/operators/elementwise_op_function.h +++ b/paddle/operators/elementwise_op_function.h @@ -108,7 +108,7 @@ void ElementwiseCompute(const framework::ExecutionContext& ctx) { PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), "Rank of first input must >= rank of second input.") - if (x_dims == y_dims) { + if (x_dims == y_dims || product(y_dims) == 1) { functor f; f.template Run(x, y, z, ctx); return; @@ -174,6 +174,12 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) { return; } + if (product(y_dims) == 1) { + functor1 f; + f(place, x, y, out, dx, dy, dout); + return; + } + int axis = ctx.Attr("axis"); axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis); diff --git a/python/paddle/v2/framework/tests/test_elementwise_add_op.py b/python/paddle/v2/framework/tests/test_elementwise_add_op.py index 57daddd569..f3101a709b 100644 --- a/python/paddle/v2/framework/tests/test_elementwise_add_op.py +++ b/python/paddle/v2/framework/tests/test_elementwise_add_op.py @@ -92,33 +92,5 @@ class TestElementwiseAddOp_broadcast_3(TestElementwiseOp): } -class TestElementwiseAddOp_rowwise_add_0(TestElementwiseOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 3, 4).astype(np.float32), - 'Y': np.random.rand(3, 4).astype(np.float32) - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 3, 4) - } - - -class TestElementwiseAddOp_rowwise_add_1(TestElementwiseOp): - def setUp(self): - self.op_type = "elementwise_add" - self.inputs = { - 'X': np.random.rand(2, 1).astype(np.float32), - 'Y': np.random.rand(1).astype(np.float32) - } - - self.attrs = {'axis': 1} - self.outputs = { - 'Out': self.inputs['X'] + self.inputs['Y'].reshape(1, 1) - } - - if __name__ == '__main__': unittest.main() -- GitLab