From 149a1e31242a3e9e5cb3b505ece43a3281503022 Mon Sep 17 00:00:00 2001 From: wangchaochaohu Date: Fri, 8 Nov 2019 14:30:55 +0800 Subject: [PATCH] Expand refine (#21063) * fix the expand op compile time cost long time test=develop * add tag for just copy test=develop --- paddle/fluid/operators/expand_op.h | 46 +++++++++++++----------------- 1 file changed, 20 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/expand_op.h b/paddle/fluid/operators/expand_op.h index eb3b46f913..4f167468eb 100644 --- a/paddle/fluid/operators/expand_op.h +++ b/paddle/fluid/operators/expand_op.h @@ -34,9 +34,7 @@ limitations under the License. */ break; \ } #define REP_EXPAND_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_TEMPLATE, ~) -#define COND(n) \ - BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \ - BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) +#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) #define EXPAND_GRAD_CASE(n) \ case n: { \ ExpandBackward(context, reshape_dims_vec, reduce_dims_vec); \ @@ -145,33 +143,29 @@ class ExpandGradKernel : public framework::OpKernel { // auto& expand_times = context.Attr>("expand_times"); auto expand_times = get_expand_times(context); auto x_dims = in0->dims(); - // 1. reshape_dims_vec is the broadcast parameter. For each dimension i, - // if expand_times[i] > 1 and x_dims[i] > 1, i will be splitted to two - // dimensions [expand_times[i], x_dims[i]]. + // 1. reshape_dims_vec is the broadcast parameter. // 2. reduce_dims_vec is the dimension parameter to compute gradients. For // each dimension expanded, the gradients should be summed to original // size. std::vector reshape_dims_vec; std::vector reduce_dims_vec; for (size_t i = 0; i < expand_times.size(); ++i) { - if (expand_times[i] == 1) { - reshape_dims_vec.push_back(x_dims[i]); - } else { - if (x_dims[i] == 1) { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(expand_times[i]); - } else { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(expand_times[i]); - reshape_dims_vec.push_back(x_dims[i]); - } - } + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(expand_times[i]); + reshape_dims_vec.push_back(x_dims[i]); } - int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED + - reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1; + int dims = reduce_dims_vec.size(); + + bool just_copy = true; + for (size_t i = 0; i < expand_times.size(); i++) { + if (expand_times[i] != 1) { + just_copy = false; + break; + } + } // no need reduce, just copy - if (reduce_dims_vec.size() == 0) { + if (just_copy) { auto* in0 = context.Input(framework::GradVarName("Out")); auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); @@ -179,7 +173,7 @@ class ExpandGradKernel : public framework::OpKernel { out0); } else { switch (dims) { - REP_EXPAND_GRAD_TEMPLATE(72) + REP_EXPAND_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) default: PADDLE_ENFORCE( false, "Only support tensor with rank being between 1 and 6."); @@ -192,8 +186,8 @@ class ExpandGradKernel : public framework::OpKernel { void ExpandBackward(const framework::ExecutionContext& context, const std::vector& reshape_dims_vec, const std::vector& reduce_dims_vec) const { - size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1; - size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1; + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), "Inconsistent size between template Dims and " "reshape dimensions."); @@ -204,11 +198,11 @@ class ExpandGradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; } -- GitLab