提交 83ac8515 编写于 作者: S sneaxiy

polish code

test=develop
上级 045dc127
......@@ -18,13 +18,6 @@ limitations under the License. */
namespace paddle {
namespace framework {
template <int N>
Dim<N> make_dim(const int64_t* d) {
Dim<N> ret;
fix_dim_assign(d, ret.GetMutable());
return ret;
}
DDim make_ddim(std::initializer_list<int64_t> dims) {
return DDim(dims.begin(), dims.size());
}
......@@ -69,8 +62,7 @@ struct DDimPlusVisitor {
DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
DDim ret(rank_);
ret.apply_visitor(DDimPlusVisitor(Get(), d.Get()));
return ret;
}
......@@ -90,8 +82,7 @@ struct DDimMulVisitor {
DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
DDim ret(rank_);
ret.apply_visitor(DDimMulVisitor(Get(), d.Get()));
return ret;
}
......@@ -118,7 +109,7 @@ std::vector<int> vectorize2int(const DDim& ddim) {
struct ProductVisitor {
template <int D>
int64_t operator()(const Dim<D>& dim) {
inline int64_t operator()(const Dim<D>& dim) {
return product(dim);
}
};
......@@ -130,8 +121,7 @@ int64_t product(const DDim& ddim) {
DDim slice_ddim(const DDim& dim, int begin, int end) {
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
DDim ret;
ret.rank_ = end - begin;
DDim ret(end - begin);
dynamic_dim_assign(dim.Get() + begin, ret.GetMutable(), ret.rank_);
return ret;
}
......@@ -166,8 +156,7 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
DDim strides(ddim.size());
strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1];
......@@ -175,9 +164,8 @@ DDim stride(const DDim& ddim) {
return strides;
}
DDim stride_numel(const framework::DDim& ddim) {
DDim strides;
strides.rank_ = ddim.size();
DDim stride_numel(const DDim& ddim) {
DDim strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
......
......@@ -22,27 +22,31 @@ limitations under the License. */
namespace paddle {
namespace framework {
#define PADDLE_VISIT_DDIM_BASE(rank, callback) \
case (rank): { \
constexpr auto kRank = (rank); \
return (callback); \
}
#define PADDLE_VISIT_DDIM(rank, callback) \
switch (rank) { \
PADDLE_VISIT_DDIM_BASE(0, callback); \
PADDLE_VISIT_DDIM_BASE(1, callback); \
PADDLE_VISIT_DDIM_BASE(2, callback); \
PADDLE_VISIT_DDIM_BASE(3, callback); \
PADDLE_VISIT_DDIM_BASE(4, callback); \
PADDLE_VISIT_DDIM_BASE(5, callback); \
PADDLE_VISIT_DDIM_BASE(6, callback); \
PADDLE_VISIT_DDIM_BASE(7, callback); \
PADDLE_VISIT_DDIM_BASE(8, callback); \
PADDLE_VISIT_DDIM_BASE(9, callback); \
default: \
PADDLE_THROW("Invalid rank %d", rank); \
}
template <typename T1, typename T2>
inline void dynamic_dim_assign(const T1* in, T2* out, int n) {
#define STATIC_DIM_ASSIGN_CASE(rank) \
case rank: \
static_dim_assign<rank, T1, T2>(in, out); \
return
switch (n) {
STATIC_DIM_ASSIGN_CASE(0);
STATIC_DIM_ASSIGN_CASE(1);
STATIC_DIM_ASSIGN_CASE(2);
STATIC_DIM_ASSIGN_CASE(3);
STATIC_DIM_ASSIGN_CASE(4);
STATIC_DIM_ASSIGN_CASE(5);
STATIC_DIM_ASSIGN_CASE(6);
STATIC_DIM_ASSIGN_CASE(7);
STATIC_DIM_ASSIGN_CASE(8);
STATIC_DIM_ASSIGN_CASE(9);
default:
PADDLE_THROW("Invalid rank %d", n);
}
#undef STATIC_DIM_ASSIGN_CASE
PADDLE_VISIT_DDIM(n, (static_dim_assign<kRank, T1, T2>(in, out)));
}
/**
......@@ -84,22 +88,26 @@ class DDim {
inline int64_t operator[](int idx) const { return dim_[idx]; }
inline int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_);
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
inline int64_t at(int idx) const {
PADDLE_ENFORCE(idx >= 0 && idx < rank_);
PADDLE_ENFORCE(idx >= 0 && idx < rank_, "Invalid idx %d", idx);
return dim_[idx];
}
template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
Visitor&& visitor);
Visitor&& visitor) {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
Visitor&& visitor) const;
Visitor&& visitor) const {
PADDLE_VISIT_DDIM(rank_, visitor(UnsafeCast<kRank>()));
}
bool operator==(const DDim& d) const;
......@@ -128,55 +136,22 @@ class DDim {
return *reinterpret_cast<const Dim<M>*>(p);
}
// Construct DDim with given rank
// Only used in friend functions
explicit DDim(int rank) : rank_(rank) {
PADDLE_ENFORCE(rank_ >= 0 && rank_ < kMaxRank, "Invalid rank %d", rank);
}
friend DDim slice_ddim(const DDim& dim, int begin, int end);
friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim);
private:
Dim<kMaxRank> dim_;
int rank_;
};
#define PADDLE_VISIT_DDIM(rank) \
case rank: \
return visitor(UnsafeCast<rank>())
template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) const {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
#undef PADDLE_VISIT_DDIM_BASE
#undef PADDLE_VISIT_DDIM
/**
......
......@@ -98,8 +98,8 @@ struct StridedCopyDimVisitor {
template <int D>
void operator()(const framework::Dim<D>& dst_dim) const {
StridedMemcpyFunctor<T, D> functor;
functor(dev_ctx_, src_, src_stride_.data(), dst_dim.data(),
dst_stride_.data(), dst_);
functor(dev_ctx_, src_, src_stride_.Get(), dst_dim.Get(), dst_stride_.Get(),
dst_);
}
const platform::DeviceContext& dev_ctx_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册