diff --git a/paddle/operators/expand_op.h b/paddle/operators/expand_op.h index 5285d7525b64d71764f07059494f950962feabfb..2de849c4844a53dd5277c11aa70a84039be2a7c8 100644 --- a/paddle/operators/expand_op.h +++ b/paddle/operators/expand_op.h @@ -109,11 +109,23 @@ class ExpandGradKernel : public framework::OpKernel { } int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7; - switch (dims) { - REP_EXPAND_GRAD_TEMPLATE(72) - default: - PADDLE_ENFORCE(false, "Only support tensor whose rank in [1, 6]."); - }; + // no need reduce, just copy + if (reduce_dims_vec.size() == 0) { + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + if (platform::is_cpu_place(context.GetPlace())) { + out0->CopyFrom(*in0, platform::CPUPlace()); + } else { + out0->CopyFrom(*in0, platform::GPUPlace()); + } + } else { + switch (dims) { + REP_EXPAND_GRAD_TEMPLATE(72) + default: + PADDLE_ENFORCE(false, "Only support tensor whose rank in [1, 6]."); + }; + } } protected: diff --git a/python/paddle/v2/framework/tests/test_expand_op.py b/python/paddle/v2/framework/tests/test_expand_op.py index 9f5bd5f522569266e4d07c4e4a93f39ae4e8bd1d..1bf9a9129898606ac1b0e85c04a18595201f4768 100644 --- a/python/paddle/v2/framework/tests/test_expand_op.py +++ b/python/paddle/v2/framework/tests/test_expand_op.py @@ -22,8 +22,8 @@ class TestExpandOpRank2(OpTest): def setUp(self): self.op_type = "expand" self.inputs = {'X': np.random.random((12, 14)).astype("float32")} - self.attrs = {'expandTimes': [3, 4]} - output = np.tile(self.inputs['X'], (3, 4)) + self.attrs = {'expandTimes': [1, 1]} + output = np.tile(self.inputs['X'], (1, 1)) self.outputs = {'Out': output} def test_check_output(self): @@ -37,8 +37,8 @@ class TestExpandOpRank3(OpTest): def setUp(self): self.op_type = "expand" self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} - self.attrs = {'expandTimes': [3, 2, 1]} - output = np.tile(self.inputs['X'], (3, 2, 1)) + self.attrs = {'expandTimes': [1, 1, 1]} + output = np.tile(self.inputs['X'], (1, 1, 1)) self.outputs = {'Out': output} def test_check_output(self):