提交 7d56c6d0 编写于 作者: X xuwei06

Adding Dim<0>

Dim<0> is for scalar (rank-0 tensor). Adding Dim<0> can simplify
a lot of code.
上级 a67cebaf
......@@ -26,12 +26,15 @@ Dim<i> make_dim(const int64_t* d) {
}
template <>
Dim<1> make_dim<1>(const int64_t* d) {
return Dim<1>(*d);
Dim<0> make_dim<0>(const int64_t* d) {
return Dim<0>(*d);
}
void make_ddim(DDim& ddim, const int64_t* dims, int n) {
switch (n) {
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
......@@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> {
this->operator()(t.tail);
}
void operator()(const Dim<1>& t) { vector.push_back(t.head); }
void operator()(const Dim<0>& t) {}
};
/// @endcond
......@@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> {
}
}
void operator()(const Dim<1>& dim) {
PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound.");
vector.push_back(dim.head);
void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
}
};
......
......@@ -30,8 +30,8 @@ namespace framework {
* The number of dimensions must be between [1, 9].
*/
struct DDim {
typedef boost::variant<Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>,
Dim<8>, Dim<9>>
typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>,
Dim<7>, Dim<8>, Dim<9>>
DDimVar;
DDimVar var;
......
......@@ -72,38 +72,36 @@ struct Dim {
// Base case specialization
template <>
struct Dim<1> {
static constexpr int dimensions = 1;
struct Dim<0> {
static constexpr int dimensions = 0;
HOSTDEVICE
Dim(int64_t _head) : head(_head) {}
Dim(int64_t _head) {}
HOSTDEVICE
Dim() : head(0) {}
Dim() {}
HOSTDEVICE
Dim(int idx, const Dim<1>& size) : head(idx) {
Dim(int idx, const Dim<0>& size) {
#ifndef __CUDA_ARCH__
if (idx >= size.head) {
if (idx > 0) {
throw std::invalid_argument("Index out of range.");
}
#else
PADDLE_ASSERT(idx < size.head);
PADDLE_ASSERT(idx == 0);
#endif
}
HOSTDEVICE
bool operator==(const Dim<1>& o) const { return (head == o.head); }
bool operator==(const Dim<0>& o) const { return true; }
HOSTDEVICE
bool operator!=(const Dim<1>& o) const { return !(*this == o); }
bool operator!=(const Dim<0>& o) const { return false; }
HOSTDEVICE
int64_t& operator[](int idx);
HOSTDEVICE
int64_t operator[](int idx) const;
int64_t head;
};
namespace {
......@@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) {
}
template <>
HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) {
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx != 0) {
throw std::invalid_argument("Invalid index");
}
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(idx == 0);
PADDLE_ASSERT(false);
#endif
return dim.head;
static int64_t head = 0;
return head;
}
template <int D>
......@@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) {
}
template <>
HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) {
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__
if (idx != 0) {
throw std::invalid_argument("Invalid index");
}
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(idx == 0);
PADDLE_ASSERT(false);
#endif
return dim.head;
static int64_t head = 0;
return head;
}
} // namespace
......@@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim<l>::operator[](int i) {
}
// Dynamic access to constant Dim
inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const {
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const {
return indexer(*this, i);
}
// Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) {
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
return indexer(*this, i);
}
......@@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) {
// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) {
return a.head * b.head;
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
return 0;
}
// Product of a Dim
......@@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) {
// Base case product of a Dim
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) {
return prod * a.head;
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod;
}
// Is 0 <= idx_i < size_i for all i?
......@@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) {
// Base case of is 0 <= idx_i < size_i ?
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) {
return ((0 <= idx.head) && (idx.head < size.head));
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
return true;
}
/**
......@@ -294,8 +290,8 @@ HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) {
// Base case of ex_prefix_mul
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) {
return Dim<1>(mul);
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) {
return Dim<0>();
}
///\endcond
......@@ -309,8 +305,8 @@ HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) {
// Base case
template <>
HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) {
return Dim<1>(a.head + b.head);
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
}
template <int i>
......@@ -328,8 +324,8 @@ HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) {
// Base case
template <>
HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) {
return Dim<1>(a.head * b.head);
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
}
template <int i>
......@@ -356,10 +352,9 @@ HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) {
///\cond HIDDEN
template <>
HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size,
const Dim<1>& stride) {
int norm_stride = size.head == 1 ? 0 : stride.head;
return Dim<1>(norm_stride);
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
const Dim<0>& stride) {
return Dim<0>();
}
///\endcond
......@@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
return os;
}
inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os;
}
template <int i>
HOST std::string Dim<i>::to_string() const {
std::stringstream stream;
......
......@@ -24,6 +24,29 @@ namespace detail {
template <typename T, int Rank>
struct StridedMemcpyFunctor;
template <typename T>
struct StridedMemcpyFunctor<T, 0> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<0> src_stride, framework::Dim<0> dst_dim,
framework::Dim<0> dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T));
} else {
#ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T),
cuda_ctx.stream());
#else
PADDLE_THROW("Paddle is not compiled with GPU");
#endif
}
}
};
template <typename T>
struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册