提交 85c4f488 编写于 作者: F fengjiayi

Refactor DDim's product() and add slice_ddim()

1. Refactor DDim's product() to make it more efficiently.

2. Add slice_ddim().
上级 ee90c2d2
#include "paddle/framework/ddim.h" #include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -190,6 +191,46 @@ ssize_t product(const DDim& ddim) { ...@@ -190,6 +191,46 @@ ssize_t product(const DDim& ddim) {
return boost::apply_visitor(visitor, ddim); return boost::apply_visitor(visitor, ddim);
} }
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int>& v, int b, int e)
: vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
}
template <int S>
void operator()(const Dim<S>& dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
}
--end;
if (end > 0) {
this->operator()(dim.tail);
}
}
void operator()(const Dim<1>& dim) {
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound.");
vector.push_back(dim.head);
}
};
DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int> vec;
vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end);
boost::apply_visitor(visitor, dim);
return make_ddim(vec);
}
///\cond HIDDEN ///\cond HIDDEN
struct ArityVisitor : boost::static_visitor<int> { struct ArityVisitor : boost::static_visitor<int> {
......
...@@ -81,6 +81,8 @@ std::vector<int> vectorize(const DDim& ddim); ...@@ -81,6 +81,8 @@ std::vector<int> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); ssize_t product(const DDim& ddim);
DDim slice_ddim(const DDim& dim, int begin, int end);
/** /**
* \brief What is the length of this dimension? * \brief What is the length of this dimension?
* *
......
...@@ -55,6 +55,23 @@ TEST(DDim, Equality) { ...@@ -55,6 +55,23 @@ TEST(DDim, Equality) {
EXPECT_EQ( EXPECT_EQ(
paddle::framework::product(paddle::framework::make_ddim({3, 2, 5, 3})), paddle::framework::product(paddle::framework::make_ddim({3, 2, 5, 3})),
90); 90);
// slice a DDim
paddle::framework::DDim ddim2 =
paddle::framework::make_ddim({1, 2, 3, 4, 5, 6});
paddle::framework ::DDim ss = paddle::framework::slice_ddim(ddim2, 2, 5);
EXPECT_EQ(arity(ss), 3);
EXPECT_EQ(ss[0], 3);
EXPECT_EQ(ss[1], 4);
EXPECT_EQ(ss[2], 5);
paddle::framework ::DDim ss2 = paddle::framework::slice_ddim(ddim2, 0, 6);
EXPECT_EQ(arity(ss2), 6);
EXPECT_EQ(ss2[0], 1);
EXPECT_EQ(ss2[1], 2);
EXPECT_EQ(ss2[2], 3);
EXPECT_EQ(ss2[3], 4);
EXPECT_EQ(ss2[4], 5);
EXPECT_EQ(ss2[5], 6);
} }
TEST(DDim, Print) { TEST(DDim, Print) {
......
...@@ -401,20 +401,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) { ...@@ -401,20 +401,5 @@ HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) {
return result; return result;
} }
template <int D, int S>
Dim<D> slice(const Dim<S>& dim, int begin, int end) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in Dim slice.");
PADDLE_ENFORCE(begin >= 0 && end <= S && end - begin == D,
"Index error occurs in Dim slice.");
if (begin > 0) {
return slice<D>(dim.tail, begin - 1, end - 1);
}
if (D == 1) {
return Dim<1>(dim.head);
}
return Dim<D>(dim.head, slice<D - 1>(dim.tail, 0, end - 1));
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <sstream> #include <sstream>
#include "paddle/framework/dim.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/framework/dim.h"
__global__ void test(paddle::framework::Dim<2>* o) { __global__ void test(paddle::framework::Dim<2>* o) {
o[0] = paddle::framework::make_dim(5, 6); o[0] = paddle::framework::make_dim(5, 6);
} }
__global__ void dyn_idx_gpu(int* o) { __global__ void dyn_idx_gpu(int* o) {
auto d = paddle::framework::make_dim(5, 6); auto d = paddle::framework::make_dim(5, 6);
o[0] = d[1]; o[0] = d[1];
} }
TEST(Dim, Equality) { TEST(Dim, Equality) {
// construct a Dim on the CPU // construct a Dim on the CPU
auto a = paddle::framework::make_dim(3, 4); auto a = paddle::framework::make_dim(3, 4);
EXPECT_EQ(paddle::framework::get<0>(a), 3); EXPECT_EQ(paddle::framework::get<0>(a), 3);
EXPECT_EQ(paddle::framework::get<1>(a), 4); EXPECT_EQ(paddle::framework::get<1>(a), 4);
// construct a Dim on the GPU // construct a Dim on the GPU
thrust::device_vector<paddle::framework::Dim<2>> t(2); thrust::device_vector<paddle::framework::Dim<2>> t(2);
test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); test<<<1, 1>>>(thrust::raw_pointer_cast(t.data()));
a = t[0]; a = t[0];
EXPECT_EQ(paddle::framework::get<0>(a), 5); EXPECT_EQ(paddle::framework::get<0>(a), 5);
EXPECT_EQ(paddle::framework::get<1>(a), 6); EXPECT_EQ(paddle::framework::get<1>(a), 6);
// linearization // linearization
auto b = paddle::framework::make_dim(7, 8); auto b = paddle::framework::make_dim(7, 8);
EXPECT_EQ(paddle::framework::linearize(a, b), 83); EXPECT_EQ(paddle::framework::linearize(a, b), 83);
// product // product
EXPECT_EQ(paddle::framework::product(a), 30); EXPECT_EQ(paddle::framework::product(a), 30);
// mutate a Dim // mutate a Dim
paddle::framework::get<1>(b) = 10; paddle::framework::get<1>(b) = 10;
EXPECT_EQ(paddle::framework::get<0>(b), 7); EXPECT_EQ(paddle::framework::get<0>(b), 7);
EXPECT_EQ(paddle::framework::get<1>(b), 10); EXPECT_EQ(paddle::framework::get<1>(b), 10);
// dynamic access // dynamic access
paddle::framework::get(b, 0) = 8; paddle::framework::get(b, 0) = 8;
b[1] = 11; b[1] = 11;
EXPECT_EQ(paddle::framework::get<0>(b), 8); EXPECT_EQ(paddle::framework::get<0>(b), 8);
EXPECT_EQ(paddle::framework::get<1>(b), 11); EXPECT_EQ(paddle::framework::get<1>(b), 11);
EXPECT_EQ(paddle::framework::get(b, 0), 8); EXPECT_EQ(paddle::framework::get(b, 0), 8);
EXPECT_EQ(b[1], 11); EXPECT_EQ(b[1], 11);
// dynamic access on GPU // dynamic access on GPU
thrust::device_vector<int> r(1); thrust::device_vector<int> r(1);
dyn_idx_gpu<<<1,1>>>(thrust::raw_pointer_cast(r.data())); dyn_idx_gpu<<<1, 1>>>(thrust::raw_pointer_cast(r.data()));
int res = r[0]; int res = r[0];
EXPECT_EQ(res, 6); EXPECT_EQ(res, 6);
// ex_prefix_mul // ex_prefix_mul
paddle::framework::Dim<3> c = paddle::framework::ex_prefix_mul(paddle::framework::Dim<3>(3, 4, 5)); paddle::framework::Dim<3> c =
EXPECT_EQ(paddle::framework::get<0>(c), 1); paddle::framework::ex_prefix_mul(paddle::framework::Dim<3>(3, 4, 5));
EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 12); EXPECT_EQ(paddle::framework::get<1>(c), 3);
EXPECT_EQ(paddle::framework::get<2>(c), 12);
// generate from an index
auto size = paddle::framework::make_dim(4, 5, 2); // generate from an index
c = paddle::framework::Dim<3>(14, size); auto size = paddle::framework::make_dim(4, 5, 2);
EXPECT_EQ(paddle::framework::get<0>(c), 2); c = paddle::framework::Dim<3>(14, size);
EXPECT_EQ(paddle::framework::get<1>(c), 3); EXPECT_EQ(paddle::framework::get<0>(c), 2);
EXPECT_EQ(paddle::framework::get<2>(c), 0); EXPECT_EQ(paddle::framework::get<1>(c), 3);
c = paddle::framework::Dim<3>(25, size); EXPECT_EQ(paddle::framework::get<2>(c), 0);
EXPECT_EQ(paddle::framework::get<0>(c), 1); c = paddle::framework::Dim<3>(25, size);
EXPECT_EQ(paddle::framework::get<1>(c), 1); EXPECT_EQ(paddle::framework::get<0>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 1); EXPECT_EQ(paddle::framework::get<1>(c), 1);
EXPECT_EQ(paddle::framework::get<2>(c), 1);
} }
TEST(Dim, Bool) { TEST(Dim, Bool) {
auto a = paddle::framework::make_dim(3, 4); auto a = paddle::framework::make_dim(3, 4);
auto b = paddle::framework::make_dim(5, 6); auto b = paddle::framework::make_dim(5, 6);
auto c = paddle::framework::make_dim(3, 4); auto c = paddle::framework::make_dim(3, 4);
// in_bounds check // in_bounds check
EXPECT_TRUE(paddle::framework::contained(a, b)); EXPECT_TRUE(paddle::framework::contained(a, b));
EXPECT_FALSE(paddle::framework::contained(b, a)); EXPECT_FALSE(paddle::framework::contained(b, a));
// comparison // comparison
EXPECT_TRUE(a == a); EXPECT_TRUE(a == a);
EXPECT_FALSE(a == b); EXPECT_FALSE(a == b);
EXPECT_TRUE(a == c); EXPECT_TRUE(a == c);
} }
TEST(Dim, Print) { TEST(Dim, Print) {
{ {
std::stringstream ss; std::stringstream ss;
auto a = paddle::framework::make_dim(2, 3); auto a = paddle::framework::make_dim(2, 3);
ss << a; ss << a;
EXPECT_EQ(ss.str(), "2, 3"); EXPECT_EQ(ss.str(), "2, 3");
} }
{ {
std::stringstream ss; std::stringstream ss;
ss << paddle::framework::make_dim(8); ss << paddle::framework::make_dim(8);
EXPECT_EQ(ss.str(), "8"); EXPECT_EQ(ss.str(), "8");
} }
} }
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册