提交 600f6d82 编写于 作者: S sneaxiy

polish code

test=develop
上级 89b9d86d
......@@ -42,7 +42,8 @@ struct DDimEqualityVisitor {
};
bool DDim::operator==(const DDim& d) const {
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get()));
return size() == d.size() &&
this->apply_visitor(DDimEqualityVisitor(d.Get()));
}
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
......@@ -61,7 +62,7 @@ struct DDimPlusVisitor {
};
DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
......@@ -82,7 +83,7 @@ struct DDimMulVisitor {
};
DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
PADDLE_ENFORCE(size() == d.size());
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
......@@ -121,13 +122,11 @@ int64_t product(const DDim& ddim) {
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
int len = end - begin;
DDim ret;
ret.rank_ = len;
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret;
PADDLE_ENFORCE(begin >= 0 && end <= dim.size(),
"[begin(%d), end(%d)) must be inside [0, %d) in ddim slice.",
begin, end, dim.size());
// Constructor of DDim would check whether end - begin is valid
return DDim(dim.Get() + begin, end - begin);
}
int arity(const DDim& d) { return d.size(); }
......@@ -138,8 +137,8 @@ struct DDimPrinter {
std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {}
template <typename T>
void operator()(const T& t) {
template <int D>
void operator()(const Dim<D>& t) {
os << t;
}
};
......@@ -152,12 +151,11 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
}
DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, rank))});
return DDim({product(slice_ddim(src, 0, num_col_dims)),
product(slice_ddim(src, num_col_dims, src.size()))});
}
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim flatten_to_1d(const DDim& src) { return DDim({product(src)}); }
DDim stride(const DDim& ddim) {
DDim strides;
......
......@@ -124,16 +124,16 @@ class DDim {
inline int size() const { return rank_; }
private:
template <int M>
inline Dim<M>& UnsafeCast() {
return const_cast<Dim<M>&>(const_cast<const DDim*>(this)->UnsafeCast<M>());
template <int D>
inline Dim<D>& UnsafeCast() {
return const_cast<Dim<D>&>(const_cast<const DDim*>(this)->UnsafeCast<D>());
}
template <int M>
inline const Dim<M>& UnsafeCast() const {
static_assert(M >= 0 && M <= kMaxRank, "Invalid rank");
template <int D>
inline const Dim<D>& UnsafeCast() const {
static_assert(D >= 0 && D <= kMaxRank, "Invalid rank");
auto* p = static_cast<const void*>(&dim_);
return *reinterpret_cast<const Dim<M>*>(p);
return *reinterpret_cast<const Dim<D>*>(p);
}
friend DDim slice_ddim(const DDim& dim, int begin, int end);
......
......@@ -28,17 +28,17 @@ namespace paddle {
namespace framework {
// Statically sized, statically indexed dimension
template <int N>
class Dim : public Array<int64_t, N> {
template <int D>
class Dim : public Array<int64_t, D> {
public:
static_assert(N >= 0, "N must be not less than 0");
static_assert(D >= 0, "D must be not less than 0");
static constexpr int kRank = N;
using BaseClass = Array<int64_t, N>;
static constexpr int kRank = D;
using BaseClass = Array<int64_t, D>;
inline Dim(int64_t head, const Dim<N - 1>& tail) {
inline Dim(int64_t head, const Dim<D - 1>& tail) {
(*this)[0] = head;
new (this->GetMutable() + 1) Dim<N - 1>(tail);
new (this->GetMutable() + 1) Dim<D - 1>(tail);
}
template <typename... Args>
......@@ -47,7 +47,7 @@ class Dim : public Array<int64_t, N> {
/** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */
HOSTDEVICE Dim(int64_t idx, const Dim<N>& size);
HOSTDEVICE Dim(int64_t idx, const Dim<D>& size);
/** Construct a Dim with each dimension set to the given index */
HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); }
......@@ -77,42 +77,42 @@ struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
};
} // namespace detail
template <int N>
HOSTDEVICE Dim<N>::Dim(int64_t idx, const Dim<N>& size) {
detail::FortranOrderIndexingConstructorFunctor<0, N, N == 0>::Run(
template <int D>
HOSTDEVICE Dim<D>::Dim(int64_t idx, const Dim<D>& size) {
detail::FortranOrderIndexingConstructorFunctor<0, D, D == 0>::Run(
size.Get(), &idx, this->GetMutable());
}
template <int idx, int N>
HOSTDEVICE inline int64_t get(const Dim<N>& dim) {
template <int idx, int D>
HOSTDEVICE inline int64_t get(const Dim<D>& dim) {
return dim[idx];
}
template <int idx, int N>
HOSTDEVICE inline int64_t& get(Dim<N>& dim) { // NOLINT
template <int idx, int D>
HOSTDEVICE inline int64_t& get(Dim<D>& dim) { // NOLINT
return dim[idx];
}
template <int N>
HOSTDEVICE inline int64_t get(const Dim<N>& dim, int idx) {
template <int D>
HOSTDEVICE inline int64_t get(const Dim<D>& dim, int idx) {
return dim[idx];
}
template <int N>
HOSTDEVICE inline int64_t& get(Dim<N>& dim, int idx) { // NOLINT
template <int D>
HOSTDEVICE inline int64_t& get(Dim<D>& dim, int idx) { // NOLINT
return dim[idx];
}
// Dot product of two dims
template <int N>
HOSTDEVICE inline int64_t linearize(const Dim<N>& a, const Dim<N>& b) {
return UnrollProduct<N>::Run(a.Get(), b.Get());
template <int D>
HOSTDEVICE inline int64_t linearize(const Dim<D>& a, const Dim<D>& b) {
return UnrollProduct<D>::Run(a.Get(), b.Get());
}
// Product of a Dim
template <int N>
HOSTDEVICE inline int64_t product(const Dim<N>& a) {
return UnrollProduct<N>::Run(a.Get());
template <int D>
HOSTDEVICE inline int64_t product(const Dim<D>& a) {
return UnrollProduct<D>::Run(a.Get());
}
// Is 0 <= idx_i < size_i for all i?
......@@ -135,9 +135,9 @@ struct ContainedFunctor<kStart, kEnd, true> {
};
} // namespace detail
template <int N>
HOSTDEVICE inline bool contained(const Dim<N>& idx, const Dim<N>& size) {
return detail::ContainedFunctor<0, N, N == 0>::Run(idx.Get(), size.Get());
template <int D>
HOSTDEVICE inline bool contained(const Dim<D>& idx, const Dim<D>& size) {
return detail::ContainedFunctor<0, D, D == 0>::Run(idx.Get(), size.Get());
}
/**
......@@ -160,40 +160,40 @@ struct ExPrefixMulFunctor<kStart, kEnd, true> {
};
} // namespace detail
template <int N>
HOSTDEVICE inline Dim<N> ex_prefix_mul(const Dim<N>& src) {
Dim<N> ret;
detail::ExPrefixMulFunctor<0, N, N == 0>::Run(src.Get(), ret.GetMutable());
template <int D>
HOSTDEVICE inline Dim<D> ex_prefix_mul(const Dim<D>& src) {
Dim<D> ret;
detail::ExPrefixMulFunctor<0, D, D == 0>::Run(src.Get(), ret.GetMutable());
return ret;
}
/**
* Add two dimensions together
*/
template <int N>
HOSTDEVICE inline Dim<N> dim_plus(const Dim<N>& a, const Dim<N>& b) {
Dim<N> ret;
UnrollAdd<N>::Run(a.Get(), b.Get(), ret.GetMutable());
template <int D>
HOSTDEVICE inline Dim<D> dim_plus(const Dim<D>& a, const Dim<D>& b) {
Dim<D> ret;
UnrollAdd<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
}
template <int N>
HOSTDEVICE inline Dim<N> operator+(const Dim<N>& lhs, const Dim<N>& rhs) {
template <int D>
HOSTDEVICE inline Dim<D> operator+(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_plus(lhs, rhs);
}
/**
* Multiply two dimensions together
*/
template <int N>
HOSTDEVICE inline Dim<N> dim_mult(const Dim<N>& a, const Dim<N>& b) {
Dim<N> ret;
UnrollMul<N>::Run(a.Get(), b.Get(), ret.GetMutable());
template <int D>
HOSTDEVICE inline Dim<D> dim_mult(const Dim<D>& a, const Dim<D>& b) {
Dim<D> ret;
UnrollMul<D>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
}
template <int i>
HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
template <int D>
HOSTDEVICE Dim<D> operator*(const Dim<D>& lhs, const Dim<D>& rhs) {
return dim_mult(lhs, rhs);
}
......@@ -224,10 +224,10 @@ struct NormalizeStridesFunctor<kStart, kEnd, true> {
};
} // namespace detail
template <int N>
HOSTDEVICE Dim<N> normalize_strides(const Dim<N>& size, const Dim<N>& stride) {
Dim<N> ret;
detail::NormalizeStridesFunctor<0, N, N == 0>::Run(size.Get(), stride.Get(),
template <int D>
HOSTDEVICE Dim<D> normalize_strides(const Dim<D>& size, const Dim<D>& stride) {
Dim<D> ret;
detail::NormalizeStridesFunctor<0, D, D == 0>::Run(size.Get(), stride.Get(),
ret.GetMutable());
return ret;
}
......@@ -245,10 +245,10 @@ HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
}
// Allows us to output a Dim
template <int N>
inline std::ostream& operator<<(std::ostream& os, const Dim<N>& d) {
template <int D>
inline std::ostream& operator<<(std::ostream& os, const Dim<D>& d) {
os << d[0];
for (int i = 1; i < N; ++i) {
for (int i = 1; i < D; ++i) {
os << ", " << d[i];
}
return os;
......@@ -258,23 +258,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os;
}
template <int N>
HOST std::string Dim<N>::to_string() const {
template <int D>
HOST std::string Dim<D>::to_string() const {
std::stringstream stream;
stream << *this;
return stream.str();
}
template <int N>
HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) {
Dim<N> result;
template <int D>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, const Dim<D>& extents) {
Dim<D> result;
for (int i = 0; i < N - 1; ++i) {
for (int i = 0; i < D - 1; ++i) {
result[i] = linear_index % extents[i];
linear_index /= extents[i];
}
result[N - 1] = linear_index;
result[D - 1] = linear_index;
return result;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册