提交 83ac8515 编写于 作者: S sneaxiy

polish code

test=develop
上级 045dc127
...@@ -18,13 +18,6 @@ limitations under the License. */ ...@@ -18,13 +18,6 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <int N>
Dim<N> make_dim(const int64_t* d) {
Dim<N> ret;
fix_dim_assign(d, ret.GetMutable());
return ret;
}
DDim make_ddim(std::initializer_list<int64_t> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
return DDim(dims.begin(), dims.size()); return DDim(dims.begin(), dims.size());
} }
...@@ -69,8 +62,7 @@ struct DDimPlusVisitor { ...@@ -69,8 +62,7 @@ struct DDimPlusVisitor {
DDim DDim::operator+(const DDim& d) const { DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret; DDim ret(rank_);
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -90,8 +82,7 @@ struct DDimMulVisitor { ...@@ -90,8 +82,7 @@ struct DDimMulVisitor {
DDim DDim::operator*(const DDim& d) const { DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret; DDim ret(rank_);
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -118,7 +109,7 @@ std::vector<int> vectorize2int(const DDim& ddim) { ...@@ -118,7 +109,7 @@ std::vector<int> vectorize2int(const DDim& ddim) {
struct ProductVisitor { struct ProductVisitor {
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) { inline int64_t operator()(const Dim<D>& dim) {
return product(dim); return product(dim);
} }
}; };
...@@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) { ...@@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) {
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0, PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice."); "Begin index can't be less than zero in ddim slice.");
DDim ret; DDim ret(end - begin);
ret.rank_ = end - begin;
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_); dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret; return ret;
} }
...@@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { ...@@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
DDim strides; DDim strides(ddim.size());
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1; strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1]; strides[i] = strides[i + 1] * ddim[i + 1];
...@@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) { ...@@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) {
return strides; return strides;
} }
DDim stride_numel(const framework::DDim& ddim) { DDim stride_numel(const DDim& ddim) {
DDim strides; DDim strides(ddim.size());
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i]; strides[i] = strides[i + 1] * ddim[i];
......
...@@ -22,27 +22,31 @@ limitations under the License. */ ...@@ -22,27 +22,31 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
case (rank): { \
constexpr auto kRank = (rank); \
return (callback); \
}
#define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \
PADDLE_THROW("Invalid rank %d", rank); \
}
template <typename T1, typename T2> template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) { inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
#define STATIC_DIM_ASSIGN_CASE(rank) \ PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
case rank: \
static_dim_assign<rank, T1, T2>(in, out); \
return
switch (n) {
STATIC_DIM_ASSIGN_CASE(0);
STATIC_DIM_ASSIGN_CASE(1);
STATIC_DIM_ASSIGN_CASE(2);
STATIC_DIM_ASSIGN_CASE(3);
STATIC_DIM_ASSIGN_CASE(4);
STATIC_DIM_ASSIGN_CASE(5);
STATIC_DIM_ASSIGN_CASE(6);
STATIC_DIM_ASSIGN_CASE(7);
STATIC_DIM_ASSIGN_CASE(8);
STATIC_DIM_ASSIGN_CASE(9);
default:
PADDLE_THROW("Invalid rank %d", n);
}
#undef STATIC_DIM_ASSIGN_CASE
} }
/** /**
...@@ -84,22 +88,26 @@ class DDim { ...@@ -84,22 +88,26 @@ class DDim {
inline int64_t operator[](int idx) const { return dim_[idx]; } inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) { inline int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_); PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx]; return dim_[idx];
} }
inline int64_t at(int idx) const { inline int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_); PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx]; return dim_[idx];
} }
template <typename Visitor> template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor( typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
Visitor&& visitor); Visitor&& visitor) {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
template <typename Visitor> template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor( typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
Visitor&& visitor) const; Visitor&& visitor) const {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
bool operator==(const DDim& d) const; bool operator==(const DDim& d) const;
...@@ -128,55 +136,22 @@ class DDim { ...@@ -128,55 +136,22 @@ class DDim {
return *reinterpret_cast<const Dim<M>*>(p); return *reinterpret_cast<const Dim<M>*>(p);
} }
// Construct DDim with given rank
// Only used in friend functions
explicit DDim(int rank) : rank_(rank) {
PADDLE_ENFORCE(rank_ >= 0 && rank_ < kMaxRank, "Invalid rank %d", rank);
}
friend DDim slice_ddim(const DDim& dim, int begin, int end); friend DDim slice_ddim(const DDim& dim, int begin, int end);
friend DDim stride(const DDim& ddim); friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim); friend DDim stride_numel(const DDim& ddim);
private:
Dim<kMaxRank> dim_; Dim<kMaxRank> dim_;
int rank_; int rank_;
}; };
#define PADDLE_VISIT_DDIM(rank) \ #undef PADDLE_VISIT_DDIM_BASE
case rank: \
return visitor(UnsafeCast<rank>())
template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) const {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
#undef PADDLE_VISIT_DDIM #undef PADDLE_VISIT_DDIM
/** /**
......
...@@ -98,8 +98,8 @@ struct StridedCopyDimVisitor { ...@@ -98,8 +98,8 @@ struct StridedCopyDimVisitor {
template <int D> template <int D>
void operator()(const framework::Dim<D>& dst_dim) const { void operator()(const framework::Dim<D>& dst_dim) const {
StridedMemcpyFunctor<T, D> functor; StridedMemcpyFunctor<T, D> functor;
functor(dev_ctx_, src_, src_stride_.data(), dst_dim.data(), functor(dev_ctx_, src_, src_stride_.Get(), dst_dim.Get(), dst_stride_.Get(),
dst_stride_.data(), dst_); dst_);
} }
const platform::DeviceContext& dev_ctx_; const platform::DeviceContext& dev_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册