提交 600f6d82 编写于 作者: S sneaxiy

polish code

test=develop
上级 89b9d86d
...@@ -42,7 +42,8 @@ struct DDimEqualityVisitor { ...@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
}; };
bool DDim::operator==(const DDim& d) const { bool DDim::operator==(const DDim& d) const {
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get())); return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
} }
bool DDim::operator!=(const DDim& d) const { return !(*this == d); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
...@@ -61,7 +62,7 @@ struct DDimPlusVisitor { ...@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
}; };
DDim DDim::operator+(const DDim& d) const { DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(size() == d.size());
DDim ret; DDim ret;
ret.rank_ = rank_; ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get())); ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
...@@ -82,7 +83,7 @@ struct DDimMulVisitor { ...@@ -82,7 +83,7 @@ struct DDimMulVisitor {
}; };
DDim DDim::operator*(const DDim& d) const { DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(size() == d.size());
DDim ret; DDim ret;
ret.rank_ = rank_; ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get())); ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
...@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) { ...@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
} }
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0, PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"Begin index can't be less than zero in ddim slice."); "[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
int len = end - begin; begin, end, dim.size());
DDim ret; // Constructor of DDim would check whether end - begin is valid
ret.rank_ = len; return DDim(dim.Get() + begin, end - begin);
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret;
} }
int arity(const DDim& d) { return d.size(); } int arity(const DDim& d) { return d.size(); }
...@@ -138,8 +137,8 @@ struct DDimPrinter { ...@@ -138,8 +137,8 @@ struct DDimPrinter {
std::ostream& os; std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {} explicit DDimPrinter(std::ostream& os_) : os(os_) {}
template <typename T> template <int D>
void operator()(const T& t) { void operator()(const Dim<D>& t) {
os << t; os << t;
} }
}; };
...@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ...@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
} }
DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size(); return DDim({product(slice_ddim(src, 0, num_col_dims)),
return make_ddim({product(slice_ddim(src, 0, num_col_dims)), product(slice_ddim(src, num_col_dims, src.size()))});
product(slice_ddim(src, num_col_dims, rank))});
} }
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
DDim strides; DDim strides;
......
...@@ -124,16 +124,16 @@ class DDim { ...@@ -124,16 +124,16 @@ class DDim {
inline int size() const { return rank_; } inline int size() const { return rank_; }
private: private:
template <int M> template <int D>
inline Dim<M>& UnsafeCast() { inline Dim<D>& UnsafeCast() {
return const_cast<Dim<M>&>(const_cast<const DDim*>(this)->UnsafeCast<M>()); return const_cast<Dim<D>&>(const_cast<const DDim*>(this)->UnsafeCast<D>());
} }
template <int M> template <int D>
inline const Dim<M>& UnsafeCast() const { inline const Dim<D>& UnsafeCast() const {
static_assert(M >= 0 && M <= kMaxRank, "Invalid rank"); static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<const void*>(&dim_); auto* p = static_cast<const void*>(&dim_);
return *reinterpret_cast<const Dim<M>*>(p); return *reinterpret_cast<const Dim<D>*>(p);
} }
friend DDim slice_ddim(const DDim& dim, int begin, int end); friend DDim slice_ddim(const DDim& dim, int begin, int end);
......
...@@ -28,17 +28,17 @@ namespace paddle { ...@@ -28,17 +28,17 @@ namespace paddle {
namespace framework { namespace framework {
// Statically sized, statically indexed dimension // Statically sized, statically indexed dimension
template <int N> template <int D>
class Dim : public Array<int64_t, N> { class Dim : public Array<int64_t, D> {
public: public:
static_assert(N >= 0, "N must be not less than 0"); static_assert(D >= 0, "D must be not less than 0");
static constexpr int kRank = N; static constexpr int kRank = D;
using BaseClass = Array<int64_t, N>; using BaseClass = Array<int64_t, D>;
inline Dim(int64_t head, const Dim<N - 1>& tail) { inline Dim(int64_t head, const Dim<D - 1>& tail) {
(*this)[0] = head; (*this)[0] = head;
new (this->GetMutable() + 1) Dim<N - 1>(tail); new (this->GetMutable() + 1) Dim<D - 1>(tail);
} }
template <typename... Args> template <typename... Args>
...@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> { ...@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
/** Construct a Dim from a linear index and size. Uses Fortran order /** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */ * indexing. */
HOSTDEVICE Dim(int64_t idx, const Dim<N>& size); HOSTDEVICE Dim(int64_t idx, const Dim<D>& size);
/** Construct a Dim with each dimension set to the given index */ /** Construct a Dim with each dimension set to the given index */
HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); } HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); }
...@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> { ...@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
}; };
} // namespace detail } // namespace detail
template <int N> template <int D>
HOSTDEVICE Dim<N>::Dim(int64_t idx, const Dim<N>& size) { HOSTDEVICE Dim<D>::Dim(int64_t idx, const Dim<D>& size) {
detail::FortranOrderIndexingConstructorFunctor<0, N, N == 0>::Run( detail::FortranOrderIndexingConstructorFunctor<0, D, D == 0>::Run(
size.Get(), &idx, this->GetMutable()); size.Get(), &idx, this->GetMutable());
} }
template <int idx, int N> template <int idx, int D>
HOSTDEVICE inline int64_t get(const Dim<N>& dim) { HOSTDEVICE inline int64_t get(const Dim<D>& dim) {
return dim[idx]; return dim[idx];
} }
template <int idx, int N> template <int idx, int D>
HOSTDEVICE inline int64_t& get(Dim<N>& dim) { // NOLINT HOSTDEVICE inline int64_t& get(Dim<D>& dim) { // NOLINT
return dim[idx]; return dim[idx];
} }
template <int N> template <int D>
HOSTDEVICE inline int64_t get(const Dim<N>& dim, int idx) { HOSTDEVICE inline int64_t get(const Dim<D>& dim, int idx) {
return dim[idx]; return dim[idx];
} }
template <int N> template <int D>
HOSTDEVICE inline int64_t& get(Dim<N>& dim, int idx) { // NOLINT HOSTDEVICE inline int64_t& get(Dim<D>& dim, int idx) { // NOLINT
return dim[idx]; return dim[idx];
} }
// Dot product of two dims // Dot product of two dims
template <int N> template <int D>
HOSTDEVICE inline int64_t linearize(const Dim<N>& a, const Dim<N>& b) { HOSTDEVICE inline int64_t linearize(const Dim<D>& a, const Dim<D>& b) {
return UnrollProduct<N>::Run(a.Get(), b.Get()); return UnrollProduct<D>::Run(a.Get(), b.Get());
} }
// Product of a Dim // Product of a Dim
template <int N> template <int D>
HOSTDEVICE inline int64_t product(const Dim<N>& a) { HOSTDEVICE inline int64_t product(const Dim<D>& a) {
return UnrollProduct<N>::Run(a.Get()); return UnrollProduct<D>::Run(a.Get());
} }
// Is 0 <= idx_i < size_i for all i? // Is 0 <= idx_i < size_i for all i?
...@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> { ...@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
}; };
} // namespace detail } // namespace detail
template <int N> template <int D>
HOSTDEVICE inline bool contained(const Dim<N>& idx, const Dim<N>& size) { HOSTDEVICE inline bool contained(const Dim<D>& idx, const Dim<D>& size) {
return detail::ContainedFunctor<0, N, N == 0>::Run(idx.Get(), size.Get()); return detail::ContainedFunctor<0, D, D == 0>::Run(idx.Get(), size.Get());
} }
/** /**
...@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> { ...@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
}; };
} // namespace detail } // namespace detail
template <int N> template <int D>
HOSTDEVICE inline Dim<N> ex_prefix_mul(const Dim<N>& src) { HOSTDEVICE inline Dim<D> ex_prefix_mul(const Dim<D>& src) {
Dim<N> ret; Dim<D> ret;
detail::ExPrefixMulFunctor<0, N, N == 0>::Run(src.Get(), ret.GetMutable()); detail::ExPrefixMulFunctor<0, D, D == 0>::Run(src.Get(), ret.GetMutable());
return ret; return ret;
} }
/** /**
* Add two dimensions together * Add two dimensions together
*/ */
template <int N> template <int D>
HOSTDEVICE inline Dim<N> dim_plus(const Dim<N>& a, const Dim<N>& b) { HOSTDEVICE inline Dim<D> dim_plus(const Dim<D>& a, const Dim<D>& b) {
Dim<N> ret; Dim<D> ret;
UnrollAdd<N>::Run(a.Get(), b.Get(), ret.GetMutable()); UnrollAdd<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret; return ret;
} }
template <int N> template <int D>
HOSTDEVICE inline Dim<N> operator+(const Dim<N>& lhs, const Dim<N>& rhs) { HOSTDEVICE inline Dim<D> operator+(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_plus(lhs, rhs); return dim_plus(lhs, rhs);
} }
/** /**
* Multiply two dimensions together * Multiply two dimensions together
*/ */
template <int N> template <int D>
HOSTDEVICE inline Dim<N> dim_mult(const Dim<N>& a, const Dim<N>& b) { HOSTDEVICE inline Dim<D> dim_mult(const Dim<D>& a, const Dim<D>& b) {
Dim<N> ret; Dim<D> ret;
UnrollMul<N>::Run(a.Get(), b.Get(), ret.GetMutable()); UnrollMul<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret; return ret;
} }
template <int i> template <int D>
HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) { HOSTDEVICE Dim<D> operator*(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_mult(lhs, rhs); return dim_mult(lhs, rhs);
} }
...@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> { ...@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
}; };
} // namespace detail } // namespace detail
template <int N> template <int D>
HOSTDEVICE Dim<N> normalize_strides(const Dim<N>& size, const Dim<N>& stride) { HOSTDEVICE Dim<D> normalize_strides(const Dim<D>& size, const Dim<D>& stride) {
Dim<N> ret; Dim<D> ret;
detail::NormalizeStridesFunctor<0, N, N == 0>::Run(size.Get(), stride.Get(), detail::NormalizeStridesFunctor<0, D, D == 0>::Run(size.Get(), stride.Get(),
ret.GetMutable()); ret.GetMutable());
return ret; return ret;
} }
...@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) { ...@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
} }
// Allows us to output a Dim // Allows us to output a Dim
template <int N> template <int D>
inline std::ostream& operator<<(std::ostream& os, const Dim<N>& d) { inline std::ostream& operator<<(std::ostream& os, const Dim<D>& d) {
os << d[0]; os << d[0];
for (int i = 1; i < N; ++i) { for (int i = 1; i < D; ++i) {
os << ", " << d[i]; os << ", " << d[i];
} }
return os; return os;
...@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { ...@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os; return os;
} }
template <int N> template <int D>
HOST std::string Dim<N>::to_string() const { HOST std::string Dim<D>::to_string() const {
std::stringstream stream; std::stringstream stream;
stream << *this; stream << *this;
return stream.str(); return stream.str();
} }
template <int N> template <int D>
HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) { HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, const Dim<D>& extents) {
Dim<N> result; Dim<D> result;
for (int i = 0; i < N - 1; ++i) { for (int i = 0; i < D - 1; ++i) {
result[i] = linear_index % extents[i]; result[i] = linear_index % extents[i];
linear_index /= extents[i]; linear_index /= extents[i];
} }
result[N - 1] = linear_index; result[D - 1] = linear_index;
return result; return result;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册