diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 033d780faad267992d4c28fa486b093dc4e96af2..95078093e5905a47cd363815b62a1ed0db089bf2 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -62,7 +62,8 @@ struct DDimPlusVisitor { DDim DDim::operator+(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); - DDim ret(rank_); + DDim ret; + ret.rank_ = rank_; ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); return ret; } @@ -82,7 +83,8 @@ struct DDimMulVisitor { DDim DDim::operator*(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); - DDim ret(rank_); + DDim ret; + ret.rank_ = rank_; ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); return ret; } @@ -121,7 +123,9 @@ int64_t product(const DDim& ddim) { DDim slice_ddim(const DDim& dim, int begin, int end) { PADDLE_ENFORCE(begin >= 0, "Begin index can't be less than zero in ddim slice."); - DDim ret(end - begin); + int len = end - begin; + DDim ret; + ret.rank_ = len; dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); return ret; } @@ -156,7 +160,8 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim stride(const DDim& ddim) { - DDim strides(ddim.size()); + DDim strides; + strides.rank_ = ddim.size(); strides[ddim.size() - 1] = 1; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i + 1]; @@ -165,7 +170,8 @@ DDim stride(const DDim& ddim) { } DDim stride_numel(const DDim& ddim) { - DDim strides(ddim.size()); + DDim strides; + strides.rank_ = ddim.size(); strides[ddim.size() - 1] = ddim[ddim.size() - 1]; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i]; diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 36ad90a2ae48bb2620f2fe7ae6b82e665e3eba47..0d7b12120525b538ac714e8f89156f8d79f8c41d 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -136,12 +136,6 @@ class DDim { return *reinterpret_cast*>(p); } - // Construct DDim with given rank - // Only used in friend functions - explicit DDim(int rank) : rank_(rank) { - PADDLE_ENFORCE(rank_ >= 0 && rank_ < kMaxRank, "Invalid rank %d", rank); - } - friend DDim slice_ddim(const DDim& dim, int begin, int end); friend DDim stride(const DDim& ddim); friend DDim stride_numel(const DDim& ddim);