提交 85c4f488 编写于 作者: F fengjiayi

Refactor DDim's product() and add slice_ddim()

1. Refactor DDim's product() to make it more efficiently.

2. Add slice_ddim().
上级 ee90c2d2
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -190,6 +191,46 @@ ssize_t product(const DDim& ddim) { ...@@ -190,6 +191,46 @@ ssize_t product(const DDim& ddim) {
return boost::apply_visitor(visitor, ddim); return boost::apply_visitor(visitor, ddim);
} }
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int>& v, int b, int e)
: vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
}
template <int S>
void operator()(const Dim<S>& dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
}
--end;
if (end > 0) {
this->operator()(dim.tail);
}
}
void operator()(const Dim<1>& dim) {
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound.");
vector.push_back(dim.head);
}
};
DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int> vec;
vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end);
boost::apply_visitor(visitor, dim);
return make_ddim(vec);
}
///\cond HIDDEN ///\cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> { struct ArityVisitor : boost::static_visitor<int> {
......
...@@ -81,6 +81,8 @@ std::vector<int> vectorize(const DDim& ddim); ...@@ -81,6 +81,8 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); ssize_t product(const DDim& ddim);
DDim slice_ddim(const DDim& dim, int begin, int end);
/** /**
* \brief What is the length of this dimension? * \brief What is the length of this dimension?
* *
......
...@@ -55,6 +55,23 @@ TEST(DDim, Equality) { ...@@ -55,6 +55,23 @@ TEST(DDim, Equality) {
EXPECT_EQ( EXPECT_EQ(
paddle::framework::product(paddle::framework::make_ddim({3, 2, 5, 3})), paddle::framework::product(paddle::framework::make_ddim({3, 2, 5, 3})),
90); 90);
// slice a DDim
paddle::framework::DDim ddim2 =
paddle::framework::make_ddim({1, 2, 3, 4, 5, 6});
paddle::framework ::DDim ss = paddle::framework::slice_ddim(ddim2, 2, 5);
EXPECT_EQ(arity(ss), 3);
EXPECT_EQ(ss[0], 3);
EXPECT_EQ(ss[1], 4);
EXPECT_EQ(ss[2], 5);
paddle::framework ::DDim ss2 = paddle::framework::slice_ddim(ddim2, 0, 6);
EXPECT_EQ(arity(ss2), 6);
EXPECT_EQ(ss2[0], 1);
EXPECT_EQ(ss2[1], 2);
EXPECT_EQ(ss2[2], 3);
EXPECT_EQ(ss2[3], 4);
EXPECT_EQ(ss2[4], 5);
EXPECT_EQ(ss2[5], 6);
} }
TEST(DDim, Print) { TEST(DDim, Print) {
......
...@@ -401,20 +401,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) { ...@@ -401,20 +401,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result; return result;
} }
template <int D, int S>
Dim<D> slice(const Dim<S>& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in Dim slice.");
PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D,
"Index error occurs in Dim slice.");
if (begin > 0) {
return slice<D>(dim.tail, begin - 1, end - 1);
}
if (D == 1) {
return Dim<1>(dim.head);
}
return Dim<D>(dim.head, slice<D - 1>(dim.tail, 0, end - 1));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <sstream> #include <sstream>
#include "paddle/framework/dim.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/dim.h"
__global__ void test(paddle::framework::Dim<2>* o) { __global__ void test(paddle::framework::Dim<2>* o) {
o[0] = paddle::framework::make_dim(5, 6); o[0] = paddle::framework::make_dim(5, 6);
...@@ -21,7 +21,7 @@ TEST(Dim, Equality) { ...@@ -21,7 +21,7 @@ TEST(Dim, Equality) {
// construct a Dim on the GPU // construct a Dim on the GPU
thrust::device_vector<paddle::framework::Dim<2>> t(2); thrust::device_vector<paddle::framework::Dim<2>> t(2);
test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); test<<<1, 1>>>(thrust::raw_pointer_cast(t.data()));
a = t[0]; a = t[0];
EXPECT_EQ(paddle::framework::get<0>(a), 5); EXPECT_EQ(paddle::framework::get<0>(a), 5);
EXPECT_EQ(paddle::framework::get<1>(a), 6); EXPECT_EQ(paddle::framework::get<1>(a), 6);
...@@ -48,12 +48,13 @@ TEST(Dim, Equality) { ...@@ -48,12 +48,13 @@ TEST(Dim, Equality) {
// dynamic access on GPU // dynamic access on GPU
thrust::device_vector<int> r(1); thrust::device_vector<int> r(1);
dyn_idx_gpu<<<1,1>>>(thrust::raw_pointer_cast(r.data())); dyn_idx_gpu<<<1, 1>>>(thrust::raw_pointer_cast(r.data()));
int res = r[0]; int res = r[0];
EXPECT_EQ(res, 6); EXPECT_EQ(res, 6);
// ex_prefix_mul // ex_prefix_mul
paddle::framework::Dim<3> c = paddle::framework::ex_prefix_mul(paddle::framework::Dim<3>(3, 4, 5)); paddle::framework::Dim<3> c =
paddle::framework::ex_prefix_mul(paddle::framework::Dim<3>(3, 4, 5));
EXPECT_EQ(paddle::framework::get<0>(c), 1); EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12); EXPECT_EQ(paddle::framework::get<2>(c), 12);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册