From 4520afcf3e8255b97325d1d4ab79d77e13a0655f Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 13 Sep 2017 17:07:00 +0800 Subject: [PATCH] Consider corner case. --- paddle/operators/expand_op.h | 22 ++++++++++++++----- .../v2/framework/tests/test_expand_op.py | 8 +++---- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/paddle/operators/expand_op.h b/paddle/operators/expand_op.h index 5285d7525..2de849c48 100644 --- a/paddle/operators/expand_op.h +++ b/paddle/operators/expand_op.h @@ -109,11 +109,23 @@ class ExpandGradKernel : public framework::OpKernel { } int dims = reshape_dims_vec.size() * 6 + reduce_dims_vec.size() - 7; - switch (dims) { - REP_EXPAND_GRAD_TEMPLATE(72) - default: - PADDLE_ENFORCE(false, "Only support tensor whose rank in [1, 6]."); - }; + // no need reduce, just copy + if (reduce_dims_vec.size() == 0) { + auto* in0 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + out0->mutable_data(context.GetPlace()); + if (platform::is_cpu_place(context.GetPlace())) { + out0->CopyFrom(*in0, platform::CPUPlace()); + } else { + out0->CopyFrom(*in0, platform::GPUPlace()); + } + } else { + switch (dims) { + REP_EXPAND_GRAD_TEMPLATE(72) + default: + PADDLE_ENFORCE(false, "Only support tensor whose rank in [1, 6]."); + }; + } } protected: diff --git a/python/paddle/v2/framework/tests/test_expand_op.py b/python/paddle/v2/framework/tests/test_expand_op.py index 9f5bd5f52..1bf9a9129 100644 --- a/python/paddle/v2/framework/tests/test_expand_op.py +++ b/python/paddle/v2/framework/tests/test_expand_op.py @@ -22,8 +22,8 @@ class TestExpandOpRank2(OpTest): def setUp(self): self.op_type = "expand" self.inputs = {'X': np.random.random((12, 14)).astype("float32")} - self.attrs = {'expandTimes': [3, 4]} - output = np.tile(self.inputs['X'], (3, 4)) + self.attrs = {'expandTimes': [1, 1]} + output = np.tile(self.inputs['X'], (1, 1)) self.outputs = {'Out': output} def test_check_output(self): @@ -37,8 +37,8 @@ class TestExpandOpRank3(OpTest): def setUp(self): self.op_type = "expand" self.inputs = {'X': np.random.random((2, 4, 5)).astype("float32")} - self.attrs = {'expandTimes': [3, 2, 1]} - output = np.tile(self.inputs['X'], (3, 2, 1)) + self.attrs = {'expandTimes': [1, 1, 1]} + output = np.tile(self.inputs['X'], (1, 1, 1)) self.outputs = {'Out': output} def test_check_output(self): -- GitLab