提交 53f6c699 编写于 作者: S sneaxiy

polish code

test=develop
上级 a500dfa5
......@@ -18,32 +18,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T>
struct DDimAssignFunctor {
static_assert(std::is_integral<T>::value, "T must be integral type");
using result_type = void;
explicit DDimAssignFunctor(const T* in) : in_(in) {}
template <int D>
inline void operator()(Dim<D>& dim) { // NOLINT
UnrollAssign<D>::Run(in_, dim.data());
}
const T* in_;
};
DDim::DDim(const int* d, int n) : rank_(n) {
this->apply_visitor(DDimAssignFunctor<int>(d));
}
DDim::DDim(const int64_t* d, int n) : rank_(n) {
this->apply_visitor(DDimAssignFunctor<int64_t>(d));
}
template <int N>
Dim<N> make_dim(const int64_t* d) {
Dim<N> ret;
for (int i = 0; i < N; ++i) ret[i] = d[i];
fix_dim_assign(d, ret.GetMutable());
return ret;
}
......@@ -64,14 +42,14 @@ struct DDimEqualityVisitor {
template <int D>
inline bool operator()(const Dim<D>& self) const {
return UnrollCompare<D>::Run(self.data(), d_);
return UnrollCompare<D>::Run(self.Get(), d_);
}
const int64_t* d_;
};
bool DDim::operator==(const DDim& d) const {
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.data()));
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.Get()));
}
bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
......@@ -82,7 +60,7 @@ struct DDimPlusVisitor {
template <int D>
inline void operator()(Dim<D>& self) const {
UnrollAdd<D>::Run(d1_, d2_, self.data());
UnrollAdd<D>::Run(d1_, d2_, self.GetMutable());
}
const int64_t* d1_;
......@@ -93,7 +71,7 @@ DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(data(), d.data()));
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret;
}
......@@ -103,7 +81,7 @@ struct DDimMulVisitor {
template <int D>
inline void operator()(Dim<D>& self) const {
UnrollMul<D>::Run(d1_, d2_, self.data());
UnrollMul<D>::Run(d1_, d2_, self.GetMutable());
}
const int64_t* d1_;
......@@ -114,7 +92,7 @@ DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(data(), d.data()));
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret;
}
......@@ -124,9 +102,7 @@ void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result(DDim::kMaxRank);
for (int i = 0; i < ddim.size(); ++i) {
result[i] = ddim[i];
}
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
......@@ -135,9 +111,7 @@ std::vector<int64_t> vectorize(const DDim& ddim) {
// which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int> result(DDim::kMaxRank);
for (int i = 0; i < ddim.size(); ++i) {
result[i] = ddim[i];
}
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
......@@ -154,15 +128,11 @@ int64_t product(const DDim& ddim) {
}
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
DDim ret;
ret.rank_ = end - begin;
for (int i = 0; i < ret.rank_; ++i) {
ret[i] = dim[i + begin];
}
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret;
}
......
......@@ -22,6 +22,29 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
#define STATIC_DIM_ASSIGN_CASE(rank) \
case rank: \
static_dim_assign<rank, T1, T2>(in, out); \
return
switch (n) {
STATIC_DIM_ASSIGN_CASE(0);
STATIC_DIM_ASSIGN_CASE(1);
STATIC_DIM_ASSIGN_CASE(2);
STATIC_DIM_ASSIGN_CASE(3);
STATIC_DIM_ASSIGN_CASE(4);
STATIC_DIM_ASSIGN_CASE(5);
STATIC_DIM_ASSIGN_CASE(6);
STATIC_DIM_ASSIGN_CASE(7);
STATIC_DIM_ASSIGN_CASE(8);
STATIC_DIM_ASSIGN_CASE(9);
default:
PADDLE_THROW("Invalid rank %d", n);
}
#undef STATIC_DIM_ASSIGN_CASE
}
/**
* \brief A dynamically sized dimension.
*
......@@ -33,8 +56,13 @@ class DDim {
DDim() : rank_(1) { dim_[0] = 0; }
DDim(const int* d, int n);
DDim(const int64_t* d, int n);
DDim(const int* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
DDim(const int64_t* d, int n) : rank_(n) {
dynamic_dim_assign(d, dim_.GetMutable(), n);
}
template <int D>
/*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT
......@@ -81,19 +109,11 @@ class DDim {
DDim operator*(const DDim& d) const;
// Make DDim act like std::vector<int64_t>
using iterator = int64_t*;
using const_iterator = const int64_t*;
int64_t* data() { return dim_.data(); }
const int64_t* data() const { return dim_.data(); }
inline const int64_t* Get() const { return dim_.Get(); }
iterator begin() { return data(); }
const_iterator begin() const { return data(); }
iterator end() { return data() + rank_; }
const_iterator end() const { return data() + rank_; }
inline int64_t* GetMutable() { return dim_.GetMutable(); }
int size() const { return rank_; }
inline int size() const { return rank_; }
private:
template <int M>
......
......@@ -54,10 +54,6 @@ class Dim : public Array<int64_t, N> {
HOSTDEVICE Dim() = default;
HOSTDEVICE int64_t* data() { return this->GetMutable(); }
HOSTDEVICE const int64_t* data() const { return this->Get(); }
HOST std::string to_string() const;
};
......@@ -283,5 +279,10 @@ HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) {
return result;
}
template <int D, typename T1, typename T2>
inline void static_dim_assign(const T1* in, T2* out) {
UnrollAssign<D>::Run(in, out);
}
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册