diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index f7fee04c1e2730674f0b8e24470b53a33294717f..033d780faad267992d4c28fa486b093dc4e96af2 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -18,13 +18,6 @@ limitations under the License. */ namespace paddle { namespace framework { -template -Dim make_dim(const int64_t* d) { - Dim ret; - fix_dim_assign(d, ret.GetMutable()); - return ret; -} - DDim make_ddim(std::initializer_list dims) { return DDim(dims.begin(), dims.size()); } @@ -69,8 +62,7 @@ struct DDimPlusVisitor { DDim DDim::operator+(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); - DDim ret; - ret.rank_ = rank_; + DDim ret(rank_); ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); return ret; } @@ -90,8 +82,7 @@ struct DDimMulVisitor { DDim DDim::operator*(const DDim& d) const { PADDLE_ENFORCE(rank_ == d.rank_); - DDim ret; - ret.rank_ = rank_; + DDim ret(rank_); ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); return ret; } @@ -118,7 +109,7 @@ std::vector vectorize2int(const DDim& ddim) { struct ProductVisitor { template - int64_t operator()(const Dim& dim) { + inline int64_t operator()(const Dim& dim) { return product(dim); } }; @@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) { DDim slice_ddim(const DDim& dim, int begin, int end) { PADDLE_ENFORCE(begin >= 0, "Begin index can't be less than zero in ddim slice."); - DDim ret; - ret.rank_ = end - begin; + DDim ret(end - begin); dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); return ret; } @@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim stride(const DDim& ddim) { - DDim strides; - strides.rank_ = ddim.size(); + DDim strides(ddim.size()); strides[ddim.size() - 1] = 1; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i + 1]; @@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) { return strides; } -DDim stride_numel(const framework::DDim& ddim) { - DDim strides; - strides.rank_ = ddim.size(); +DDim stride_numel(const DDim& ddim) { + DDim strides(ddim.size()); strides[ddim.size() - 1] = ddim[ddim.size() - 1]; for (int i = ddim.size() - 2; i >= 0; --i) { strides[i] = strides[i + 1] * ddim[i]; diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index e65d451cdef11a7bd571d91483808a008308845f..36ad90a2ae48bb2620f2fe7ae6b82e665e3eba47 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -22,27 +22,31 @@ limitations under the License. */ namespace paddle { namespace framework { +#define PADDLE_VISIT_DDIM_BASE(rank, callback) \ + case (rank): { \ + constexpr auto kRank = (rank); \ + return (callback); \ + } + +#define PADDLE_VISIT_DDIM(rank, callback) \ + switch (rank) { \ + PADDLE_VISIT_DDIM_BASE(0, callback); \ + PADDLE_VISIT_DDIM_BASE(1, callback); \ + PADDLE_VISIT_DDIM_BASE(2, callback); \ + PADDLE_VISIT_DDIM_BASE(3, callback); \ + PADDLE_VISIT_DDIM_BASE(4, callback); \ + PADDLE_VISIT_DDIM_BASE(5, callback); \ + PADDLE_VISIT_DDIM_BASE(6, callback); \ + PADDLE_VISIT_DDIM_BASE(7, callback); \ + PADDLE_VISIT_DDIM_BASE(8, callback); \ + PADDLE_VISIT_DDIM_BASE(9, callback); \ + default: \ + PADDLE_THROW("Invalid rank %d", rank); \ + } + template inline void dynamic_dim_assign(const T1* in, T2* out, int n) { -#define STATIC_DIM_ASSIGN_CASE(rank) \ - case rank: \ - static_dim_assign(in, out); \ - return - switch (n) { - STATIC_DIM_ASSIGN_CASE(0); - STATIC_DIM_ASSIGN_CASE(1); - STATIC_DIM_ASSIGN_CASE(2); - STATIC_DIM_ASSIGN_CASE(3); - STATIC_DIM_ASSIGN_CASE(4); - STATIC_DIM_ASSIGN_CASE(5); - STATIC_DIM_ASSIGN_CASE(6); - STATIC_DIM_ASSIGN_CASE(7); - STATIC_DIM_ASSIGN_CASE(8); - STATIC_DIM_ASSIGN_CASE(9); - default: - PADDLE_THROW("Invalid rank %d", n); - } -#undef STATIC_DIM_ASSIGN_CASE + PADDLE_VISIT_DDIM(n, (static_dim_assign(in, out))); } /** @@ -84,22 +88,26 @@ class DDim { inline int64_t operator[](int idx) const { return dim_[idx]; } inline int64_t& at(int idx) { - PADDLE_ENFORCE(idx >= 0 && idx < rank_); + PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); return dim_[idx]; } inline int64_t at(int idx) const { - PADDLE_ENFORCE(idx >= 0 && idx < rank_); + PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx); return dim_[idx]; } template typename std::result_of&)>::type apply_visitor( - Visitor&& visitor); + Visitor&& visitor) { + PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast())); + } template typename std::result_of&)>::type apply_visitor( - Visitor&& visitor) const; + Visitor&& visitor) const { + PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast())); + } bool operator==(const DDim& d) const; @@ -128,55 +136,22 @@ class DDim { return *reinterpret_cast*>(p); } + // Construct DDim with given rank + // Only used in friend functions + explicit DDim(int rank) : rank_(rank) { + PADDLE_ENFORCE(rank_ >= 0 && rank_ < kMaxRank, "Invalid rank %d", rank); + } + friend DDim slice_ddim(const DDim& dim, int begin, int end); friend DDim stride(const DDim& ddim); friend DDim stride_numel(const DDim& ddim); + private: Dim dim_; int rank_; }; -#define PADDLE_VISIT_DDIM(rank) \ - case rank: \ - return visitor(UnsafeCast()) - -template -typename std::result_of&)>::type DDim::apply_visitor( - Visitor&& visitor) { - switch (rank_) { - PADDLE_VISIT_DDIM(0); - PADDLE_VISIT_DDIM(1); - PADDLE_VISIT_DDIM(2); - PADDLE_VISIT_DDIM(3); - PADDLE_VISIT_DDIM(4); - PADDLE_VISIT_DDIM(5); - PADDLE_VISIT_DDIM(6); - PADDLE_VISIT_DDIM(7); - PADDLE_VISIT_DDIM(8); - PADDLE_VISIT_DDIM(9); - default: - PADDLE_THROW("Invalid rank %d", rank_); - } -} - -template -typename std::result_of&)>::type DDim::apply_visitor( - Visitor&& visitor) const { - switch (rank_) { - PADDLE_VISIT_DDIM(0); - PADDLE_VISIT_DDIM(1); - PADDLE_VISIT_DDIM(2); - PADDLE_VISIT_DDIM(3); - PADDLE_VISIT_DDIM(4); - PADDLE_VISIT_DDIM(5); - PADDLE_VISIT_DDIM(6); - PADDLE_VISIT_DDIM(7); - PADDLE_VISIT_DDIM(8); - PADDLE_VISIT_DDIM(9); - default: - PADDLE_THROW("Invalid rank %d", rank_); - } -} +#undef PADDLE_VISIT_DDIM_BASE #undef PADDLE_VISIT_DDIM /** diff --git a/paddle/fluid/operators/detail/strided_memcpy.h b/paddle/fluid/operators/detail/strided_memcpy.h index fc223ce55931e0b826e46ffe93a90ea70a1540af..94419d1f9a4ba654952e0aedb46ab94ea8d5c0a8 100644 --- a/paddle/fluid/operators/detail/strided_memcpy.h +++ b/paddle/fluid/operators/detail/strided_memcpy.h @@ -98,8 +98,8 @@ struct StridedCopyDimVisitor { template void operator()(const framework::Dim& dst_dim) const { StridedMemcpyFunctor functor; - functor(dev_ctx_, src_, src_stride_.data(), dst_dim.data(), - dst_stride_.data(), dst_); + functor(dev_ctx_, src_, src_stride_.Get(), dst_dim.Get(), dst_stride_.Get(), + dst_); } const platform::DeviceContext& dev_ctx_;