diff --git a/paddle/fluid/operators/expand_as_op.h b/paddle/fluid/operators/expand_as_op.h old mode 100755 new mode 100644 index 249f4c35a7fff1986c2dd951977126c46e3957cd..b189aa6f12274f09738e7f01f16f56f7ff20b534 --- a/paddle/fluid/operators/expand_as_op.h +++ b/paddle/fluid/operators/expand_as_op.h @@ -31,9 +31,7 @@ limitations under the License. */ break; \ } #define REP_EXPAND_AS_TEMPLATE(n) BOOST_PP_REPEAT(n, EXPAND_AS_TEMPLATE, ~) -#define COND(n) \ - BOOST_PP_GREATER_EQUAL(BOOST_PP_DIV(n, MAX_RANK_SUPPORTED), \ - BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) +#define COND(n) BOOST_PP_GREATER_EQUAL(n, BOOST_PP_MOD(n, MAX_RANK_SUPPORTED)) #define EXPAND_AS_GRAD_CASE(n) \ case n: { \ ExpandAsBackward(context, reshape_dims_vec, reduce_dims_vec); \ @@ -116,23 +114,20 @@ class ExpandAsGradKernel : public framework::OpKernel { std::vector reshape_dims_vec; std::vector reduce_dims_vec; for (size_t i = 0; i < bcast_dims.size(); ++i) { - if (bcast_dims[i] == 1) { - reshape_dims_vec.push_back(x_dims[i]); - } else { - if (x_dims[i] == 1) { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(bcast_dims[i]); - } else { - reduce_dims_vec.push_back(reshape_dims_vec.size()); - reshape_dims_vec.push_back(bcast_dims[i]); - reshape_dims_vec.push_back(x_dims[i]); - } + reduce_dims_vec.push_back(reshape_dims_vec.size()); + reshape_dims_vec.push_back(bcast_dims[i]); + reshape_dims_vec.push_back(x_dims[i]); + } + int dims = reduce_dims_vec.size(); + bool just_copy = true; + for (size_t i = 0; i < bcast_dims.size(); i++) { + if (bcast_dims[i] != 1) { + just_copy = false; + break; } } - int dims = reshape_dims_vec.size() * MAX_RANK_SUPPORTED + - reduce_dims_vec.size() - MAX_RANK_SUPPORTED - 1; // no need reduce, just copy - if (reduce_dims_vec.size() == 0) { + if (just_copy) { auto* in0 = context.Input(framework::GradVarName("Out")); auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); @@ -140,7 +135,7 @@ class ExpandAsGradKernel : public framework::OpKernel { out0); } else { switch (dims) { - REP_EXPAND_AS_GRAD_TEMPLATE(72) + REP_EXPAND_AS_GRAD_TEMPLATE(MAX_RANK_SUPPORTED) default: PADDLE_THROW("Only support tensor with rank being between 1 and 6."); } @@ -152,8 +147,8 @@ class ExpandAsGradKernel : public framework::OpKernel { void ExpandAsBackward(const framework::ExecutionContext& context, const std::vector& reshape_dims_vec, const std::vector& reduce_dims_vec) const { - size_t reshape_size = Dims / MAX_RANK_SUPPORTED + 1; - size_t reduce_size = Dims % MAX_RANK_SUPPORTED + 1; + size_t reshape_size = reshape_dims_vec.size(); + size_t reduce_size = reduce_dims_vec.size(); PADDLE_ENFORCE_EQ(reshape_size, reshape_dims_vec.size(), "Inconsistent size between template Dims and " "reshape dimensions."); @@ -164,11 +159,11 @@ class ExpandAsGradKernel : public framework::OpKernel { auto* out0 = context.Output(framework::GradVarName("X")); out0->mutable_data(context.GetPlace()); auto x_grad = EigenVector::Flatten(*out0); - Eigen::DSizes reshape_dims; + Eigen::DSizes reshape_dims; for (size_t i = 0; i < reshape_size; ++i) { reshape_dims[i] = reshape_dims_vec[i]; } - Eigen::DSizes reduce_dims; + Eigen::DSizes reduce_dims; for (size_t i = 0; i < reduce_size; ++i) { reduce_dims[i] = reduce_dims_vec[i]; }