diff --git a/paddle/fluid/framework/ddim.cc b/paddle/fluid/framework/ddim.cc index 97afd366387e9ba6476be59a4d73d53a38834d0e..05e423b8a52962d47a6615d48243444374b470e3 100644 --- a/paddle/fluid/framework/ddim.cc +++ b/paddle/fluid/framework/ddim.cc @@ -26,12 +26,15 @@ Dim make_dim(const int64_t* d) { } template <> -Dim<1> make_dim<1>(const int64_t* d) { - return Dim<1>(*d); +Dim<0> make_dim<0>(const int64_t* d) { + return Dim<0>(*d); } void make_ddim(DDim& ddim, const int64_t* dims, int n) { switch (n) { + case 0: + ddim = make_dim<0>(dims); + break; case 1: ddim = make_dim<1>(dims); break; @@ -190,7 +193,7 @@ struct VectorizeVisitor : public boost::static_visitor<> { this->operator()(t.tail); } - void operator()(const Dim<1>& t) { vector.push_back(t.head); } + void operator()(const Dim<0>& t) {} }; /// @endcond @@ -247,9 +250,8 @@ struct SliceVectorizeVisitor : public boost::static_visitor<> { } } - void operator()(const Dim<1>& dim) { - PADDLE_ENFORCE(end == 1, "End index in ddim slice is out of bound."); - vector.push_back(dim.head); + void operator()(const Dim<0>& dim) { + PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound."); } }; diff --git a/paddle/fluid/framework/ddim.h b/paddle/fluid/framework/ddim.h index 5aff10d3b95902fdb9fe432d9f31830304dd3d07..f05b5ee3faee856a41f1376e5952710b550e7c42 100644 --- a/paddle/fluid/framework/ddim.h +++ b/paddle/fluid/framework/ddim.h @@ -30,8 +30,8 @@ namespace framework { * The number of dimensions must be between [1, 9]. */ struct DDim { - typedef boost::variant, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, Dim<7>, - Dim<8>, Dim<9>> + typedef boost::variant, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, + Dim<7>, Dim<8>, Dim<9>> DDimVar; DDimVar var; diff --git a/paddle/fluid/framework/dim.h b/paddle/fluid/framework/dim.h index 08b708006aadc4769bde7b37347ac1adfeca2bf7..8d288120e30035673be0ec5dc6230f607dfd1ebe 100644 --- a/paddle/fluid/framework/dim.h +++ b/paddle/fluid/framework/dim.h @@ -72,38 +72,36 @@ struct Dim { // Base case specialization template <> -struct Dim<1> { - static constexpr int dimensions = 1; +struct Dim<0> { + static constexpr int dimensions = 0; HOSTDEVICE - Dim(int64_t _head) : head(_head) {} + Dim(int64_t _head) {} HOSTDEVICE - Dim() : head(0) {} + Dim() {} HOSTDEVICE - Dim(int idx, const Dim<1>& size) : head(idx) { + Dim(int idx, const Dim<0>& size) { #ifndef __CUDA_ARCH__ - if (idx >= size.head) { + if (idx > 0) { throw std::invalid_argument("Index out of range."); } #else - PADDLE_ASSERT(idx < size.head); + PADDLE_ASSERT(idx == 0); #endif } HOSTDEVICE - bool operator==(const Dim<1>& o) const { return (head == o.head); } + bool operator==(const Dim<0>& o) const { return true; } HOSTDEVICE - bool operator!=(const Dim<1>& o) const { return !(*this == o); } + bool operator!=(const Dim<0>& o) const { return false; } HOSTDEVICE int64_t& operator[](int idx); HOSTDEVICE int64_t operator[](int idx) const; - - int64_t head; }; namespace { @@ -154,15 +152,14 @@ HOSTDEVICE int64_t& indexer(Dim& dim, int idx) { } template <> -HOSTDEVICE int64_t& indexer<1>(Dim<1>& dim, int idx) { +HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) { #ifndef __CUDA_ARCH__ - if (idx != 0) { - throw std::invalid_argument("Invalid index"); - } + throw std::invalid_argument("Invalid index"); #else - PADDLE_ASSERT(idx == 0); + PADDLE_ASSERT(false); #endif - return dim.head; + static int64_t head = 0; + return head; } template @@ -181,15 +178,14 @@ HOSTDEVICE int64_t indexer(const Dim& dim, int idx) { } template <> -HOSTDEVICE int64_t indexer<1>(const Dim<1>& dim, int idx) { +HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) { #ifndef __CUDA_ARCH__ - if (idx != 0) { - throw std::invalid_argument("Invalid index"); - } + throw std::invalid_argument("Invalid index"); #else - PADDLE_ASSERT(idx == 0); + PADDLE_ASSERT(false); #endif - return dim.head; + static int64_t head = 0; + return head; } } // namespace @@ -218,12 +214,12 @@ HOSTDEVICE int64_t& Dim::operator[](int i) { } // Dynamic access to constant Dim -inline HOSTDEVICE int64_t Dim<1>::operator[](int i) const { +inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const { return indexer(*this, i); } // Dynamic access to mutable Dim -inline HOSTDEVICE int64_t& Dim<1>::operator[](int i) { +inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) { return indexer(*this, i); } @@ -251,8 +247,8 @@ HOSTDEVICE int64_t linearize(const Dim& a, const Dim& b) { // Base case dot product of two Dims // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline int64_t linearize(const Dim<1>& a, const Dim<1>& b) { - return a.head * b.head; +HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) { + return 0; } // Product of a Dim @@ -264,8 +260,8 @@ HOSTDEVICE int64_t product(const Dim& a, int prod = 1) { // Base case product of a Dim // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline int64_t product(const Dim<1>& a, int prod) { - return prod * a.head; +HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) { + return prod; } // Is 0 <= idx_i < size_i for all i? @@ -278,8 +274,8 @@ HOSTDEVICE bool contained(const Dim& idx, const Dim& size) { // Base case of is 0 <= idx_i < size_i ? // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { - return ((0 <= idx.head) && (idx.head < size.head)); +HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) { + return true; } /** @@ -294,8 +290,8 @@ HOSTDEVICE Dim ex_prefix_mul(const Dim& src, int mul = 1) { // Base case of ex_prefix_mul // Notice it is inline because it is no longer a template template <> -HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { - return Dim<1>(mul); +HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) { + return Dim<0>(); } ///\endcond @@ -309,8 +305,8 @@ HOSTDEVICE Dim dim_plus(const Dim& a, const Dim& b) { // Base case template <> -HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { - return Dim<1>(a.head + b.head); +HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) { + return Dim<0>(); } template @@ -328,8 +324,8 @@ HOSTDEVICE Dim dim_mult(const Dim& a, const Dim& b) { // Base case template <> -HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { - return Dim<1>(a.head * b.head); +HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) { + return Dim<0>(); } template @@ -356,10 +352,9 @@ HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { ///\cond HIDDEN template <> -HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, - const Dim<1>& stride) { - int norm_stride = size.head == 1 ? 0 : stride.head; - return Dim<1>(norm_stride); +HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size, + const Dim<0>& stride) { + return Dim<0>(); } ///\endcond @@ -394,6 +389,10 @@ typename std::enable_if<(i == 1), std::ostream&>::type operator<<( return os; } +inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { + return os; +} + template HOST std::string Dim::to_string() const { std::stringstream stream; diff --git a/paddle/fluid/operators/detail/strided_memcpy.h b/paddle/fluid/operators/detail/strided_memcpy.h index bac5cdc99c0133b1e6da3f6a23bc0512ca4177f5..0b7c470fe72eb4270b8d5b2d227642d85683c16d 100644 --- a/paddle/fluid/operators/detail/strided_memcpy.h +++ b/paddle/fluid/operators/detail/strided_memcpy.h @@ -24,6 +24,29 @@ namespace detail { template struct StridedMemcpyFunctor; +template +struct StridedMemcpyFunctor { + void operator()(const platform::DeviceContext& dev_ctx, const T* src, + framework::Dim<0> src_stride, framework::Dim<0> dst_dim, + framework::Dim<0> dst_stride, T* dst) const { + auto place = dev_ctx.GetPlace(); + if (platform::is_cpu_place(place)) { + auto& cpu_place = boost::get(place); + memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T)); + } else { +#ifdef PADDLE_WITH_CUDA + auto& gpu_place = boost::get(place); + auto& cuda_ctx = + reinterpret_cast(dev_ctx); + memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T), + cuda_ctx.stream()); +#else + PADDLE_THROW("Paddle is not compiled with GPU"); +#endif + } + } +}; + template struct StridedMemcpyFunctor { void operator()(const platform::DeviceContext& dev_ctx, const T* src,