提交 53f6c699 编写于 作者: S sneaxiy

polish code

test=develop
上级 a500dfa5
...@@ -18,32 +18,10 @@ limitations under the License. */ ...@@ -18,32 +18,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T>
struct DDimAssignFunctor {
static_assert(std::is_integral<T>::value, "T must be integral type");
using result_type = void;
explicit DDimAssignFunctor(const T* in) : in_(in) {}
template <int D>
inline void operator()(Dim<D>& dim) { // NOLINT
UnrollAssign<D>::Run(in_, dim.data());
}
const T* in_;
};
DDim::DDim(const int* d, int n) : rank_(n) {
this->apply_visitor(DDimAssignFunctor<int>(d));
}
DDim::DDim(const int64_t* d, int n) : rank_(n) {
this->apply_visitor(DDimAssignFunctor<int64_t>(d));
}
template <int N> template <int N>
Dim<N> make_dim(const int64_t* d) { Dim<N> make_dim(const int64_t* d) {
Dim<N> ret; Dim<N> ret;
for (int i = 0; i < N; ++i) ret[i] = d[i]; fix_dim_assign(d, ret.GetMutable());
return ret; return ret;
} }
...@@ -64,14 +42,14 @@ struct DDimEqualityVisitor { ...@@ -64,14 +42,14 @@ struct DDimEqualityVisitor {
template <int D> template <int D>
inline bool operator()(const Dim<D>& self) const { inline bool operator()(const Dim<D>& self) const {
return UnrollCompare<D>::Run(self.data(), d_); return UnrollCompare<D>::Run(self.Get(), d_);
} }
const int64_t* d_; const int64_t* d_;
}; };
bool DDim::operator==(const DDim& d) const { bool DDim::operator==(const DDim& d) const {
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.data())); return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get()));
} }
bool DDim::operator!=(const DDim& d) const { return !(*this == d); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
...@@ -82,7 +60,7 @@ struct DDimPlusVisitor { ...@@ -82,7 +60,7 @@ struct DDimPlusVisitor {
template <int D> template <int D>
inline void operator()(Dim<D>& self) const { inline void operator()(Dim<D>& self) const {
UnrollAdd<D>::Run(d1_, d2_, self.data()); UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
} }
const int64_t* d1_; const int64_t* d1_;
...@@ -93,7 +71,7 @@ DDim DDim::operator+(const DDim& d) const { ...@@ -93,7 +71,7 @@ DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret; DDim ret;
ret.rank_ = rank_; ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(data(), d.data())); ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -103,7 +81,7 @@ struct DDimMulVisitor { ...@@ -103,7 +81,7 @@ struct DDimMulVisitor {
template <int D> template <int D>
inline void operator()(Dim<D>& self) const { inline void operator()(Dim<D>& self) const {
UnrollMul<D>::Run(d1_, d2_, self.data()); UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
} }
const int64_t* d1_; const int64_t* d1_;
...@@ -114,7 +92,7 @@ DDim DDim::operator*(const DDim& d) const { ...@@ -114,7 +92,7 @@ DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_); PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret; DDim ret;
ret.rank_ = rank_; ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(data(), d.data())); ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret; return ret;
} }
...@@ -124,9 +102,7 @@ void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT ...@@ -124,9 +102,7 @@ void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim) { std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result(DDim::kMaxRank); std::vector<int64_t> result(DDim::kMaxRank);
for (int i = 0; i < ddim.size(); ++i) { dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result[i] = ddim[i];
}
result.resize(ddim.size()); result.resize(ddim.size());
return result; return result;
} }
...@@ -135,9 +111,7 @@ std::vector<int64_t> vectorize(const DDim& ddim) { ...@@ -135,9 +111,7 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
// which does not fit cudnn inputs. // which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) { std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int> result(DDim::kMaxRank); std::vector<int> result(DDim::kMaxRank);
for (int i = 0; i < ddim.size(); ++i) { dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result[i] = ddim[i];
}
result.resize(ddim.size()); result.resize(ddim.size());
return result; return result;
} }
...@@ -154,15 +128,11 @@ int64_t product(const DDim& ddim) { ...@@ -154,15 +128,11 @@ int64_t product(const DDim& ddim) {
} }
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0, PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice."); "Begin index can't be less than zero in ddim slice.");
DDim ret; DDim ret;
ret.rank_ = end - begin; ret.rank_ = end - begin;
for (int i = 0; i < ret.rank_; ++i) { dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
ret[i] = dim[i + begin];
}
return ret; return ret;
} }
......
...@@ -22,6 +22,29 @@ limitations under the License. */ ...@@ -22,6 +22,29 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
#define STATIC_DIM_ASSIGN_CASE(rank) \
case rank: \
static_dim_assign<rank, T1, T2>(in, out); \
return
switch (n) {
STATIC_DIM_ASSIGN_CASE(0);
STATIC_DIM_ASSIGN_CASE(1);
STATIC_DIM_ASSIGN_CASE(2);
STATIC_DIM_ASSIGN_CASE(3);
STATIC_DIM_ASSIGN_CASE(4);
STATIC_DIM_ASSIGN_CASE(5);
STATIC_DIM_ASSIGN_CASE(6);
STATIC_DIM_ASSIGN_CASE(7);
STATIC_DIM_ASSIGN_CASE(8);
STATIC_DIM_ASSIGN_CASE(9);
default:
PADDLE_THROW("Invalid rank %d", n);
}
#undef STATIC_DIM_ASSIGN_CASE
}
/** /**
* \brief A dynamically sized dimension. * \brief A dynamically sized dimension.
* *
...@@ -33,8 +56,13 @@ class DDim { ...@@ -33,8 +56,13 @@ class DDim {
DDim() : rank_(1) { dim_[0] = 0; } DDim() : rank_(1) { dim_[0] = 0; }
DDim(const int* d, int n); DDim(const int* d, int n) : rank_(n) {
DDim(const int64_t* d, int n); dynamic_dim_assign(d, dim_.GetMutable(), n);
}
DDim(const int64_t* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
template <int D> template <int D>
/*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT /*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT
...@@ -81,19 +109,11 @@ class DDim { ...@@ -81,19 +109,11 @@ class DDim {
DDim operator*(const DDim& d) const; DDim operator*(const DDim& d) const;
// Make DDim act like std::vector<int64_t> inline const int64_t* Get() const { return dim_.Get(); }
using iterator = int64_t*;
using const_iterator = const int64_t*;
int64_t* data() { return dim_.data(); }
const int64_t* data() const { return dim_.data(); }
iterator begin() { return data(); } inline int64_t* GetMutable() { return dim_.GetMutable(); }
const_iterator begin() const { return data(); }
iterator end() { return data() + rank_; }
const_iterator end() const { return data() + rank_; }
int size() const { return rank_; } inline int size() const { return rank_; }
private: private:
template <int M> template <int M>
......
...@@ -54,10 +54,6 @@ class Dim : public Array<int64_t, N> { ...@@ -54,10 +54,6 @@ class Dim : public Array<int64_t, N> {
HOSTDEVICE Dim() = default; HOSTDEVICE Dim() = default;
HOSTDEVICE int64_t* data() { return this->GetMutable(); }
HOSTDEVICE const int64_t* data() const { return this->Get(); }
HOST std::string to_string() const; HOST std::string to_string() const;
}; };
...@@ -283,5 +279,10 @@ HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) { ...@@ -283,5 +279,10 @@ HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) {
return result; return result;
} }
template <int D, typename T1, typename T2>
inline void static_dim_assign(const T1* in, T2* out) {
UnrollAssign<D>::Run(in, out);
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册