提交 7d56c6d0 编写于 作者: X xuwei06

Adding Dim<0>

Dim<0> is for scalar (rank-0 tensor). Adding Dim<0> can simplify
a lot of code.
上级 a67cebaf
...@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) { ...@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
} }
template <> template <>
Dim<1> make_dim<1>(const int64_t* d) { Dim<0> make_dim<0>(const int64_t* d) {
return Dim<1>(*d); return Dim<0>(*d);
} }
void make_ddim(DDim& ddim, const int64_t* dims, int n) { void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) { switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1: case 1:
ddim = make_dim<1>(dims); ddim = make_dim<1>(dims);
break; break;
...@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> { ...@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
this->operator()(t.tail); this->operator()(t.tail);
} }
void operator()(const Dim<1>& t) { vector.push_back(t.head); } void operator()(const Dim<0>& t) {}
}; };
/// @endcond /// @endcond
...@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> { ...@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
} }
} }
void operator()(const Dim<1>& dim) { void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound."); PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
vector.push_back(dim.head);
} }
}; };
......
...@@ -30,8 +30,8 @@ namespace framework { ...@@ -30,8 +30,8 @@ namespace framework {
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { struct DDim {
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<8>, Dim<9>> Dim<7>, Dim<8>, Dim<9>>
DDimVar; DDimVar;
DDimVar var; DDimVar var;
......
...@@ -72,38 +72,36 @@ struct Dim { ...@@ -72,38 +72,36 @@ struct Dim {
// Base case specialization // Base case specialization
template <> template <>
struct Dim<1> { struct Dim<0> {
static constexpr int dimensions = 1; static constexpr int dimensions = 0;
HOSTDEVICE HOSTDEVICE
Dim(int64_t _head) : head(_head) {} Dim(int64_t _head) {}
HOSTDEVICE HOSTDEVICE
Dim() : head(0) {} Dim() {}
HOSTDEVICE HOSTDEVICE
Dim(int idx, const Dim<1>& size) : head(idx) { Dim(int idx, const Dim<0>& size) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx >= size.head) { if (idx > 0) {
throw std::invalid_argument("Index out of range."); throw std::invalid_argument("Index out of range.");
} }
#else #else
PADDLE_ASSERT(idx < size.head); PADDLE_ASSERT(idx == 0);
#endif #endif
} }
HOSTDEVICE HOSTDEVICE
bool operator==(const Dim<1>& o) const { return (head == o.head); } bool operator==(const Dim<0>& o) const { return true; }
HOSTDEVICE HOSTDEVICE
bool operator!=(const Dim<1>& o) const { return !(*this == o); } bool operator!=(const Dim<0>& o) const { return false; }
HOSTDEVICE HOSTDEVICE
int64_t& operator[](int idx); int64_t& operator[](int idx);
HOSTDEVICE HOSTDEVICE
int64_t operator[](int idx) const; int64_t operator[](int idx) const;
int64_t head;
}; };
namespace { namespace {
...@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) { ...@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) { HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
}
#else #else
PADDLE_ASSERT(idx == 0); PADDLE_ASSERT(false);
#endif #endif
return dim.head; static int64_t head = 0;
return head;
} }
template <int D> template <int D>
...@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) { ...@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
} }
template <> template <>
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) { HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__ #ifndef __CUDA_ARCH__
if (idx != 0) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
}
#else #else
PADDLE_ASSERT(idx == 0); PADDLE_ASSERT(false);
#endif #endif
return dim.head; static int64_t head = 0;
return head;
} }
} // namespace } // namespace
...@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) { ...@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
} }
// Dynamic access to constant Dim // Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const { inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
return indexer(*this, i); return indexer(*this, i);
} }
// Dynamic access to mutable Dim // Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) { inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
return indexer(*this, i); return indexer(*this, i);
} }
...@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) { ...@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
// Base case dot product of two Dims // Base case dot product of two Dims
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
return a.head * b.head; return 0;
} }
// Product of a Dim // Product of a Dim
...@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) { ...@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
// Base case product of a Dim // Base case product of a Dim
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) { HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod * a.head; return prod;
} }
// Is 0 <= idx_i < size_i for all i? // Is 0 <= idx_i < size_i for all i?
...@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) { ...@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
// Base case of is 0 <= idx_i < size_i ? // Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
return ((0 <= idx.head) && (idx.head < size.head)); return true;
} }
/** /**
...@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) { ...@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
// Base case of ex_prefix_mul // Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template // Notice it is inline because it is no longer a template
template <> template <>
HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
return Dim<1>(mul); return Dim<0>();
} }
///\endcond ///\endcond
...@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) { ...@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
return Dim<1>(a.head + b.head); return Dim<0>();
} }
template <int i> template <int i>
...@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) { ...@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
// Base case // Base case
template <> template <>
HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
return Dim<1>(a.head * b.head); return Dim<0>();
} }
template <int i> template <int i>
...@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) { ...@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
///\cond HIDDEN ///\cond HIDDEN
template <> template <>
HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
const Dim<1>& stride) { const Dim<0>& stride) {
int norm_stride = size.head == 1 ? 0 : stride.head; return Dim<0>();
return Dim<1>(norm_stride);
} }
///\endcond ///\endcond
...@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<( ...@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
return os; return os;
} }
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os;
}
template <int i> template <int i>
HOST std::string Dim<i>::to_string() const { HOST std::string Dim<i>::to_string() const {
std::stringstream stream; std::stringstream stream;
......
...@@ -24,6 +24,29 @@ namespace detail { ...@@ -24,6 +24,29 @@ namespace detail {
template <typename T, int Rank> template <typename T, int Rank>
struct StridedMemcpyFunctor; struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 0> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<0> src_stride, framework::Dim<0> dst_dim,
framework::Dim<0> dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T));
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T),
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
};
template <typename T> template <typename T>
struct StridedMemcpyFunctor<T, 1> { struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src, void operator()(const platform::DeviceContext& dev_ctx, const T* src,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册