提交 11163dfc 编写于 作者: Q qijun

make dim int to int64_t

上级 b64aac54
...@@ -21,16 +21,16 @@ namespace framework { ...@@ -21,16 +21,16 @@ namespace framework {
/// @cond HIDDEN /// @cond HIDDEN
template <int i> template <int i>
Dim<i> make_dim(const int* d) { Dim<i> make_dim(const int64_t* d) {
return Dim<i>(*d, make_dim<i - 1>(d + 1)); return Dim<i>(*d, make_dim<i - 1>(d + 1));
} }
template <> template <>
Dim<1> make_dim<1>(const int* d) { Dim<1> make_dim<1>(const int64_t* d) {
return Dim<1>(*d); return Dim<1>(*d);
} }
void make_ddim(DDim& ddim, const int* dims, int n) { void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) { switch (n) {
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
...@@ -67,13 +67,13 @@ void make_ddim(DDim& ddim, const int* dims, int n) { ...@@ -67,13 +67,13 @@ void make_ddim(DDim& ddim, const int* dims, int n) {
/// @endcond /// @endcond
DDim make_ddim(std::initializer_list<int> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, dims.begin(), dims.size()); make_ddim(result, dims.begin(), dims.size());
return result; return result;
} }
DDim make_ddim(const std::vector<int>& dims) { DDim make_ddim(const std::vector<int64_t>& dims) {
DDim result(make_dim(0)); DDim result(make_dim(0));
make_ddim(result, &dims[0], dims.size()); make_ddim(result, &dims[0], dims.size());
return result; return result;
...@@ -81,12 +81,12 @@ DDim make_ddim(const std::vector<int>& dims) { ...@@ -81,12 +81,12 @@ DDim make_ddim(const std::vector<int>& dims) {
/// @cond HIDDEN /// @cond HIDDEN
// XXX For some reason, putting this in an anonymous namespace causes errors // XXX For some reason, putting this in an anonymous namespace causes errors
class DynamicMutableIndexer : public boost::static_visitor<int&> { class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
public: public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {} explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D> template <int D>
int& operator()(Dim<D>& dim) const { int64_t& operator()(Dim<D>& dim) const {
return dim[idx_]; return dim[idx_];
} }
...@@ -94,12 +94,12 @@ class DynamicMutableIndexer : public boost::static_visitor<int&> { ...@@ -94,12 +94,12 @@ class DynamicMutableIndexer : public boost::static_visitor<int&> {
int idx_; int idx_;
}; };
class DynamicConstIndexer : public boost::static_visitor<int> { class DynamicConstIndexer : public boost::static_visitor<int64_t> {
public: public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {} explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> template <int D>
int operator()(const Dim<D>& dim) const { int64_t operator()(const Dim<D>& dim) const {
return dim[idx_]; return dim[idx_];
} }
...@@ -109,22 +109,22 @@ class DynamicConstIndexer : public boost::static_visitor<int> { ...@@ -109,22 +109,22 @@ class DynamicConstIndexer : public boost::static_visitor<int> {
/// @endcond /// @endcond
int& DDim::operator[](int idx) { int64_t& DDim::operator[](int idx) {
return boost::apply_visitor(DynamicMutableIndexer(idx), var); return boost::apply_visitor(DynamicMutableIndexer(idx), var);
} }
int DDim::operator[](int idx) const { int64_t DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var); return boost::apply_visitor(DynamicConstIndexer(idx), var);
} }
ssize_t DDim::size() const { return arity(*this); } int64_t DDim::size() const { return arity(*this); }
bool DDim::operator==(DDim d) const { bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) { if (var.which() != d.getVar().which()) {
return false; return false;
} else { } else {
std::vector<int> v1 = vectorize(*this); std::vector<int64_t> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d); std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) { for (unsigned int i = 0; i < v1.size(); i++) {
if (v1[i] != v2[i]) { if (v1[i] != v2[i]) {
...@@ -139,10 +139,10 @@ bool DDim::operator==(DDim d) const { ...@@ -139,10 +139,10 @@ bool DDim::operator==(DDim d) const {
bool DDim::operator!=(DDim d) const { return !(*this == d); } bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const { DDim DDim::operator+(DDim d) const {
std::vector<int> v1 = vectorize(*this); std::vector<int64_t> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d); std::vector<int64_t> v2 = vectorize(d);
std::vector<int> v3; std::vector<int64_t> v3;
assert(v1.size() == v2.size()); assert(v1.size() == v2.size());
...@@ -154,10 +154,10 @@ DDim DDim::operator+(DDim d) const { ...@@ -154,10 +154,10 @@ DDim DDim::operator+(DDim d) const {
} }
DDim DDim::operator*(DDim d) const { DDim DDim::operator*(DDim d) const {
std::vector<int> v1 = vectorize(*this); std::vector<int64_t> v1 = vectorize(*this);
std::vector<int> v2 = vectorize(d); std::vector<int64_t> v2 = vectorize(d);
std::vector<int> v3; std::vector<int64_t> v3;
assert(v1.size() == v2.size()); assert(v1.size() == v2.size());
...@@ -168,15 +168,15 @@ DDim DDim::operator*(DDim d) const { ...@@ -168,15 +168,15 @@ DDim DDim::operator*(DDim d) const {
return make_ddim(v3); return make_ddim(v3);
} }
int get(const DDim& ddim, int idx) { return ddim[idx]; } int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; }
/// @cond HIDDEN /// @cond HIDDEN
struct VectorizeVisitor : public boost::static_visitor<> { struct VectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector; std::vector<int64_t>& vector;
explicit VectorizeVisitor(std::vector<int>& v) : vector(v) {} explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
template <typename T> template <typename T>
void operator()(const T& t) { void operator()(const T& t) {
...@@ -188,31 +188,31 @@ struct VectorizeVisitor : public boost::static_visitor<> { ...@@ -188,31 +188,31 @@ struct VectorizeVisitor : public boost::static_visitor<> {
}; };
/// @endcond /// @endcond
std::vector<int> vectorize(const DDim& ddim) { std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int> result; std::vector<int64_t> result;
VectorizeVisitor visitor(result); VectorizeVisitor visitor(result);
boost::apply_visitor(visitor, ddim); boost::apply_visitor(visitor, ddim);
return result; return result;
} }
struct ProductVisitor : public boost::static_visitor<ssize_t> { struct ProductVisitor : public boost::static_visitor<int64_t> {
template <int D> template <int D>
ssize_t operator()(const Dim<D>& dim) { int64_t operator()(const Dim<D>& dim) {
return product(dim); return product(dim);
} }
}; };
ssize_t product(const DDim& ddim) { int64_t product(const DDim& ddim) {
ProductVisitor visitor; ProductVisitor visitor;
return boost::apply_visitor(visitor, ddim); return boost::apply_visitor(visitor, ddim);
} }
struct SliceVectorizeVisitor : public boost::static_visitor<> { struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int>& vector; std::vector<int64_t>& vector;
int begin; int begin;
int end; int end;
SliceVectorizeVisitor(std::vector<int>& v, int b, int e) SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
: vector(v), begin(b), end(e) { : vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end, PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice."); "Begin index must be less than end index in ddim slice.");
...@@ -240,7 +240,7 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> { ...@@ -240,7 +240,7 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
}; };
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int> vec; std::vector<int64_t> vec;
vec.reserve(end - begin); vec.reserve(end - begin);
SliceVectorizeVisitor visitor(vec, begin, end); SliceVectorizeVisitor visitor(vec, begin, end);
boost::apply_visitor(visitor, dim); boost::apply_visitor(visitor, dim);
...@@ -280,7 +280,7 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) { ...@@ -280,7 +280,7 @@ std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
return os; return os;
} }
DDim::DDim(std::initializer_list<int> init_list) { DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list); *this = make_ddim(init_list);
} }
} // namespace framework } // namespace framework
......
...@@ -40,7 +40,7 @@ struct DDim { ...@@ -40,7 +40,7 @@ struct DDim {
template <int D> template <int D>
explicit DDim(const Dim<D>& in) : var(in) {} explicit DDim(const Dim<D>& in) : var(in) {}
/*implicit*/ DDim(std::initializer_list<int> init_list); /*implicit*/ DDim(std::initializer_list<int64_t> init_list);
template <int D> template <int D>
DDim& operator=(const Dim<D>& in) { DDim& operator=(const Dim<D>& in) {
...@@ -48,8 +48,8 @@ struct DDim { ...@@ -48,8 +48,8 @@ struct DDim {
return *this; return *this;
} }
int& operator[](int idx); int64_t& operator[](int idx);
int operator[](int idx) const; int64_t operator[](int idx) const;
template <typename Visitor> template <typename Visitor>
typename Visitor::result_type apply_visitor(Visitor& visitor) { typename Visitor::result_type apply_visitor(Visitor& visitor) {
...@@ -71,15 +71,15 @@ struct DDim { ...@@ -71,15 +71,15 @@ struct DDim {
DDim operator*(DDim d) const; DDim operator*(DDim d) const;
ssize_t size() const; int64_t size() const;
}; };
/** /**
* \brief Make a DDim from std::vector<int> * \brief Make a DDim from std::vector<int64_t>
* *
* \param dims An vector of ints. Must be sized between [1, 9] * \param dims An vector of ints. Must be sized between [1, 9]
*/ */
DDim make_ddim(const std::vector<int>& dims); DDim make_ddim(const std::vector<int64_t>& dims);
/** /**
* \brief Make a DDim from an initializer list * \brief Make a DDim from an initializer list
...@@ -87,14 +87,14 @@ DDim make_ddim(const std::vector<int>& dims); ...@@ -87,14 +87,14 @@ DDim make_ddim(const std::vector<int>& dims);
* \param dims An initializer list of ints. Must be sized between [1, 9] * \param dims An initializer list of ints. Must be sized between [1, 9]
* *
*/ */
DDim make_ddim(std::initializer_list<int> dims); DDim make_ddim(std::initializer_list<int64_t> dims);
int get(const DDim& dim, int idx); int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val); void set(DDim& dim, int idx, int val);
std::vector<int> vectorize(const DDim& ddim); std::vector<int64_t> vectorize(const DDim& ddim);
ssize_t product(const DDim& ddim); int64_t product(const DDim& ddim);
/** /**
* \brief Slice a ddim * \brief Slice a ddim
......
...@@ -12,7 +12,7 @@ TEST(DDim, Equality) { ...@@ -12,7 +12,7 @@ TEST(DDim, Equality) {
EXPECT_EQ(ddim[2], 5); EXPECT_EQ(ddim[2], 5);
// construct a DDim from a vector // construct a DDim from a vector
std::vector<int> vec({9, 1, 5}); std::vector<int64_t> vec({9, 1, 5});
paddle::framework::DDim vddim = paddle::framework::make_ddim(vec); paddle::framework::DDim vddim = paddle::framework::make_ddim(vec);
EXPECT_EQ(ddim[0], 9); EXPECT_EQ(ddim[0], 9);
EXPECT_EQ(ddim[1], 1); EXPECT_EQ(ddim[1], 1);
...@@ -25,7 +25,7 @@ TEST(DDim, Equality) { ...@@ -25,7 +25,7 @@ TEST(DDim, Equality) {
EXPECT_EQ(paddle::framework::get(ddim, 0), 6); EXPECT_EQ(paddle::framework::get(ddim, 0), 6);
// vectorize a DDim // vectorize a DDim
std::vector<int> res_vec = paddle::framework::vectorize(vddim); std::vector<int64_t> res_vec = paddle::framework::vectorize(vddim);
EXPECT_EQ(res_vec[0], 9); EXPECT_EQ(res_vec[0], 9);
EXPECT_EQ(res_vec[1], 1); EXPECT_EQ(res_vec[1], 1);
EXPECT_EQ(res_vec[2], 5); EXPECT_EQ(res_vec[2], 5);
......
...@@ -17,13 +17,13 @@ struct Dim { ...@@ -17,13 +17,13 @@ struct Dim {
static constexpr int dimensions = i; static constexpr int dimensions = i;
template <typename... Args> template <typename... Args>
HOSTDEVICE Dim(int _head, Args... _tail) : head(_head), tail(_tail...) { HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) {
static_assert(sizeof...(_tail) == i - 1, static_assert(sizeof...(_tail) == i - 1,
"Dim initialized with the wrong number of parameters"); "Dim initialized with the wrong number of parameters");
} }
HOSTDEVICE HOSTDEVICE
Dim(int _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {} Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {}
HOSTDEVICE HOSTDEVICE
Dim() : head(0), tail() {} Dim() : head(0), tail() {}
...@@ -31,12 +31,12 @@ struct Dim { ...@@ -31,12 +31,12 @@ struct Dim {
/** Construct a Dim from a linear index and size. Uses Fortran order /** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */ * indexing. */
HOSTDEVICE HOSTDEVICE
Dim(int idx, const Dim<i>& size) Dim(int64_t idx, const Dim<i>& size)
: head(idx % size.head), tail(idx / size.head, size.tail) {} : head(idx % size.head), tail(idx / size.head, size.tail) {}
/** Construct a Dim with each dimension set to the given index */ /** Construct a Dim with each dimension set to the given index */
HOSTDEVICE HOSTDEVICE
Dim(int idx) : head(idx), tail(idx) {} Dim(int64_t idx) : head(idx), tail(idx) {}
HOSTDEVICE HOSTDEVICE
bool operator==(const Dim<i>& o) const { bool operator==(const Dim<i>& o) const {
...@@ -47,13 +47,13 @@ struct Dim { ...@@ -47,13 +47,13 @@ struct Dim {
bool operator!=(const Dim<i>& o) const { return !(*this == o); } bool operator!=(const Dim<i>& o) const { return !(*this == o); }
HOSTDEVICE HOSTDEVICE
int& operator[](int idx); int64_t& operator[](int idx);
HOSTDEVICE HOSTDEVICE
int operator[](int idx) const; int64_t operator[](int idx) const;
HOST std::string to_string() const; HOST std::string to_string() const;
int head; int64_t head;
Dim<i - 1> tail; Dim<i - 1> tail;
}; };
...@@ -63,7 +63,7 @@ struct Dim<1> { ...@@ -63,7 +63,7 @@ struct Dim<1> {
static constexpr int dimensions = 1; static constexpr int dimensions = 1;
HOSTDEVICE HOSTDEVICE
Dim(int _head) : head(_head) {} Dim(int64_t _head) : head(_head) {}
HOSTDEVICE HOSTDEVICE
Dim() : head(0) {} Dim() : head(0) {}
...@@ -86,11 +86,11 @@ struct Dim<1> { ...@@ -86,11 +86,11 @@ struct Dim<1> {
bool operator!=(const Dim<1>& o) const { return !(*this == o); } bool operator!=(const Dim<1>& o) const { return !(*this == o); }
HOSTDEVICE HOSTDEVICE
int& operator[](int idx); int64_t& operator[](int idx);
HOSTDEVICE HOSTDEVICE
int operator[](int idx) const; int64_t operator[](int idx) const;
int head; int64_t head;
}; };
namespace { namespace {
...@@ -100,12 +100,12 @@ template <int i> ...@@ -100,12 +100,12 @@ template <int i>
struct DimGetter { struct DimGetter {
// Return a copy if Dim is const // Return a copy if Dim is const
template <typename D> template <typename D>
HOSTDEVICE static int impl(const D& d) { HOSTDEVICE static int64_t impl(const D& d) {
return DimGetter<i - 1>::impl(d.tail); return DimGetter<i - 1>::impl(d.tail);
} }
// Return a reference if Dim is mutable // Return a reference if Dim is mutable
template <typename D> template <typename D>
HOSTDEVICE static int& impl(D& d) { HOSTDEVICE static int64_t& impl(D& d) {
return DimGetter<i - 1>::impl(d.tail); return DimGetter<i - 1>::impl(d.tail);
} }
}; };
...@@ -115,18 +115,18 @@ template <> ...@@ -115,18 +115,18 @@ template <>
struct DimGetter<0> { struct DimGetter<0> {
// Return a copy if Dim is const // Return a copy if Dim is const
template <typename D> template <typename D>
HOSTDEVICE static int impl(const D& d) { HOSTDEVICE static int64_t impl(const D& d) {
return d.head; return d.head;
} }
// Return a reference if Dim is mutable // Return a reference if Dim is mutable
template <typename D> template <typename D>
HOSTDEVICE static int& impl(D& d) { HOSTDEVICE static int64_t& impl(D& d) {
return d.head; return d.head;
} }
}; };
template <int D> template <int D>
HOSTDEVICE int& indexer(Dim<D>& dim, int idx) { HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx < 0) { if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
...@@ -141,7 +141,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) { ...@@ -141,7 +141,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) { if (idx != 0) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
...@@ -153,7 +153,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { ...@@ -153,7 +153,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
} }
template <int D> template <int D>
HOSTDEVICE int indexer(const Dim<D>& dim, int idx) { HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx < 0) { if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
...@@ -168,7 +168,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) { ...@@ -168,7 +168,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) { if (idx != 0) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
...@@ -182,73 +182,76 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { ...@@ -182,73 +182,76 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
} // namespace } // namespace
// Static access to constant Dim // Static access to constant Dim
template <int i, int l> template <int i, int l>
HOSTDEVICE int get(const Dim<l>& d) { HOSTDEVICE int64_t get(const Dim<l>& d) {
return DimGetter<i>::impl(d); return DimGetter<i>::impl(d);
} }
// Static access to mutable Dim // Static access to mutable Dim
template <int i, int l> template <int i, int l>
HOSTDEVICE int& get(Dim<l>& d) { HOSTDEVICE int64_t& get(Dim<l>& d) {
return DimGetter<i>::impl(d); return DimGetter<i>::impl(d);
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
template <int l> template <int l>
HOSTDEVICE int Dim<l>::operator[](int i) const { HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
template <int l> template <int l>
HOSTDEVICE int& Dim<l>::operator[](int i) { HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
inline HOSTDEVICE int Dim<1>::operator[](int i) const { inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
inline HOSTDEVICE int& Dim<1>::operator[](int i) { return indexer(*this, i); } inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
return indexer(*this, i);
}
// Dynamic access to constant Dim // Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d) // without std::enable_if will try to instantiate this on get<0>(d)
template <int l> template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int>::type get(const Dim<l>& d, HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
int i) { int i) {
return d[i]; return d[i];
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
template <int l> template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int&>::type get(Dim<l>& d, int i) { HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
int i) {
return d[i]; return d[i];
} }
// Dot product of two dims // Dot product of two dims
template <int i> template <int i>
HOSTDEVICE int linearize(const Dim<i>& a, const Dim<i>& b) { HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
return a.head * b.head + linearize(a.tail, b.tail); return a.head * b.head + linearize(a.tail, b.tail);
} }
// Base case dot product of two Dims // Base case dot product of two Dims
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int linearize(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
return a.head * b.head; return a.head * b.head;
} }
// Product of a Dim // Product of a Dim
template <int i> template <int i>
HOSTDEVICE int product(const Dim<i>& a, int prod = 1) { HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
return prod * a.head * product(a.tail); return prod * a.head * product(a.tail);
} }
// Base case product of a Dim // Base case product of a Dim
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int product(const Dim<1>& a, int prod) { HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
return prod * a.head; return prod * a.head;
} }
......
...@@ -47,9 +47,9 @@ TEST(Dim, Equality) { ...@@ -47,9 +47,9 @@ TEST(Dim, Equality) {
EXPECT_EQ(b[1], 11); EXPECT_EQ(b[1], 11);
// dynamic access on GPU // dynamic access on GPU
thrust::device_vector<int> r(1); thrust::device_vector<int64_t> r(1);
dyn_idx_gpu<<<1, 1>>>(thrust::raw_pointer_cast(r.data())); dyn_idx_gpu<<<1, 1>>>(thrust::raw_pointer_cast(r.data()));
int res = r[0]; int64_t res = r[0];
EXPECT_EQ(res, 6); EXPECT_EQ(res, 6);
// ex_prefix_mul // ex_prefix_mul
......
...@@ -28,7 +28,7 @@ struct EigenDim { ...@@ -28,7 +28,7 @@ struct EigenDim {
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret; Type ret;
for (int d = 0; d < arity(dims); d++) { for (int64_t d = 0; d < arity(dims); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
} }
return ret; return ret;
......
...@@ -58,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place) { ...@@ -58,7 +58,7 @@ inline T* Tensor::mutable_data(platform::Place place) {
"Tensor's numel must be larger than zero to call " "Tensor's numel must be larger than zero to call "
"Tensor::mutable_data. Call Tensor::set_dim first."); "Tensor::mutable_data. Call Tensor::set_dim first.");
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
size_t size = product(dims_) * sizeof(T); int64_t size = product(dims_) * sizeof(T);
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
...@@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const { ...@@ -131,7 +131,7 @@ inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
PADDLE_ENFORCE_LT(begin_idx, end_idx, PADDLE_ENFORCE_LT(begin_idx, end_idx,
"Begin index must be less than end index."); "Begin index must be less than end index.");
PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1."); PADDLE_ENFORCE_NE(dims_[0], 1, "Can not slice a tensor with dims_[0] = 1.");
int base = product(dims_) / dims_[0]; size_t base = product(dims_) / dims_[0];
Tensor dst; Tensor dst;
dst.holder_ = holder_; dst.holder_ = holder_;
DDim dst_dims = dims_; DDim dst_dims = dims_;
......
...@@ -31,8 +31,8 @@ class CPUGaussianRandomKernel : public framework::OpKernel { ...@@ -31,8 +31,8 @@ class CPUGaussianRandomKernel : public framework::OpKernel {
} }
engine.seed(seed); engine.seed(seed);
std::normal_distribution<T> dist(mean, std); std::normal_distribution<T> dist(mean, std);
ssize_t size = framework::product(tensor->dims()); int64_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
} }
} }
...@@ -46,9 +46,13 @@ class GaussianRandomOp : public framework::OperatorWithKernel { ...@@ -46,9 +46,13 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
void InferShape(const framework::InferShapeContext& context) const override { void InferShape(const framework::InferShapeContext& context) const override {
auto* tensor = context.Output<framework::Tensor>("Out"); auto* tensor = context.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = GetAttr<std::vector<int>>("dims");
std::vector<int64_t> temp(dims.size());
for (auto dim : dims) {
temp.push_back(static_cast<int64_t>(dim));
}
PADDLE_ENFORCE(dims.size() > 0UL, PADDLE_ENFORCE(dims.size() > 0UL,
"dims can be one int or array. dims must be set."); "dims can be one int or array. dims must be set.");
tensor->Resize(framework::make_ddim(dims)); tensor->Resize(framework::make_ddim(temp));
} }
}; };
......
...@@ -61,7 +61,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes, ...@@ -61,7 +61,7 @@ void ConcatOutputs(const std::vector<Scope*>& step_scopes,
PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope", PADDLE_ENFORCE(step_scope_var != nullptr, "%s not in scope",
outlinks[i].internal); outlinks[i].internal);
f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims(); f::DDim step_dims = step_scope_var->template GetMutable<Tensor>()->dims();
std::vector<int> dims_vec = vectorize(step_dims); std::vector<int64_t> dims_vec = vectorize(step_dims);
dims_vec.insert(dims_vec.begin(), seq_len); dims_vec.insert(dims_vec.begin(), seq_len);
output->Resize(f::make_ddim(dims_vec)); output->Resize(f::make_ddim(dims_vec));
} else { } else {
......
...@@ -35,8 +35,8 @@ class CPUUniformRandomKernel : public framework::OpKernel { ...@@ -35,8 +35,8 @@ class CPUUniformRandomKernel : public framework::OpKernel {
std::uniform_real_distribution<T> dist( std::uniform_real_distribution<T> dist(
static_cast<T>(context.GetAttr<float>("min")), static_cast<T>(context.GetAttr<float>("min")),
static_cast<T>(context.GetAttr<float>("max"))); static_cast<T>(context.GetAttr<float>("max")));
ssize_t size = framework::product(tensor->dims()); int64_t size = framework::product(tensor->dims());
for (ssize_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
data[i] = dist(engine); data[i] = dist(engine);
} }
} }
...@@ -52,7 +52,11 @@ class UniformRandomOp : public framework::OperatorWithKernel { ...@@ -52,7 +52,11 @@ class UniformRandomOp : public framework::OperatorWithKernel {
"uniform_random's min must less then max"); "uniform_random's min must less then max");
auto* tensor = ctx.Output<framework::Tensor>("Out"); auto* tensor = ctx.Output<framework::Tensor>("Out");
auto dims = GetAttr<std::vector<int>>("dims"); auto dims = GetAttr<std::vector<int>>("dims");
tensor->Resize(framework::make_ddim(dims)); std::vector<int64_t> temp(dims.size());
for (auto dim : dims) {
temp.push_back(static_cast<int64_t>(dim));
}
tensor->Resize(framework::make_ddim(temp));
} }
}; };
......
...@@ -76,7 +76,7 @@ PYBIND11_PLUGIN(core) { ...@@ -76,7 +76,7 @@ PYBIND11_PLUGIN(core) {
.def("get_dims", .def("get_dims",
[](const Tensor &self) { return vectorize(self.dims()); }) [](const Tensor &self) { return vectorize(self.dims()); })
.def("set_dims", .def("set_dims",
[](Tensor &self, const std::vector<int> &dim) { [](Tensor &self, const std::vector<int64_t> &dim) {
self.Resize(make_ddim(dim)); self.Resize(make_ddim(dim));
}) })
.def("alloc_float", .def("alloc_float",
......
...@@ -85,7 +85,7 @@ void PyCPUTensorSetFromArray( ...@@ -85,7 +85,7 @@ void PyCPUTensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array, py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::CPUPlace &place) { paddle::platform::CPUPlace &place) {
std::vector<int> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]); dims.push_back((int)array.shape()[i]);
...@@ -102,7 +102,7 @@ void PyCUDATensorSetFromArray( ...@@ -102,7 +102,7 @@ void PyCUDATensorSetFromArray(
framework::Tensor &self, framework::Tensor &self,
py::array_t<T, py::array::c_style | py::array::forcecast> array, py::array_t<T, py::array::c_style | py::array::forcecast> array,
paddle::platform::GPUPlace &place) { paddle::platform::GPUPlace &place) {
std::vector<int> dims; std::vector<int64_t> dims;
dims.reserve(array.ndim()); dims.reserve(array.ndim());
for (size_t i = 0; i < array.ndim(); ++i) { for (size_t i = 0; i < array.ndim(); ++i) {
dims.push_back((int)array.shape()[i]); dims.push_back((int)array.shape()[i]);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册