提交 a500dfa5 编写于 作者: S sneaxiy

rewrite ddim

test=develop
上级 e2130502
...@@ -36,7 +36,7 @@ add_subdirectory(details) ...@@ -36,7 +36,7 @@ add_subdirectory(details)
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_proto SRCS data_feed.proto) proto_library(async_executor_proto SRCS data_feed.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost enforce)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
nv_test(dim_test SRCS dim_test.cu DEPS ddim) nv_test(dim_test SRCS dim_test.cu DEPS ddim)
cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context) cc_library(data_type SRCS data_type.cc DEPS framework_proto ddim device_context)
......
...@@ -15,34 +15,88 @@ ...@@ -15,34 +15,88 @@
#pragma once #pragma once
#include <cstdint> #include <cstdint>
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/framework/unroll_array_ops.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
template <typename T, size_t N> template <typename T, size_t N>
class Array { class Array {
static_assert(N > 0, "The size of array must be larger than 0");
public: public:
HOSTDEVICE Array() {} static constexpr size_t kSize = N;
HOSTDEVICE explicit Array(const T &val) { HOSTDEVICE inline Array() = default;
for (size_t i = 0; i < N; ++i) data_[i] = val;
template <typename... Args>
HOSTDEVICE inline explicit Array(const T &val, Args... args) {
UnrollVarArgsAssign<T, N>::Run(data_, val, args...);
} }
HOSTDEVICE const T *Get() const { return data_; } HOSTDEVICE inline void Fill(const T &val) {
UnrollFillConstant<N>::Run(data_, val);
}
HOSTDEVICE T *GetMutable() { return data_; } HOSTDEVICE inline const T *Get() const { return data_; }
HOSTDEVICE T &operator[](size_t index) { return data_[index]; } HOSTDEVICE inline T *GetMutable() { return data_; }
HOSTDEVICE const T &operator[](size_t index) const { return data_[index]; } HOSTDEVICE inline T &operator[](size_t index) { return data_[index]; }
HOSTDEVICE inline const T &operator[](size_t index) const {
return data_[index];
}
HOSTDEVICE constexpr size_t size() const { return N; } HOSTDEVICE constexpr size_t size() const { return N; }
HOSTDEVICE inline bool operator==(const Array<T, N> &other) const {
return UnrollCompare<N>::Run(data_, other.data_);
}
HOSTDEVICE inline bool operator!=(const Array<T, N> &other) const {
return !(*this == other);
}
private: private:
T data_[N]; T data_[N];
}; };
template <typename T>
class Array<T, 0> {
public:
static constexpr size_t kSize = 0;
HOSTDEVICE inline Array() = default;
HOSTDEVICE inline void Fill(const T &val) {}
HOSTDEVICE inline constexpr T *Get() const { return nullptr; }
// Add constexpr to GetMutable() cause warning in MAC
HOSTDEVICE inline T *GetMutable() { return nullptr; }
HOSTDEVICE inline T &operator[](size_t index) {
#ifndef __CUDA_ARCH__
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE inline const T &operator[](size_t index) const {
#ifndef __CUDA_ARCH__
PADDLE_THROW("Array<T, 0> has no element");
#endif
}
HOSTDEVICE constexpr size_t size() const { return 0; }
HOSTDEVICE constexpr bool operator==(const Array<T, 0> &other) const {
return true;
}
HOSTDEVICE constexpr bool operator!=(const Array<T, 0> &other) const {
return false;
}
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,201 +18,131 @@ limitations under the License. */ ...@@ -18,201 +18,131 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
/// @cond HIDDEN template <typename T>
struct DDimAssignFunctor {
static_assert(std::is_integral<T>::value, "T must be integral type");
using result_type = void;
explicit DDimAssignFunctor(const T* in) : in_(in) {}
template <int i> template <int D>
Dim<i> make_dim(const int64_t* d) { inline void operator()(Dim<D>& dim) { // NOLINT
return Dim<i>(*d, make_dim<i - 1>(d + 1)); UnrollAssign<D>::Run(in_, dim.data());
} }
const T* in_;
};
template <> DDim::DDim(const int* d, int n) : rank_(n) {
Dim<0> make_dim<0>(const int64_t* d) { this->apply_visitor(DDimAssignFunctor<int>(d));
return Dim<0>(*d);
} }
void make_ddim(DDim& ddim, const int64_t* dims, int n) { DDim::DDim(const int64_t* d, int n) : rank_(n) {
switch (n) { this->apply_visitor(DDimAssignFunctor<int64_t>(d));
case 0:
ddim = make_dim<0>(dims);
break;
case 1:
ddim = make_dim<1>(dims);
break;
case 2:
ddim = make_dim<2>(dims);
break;
case 3:
ddim = make_dim<3>(dims);
break;
case 4:
ddim = make_dim<4>(dims);
break;
case 5:
ddim = make_dim<5>(dims);
break;
case 6:
ddim = make_dim<6>(dims);
break;
case 7:
ddim = make_dim<7>(dims);
break;
case 8:
ddim = make_dim<8>(dims);
break;
case 9:
ddim = make_dim<9>(dims);
break;
default:
PADDLE_THROW("Dynamic dimensions must have between [1, 9] dimensions.");
}
} }
/// @endcond template <int N>
Dim<N> make_dim(const int64_t* d) {
Dim<N> ret;
for (int i = 0; i < N; ++i) ret[i] = d[i];
return ret;
}
DDim make_ddim(std::initializer_list<int64_t> dims) { DDim make_ddim(std::initializer_list<int64_t> dims) {
DDim result(make_dim(0)); return DDim(dims.begin(), dims.size());
make_ddim(result, dims.begin(), dims.size());
return result;
} }
DDim make_ddim(const std::vector<int64_t>& dims) { DDim make_ddim(const std::vector<int64_t>& dims) {
DDim result(make_dim(0)); return DDim(dims.data(), dims.size());
make_ddim(result, &dims[0], dims.size());
return result;
} }
DDim make_ddim(const std::vector<int>& dims) { DDim make_ddim(const std::vector<int>& dims) {
std::vector<int64_t> res(dims.size()); return DDim(dims.data(), dims.size());
std::transform(dims.begin(), dims.end(), res.begin(),
[](int d) { return static_cast<int64_t>(d); });
return make_ddim(res);
} }
/// @cond HIDDEN struct DDimEqualityVisitor {
// XXX For some reason, putting this in an anonymous namespace causes errors explicit DDimEqualityVisitor(const int64_t* d) : d_(d) {}
class DynamicMutableIndexer : public boost::static_visitor<int64_t&> {
public:
explicit DynamicMutableIndexer(int idx) : idx_(idx) {}
template <int D>
int64_t& operator()(Dim<D>& dim) const {
return dim[idx_];
}
private:
int idx_;
};
class DynamicConstIndexer : public boost::static_visitor<int64_t> {
public:
explicit DynamicConstIndexer(int idx) : idx_(idx) {}
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) const { inline bool operator()(const Dim<D>& self) const {
return dim[idx_]; return UnrollCompare<D>::Run(self.data(), d_);
} }
private: const int64_t* d_;
int idx_;
}; };
/// @endcond bool DDim::operator==(const DDim& d) const {
return rank_ == d.rank_ && this->apply_visitor(DDimEqualityVisitor(d.data()));
int64_t& DDim::operator[](int idx) {
return boost::apply_visitor(DynamicMutableIndexer(idx), var);
}
int64_t DDim::operator[](int idx) const {
return boost::apply_visitor(DynamicConstIndexer(idx), var);
} }
int DDim::size() const { return arity(*this); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
bool DDim::operator==(DDim d) const {
if (var.which() != d.getVar().which()) {
return false;
} else {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
for (unsigned int i = 0; i < v1.size(); i++) { struct DDimPlusVisitor {
if (v1[i] != v2[i]) { explicit DDimPlusVisitor(const int64_t* d1, const int64_t* d2)
return false; : d1_(d1), d2_(d2) {}
}
}
return true; template <int D>
inline void operator()(Dim<D>& self) const {
UnrollAdd<D>::Run(d1_, d2_, self.data());
} }
}
bool DDim::operator!=(DDim d) const { return !(*this == d); }
DDim DDim::operator+(DDim d) const {
std::vector<int64_t> v1 = vectorize(*this);
std::vector<int64_t> v2 = vectorize(d);
std::vector<int64_t> v3; const int64_t* d1_;
const int64_t* d2_;
assert(v1.size() == v2.size()); };
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] + v2[i]);
}
return make_ddim(v3); DDim DDim::operator+(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimPlusVisitor(data(), d.data()));
return ret;
} }
DDim DDim::operator*(DDim d) const { struct DDimMulVisitor {
std::vector<int64_t> v1 = vectorize(*this); explicit DDimMulVisitor(const int64_t* d1, const int64_t* d2)
std::vector<int64_t> v2 = vectorize(d); : d1_(d1), d2_(d2) {}
std::vector<int64_t> v3; template <int D>
inline void operator()(Dim<D>& self) const {
assert(v1.size() == v2.size()); UnrollMul<D>::Run(d1_, d2_, self.data());
for (unsigned int i = 0; i < v1.size(); i++) {
v3.push_back(v1[i] * v2[i]);
} }
return make_ddim(v3); const int64_t* d1_;
const int64_t* d2_;
};
DDim DDim::operator*(const DDim& d) const {
PADDLE_ENFORCE(rank_ == d.rank_);
DDim ret;
ret.rank_ = rank_;
ret.apply_visitor(DDimMulVisitor(data(), d.data()));
return ret;
} }
int64_t get(const DDim& ddim, int idx) { return ddim[idx]; } int64_t get(const DDim& ddim, int idx) { return ddim[idx]; }
void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } // NOLINT
/// @cond HIDDEN
struct VectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
explicit VectorizeVisitor(std::vector<int64_t>& v) : vector(v) {}
template <typename T>
void operator()(const T& t) {
vector.push_back(t.head);
this->operator()(t.tail);
}
void operator()(const Dim<0>& t) {}
};
/// @endcond
std::vector<int64_t> vectorize(const DDim& ddim) { std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result; std::vector<int64_t> result(DDim::kMaxRank);
VectorizeVisitor visitor(result); for (int i = 0; i < ddim.size(); ++i) {
boost::apply_visitor(visitor, ddim); result[i] = ddim[i];
}
result.resize(ddim.size());
return result; return result;
} }
// NOTE: framework::vectorize converts to type int64_t // NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs. // which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) { std::vector<int> vectorize2int(const DDim& ddim) {
std::vector<int64_t> temp = vectorize(ddim); std::vector<int> result(DDim::kMaxRank);
std::vector<int> result(temp.begin(), temp.end()); for (int i = 0; i < ddim.size(); ++i) {
result[i] = ddim[i];
}
result.resize(ddim.size());
return result; return result;
} }
struct ProductVisitor : public boost::static_visitor<int64_t> { struct ProductVisitor {
template <int D> template <int D>
int64_t operator()(const Dim<D>& dim) { int64_t operator()(const Dim<D>& dim) {
return product(dim); return product(dim);
...@@ -220,65 +150,27 @@ struct ProductVisitor : public boost::static_visitor<int64_t> { ...@@ -220,65 +150,27 @@ struct ProductVisitor : public boost::static_visitor<int64_t> {
}; };
int64_t product(const DDim& ddim) { int64_t product(const DDim& ddim) {
ProductVisitor visitor; return ddim.apply_visitor(ProductVisitor());
return boost::apply_visitor(visitor, ddim);
} }
struct SliceVectorizeVisitor : public boost::static_visitor<> {
std::vector<int64_t>& vector;
int begin;
int end;
SliceVectorizeVisitor(std::vector<int64_t>& v, int b, int e)
: vector(v), begin(b), end(e) {
PADDLE_ENFORCE(begin < end,
"Begin index must be less than end index in ddim slice.");
PADDLE_ENFORCE(begin >= 0,
"Begin index can't be less than zero in ddim slice.");
}
template <int S>
void operator()(const Dim<S>& dim) {
if (begin == 0) {
vector.push_back(dim.head);
} else {
--begin;
}
--end;
if (end > 0) {
this->operator()(dim.tail);
}
}
void operator()(const Dim<0>& dim) {
PADDLE_ENFORCE(end == 0, "End index in ddim slice is out of bound.");
}
};
DDim slice_ddim(const DDim& dim, int begin, int end) { DDim slice_ddim(const DDim& dim, int begin, int end) {
std::vector<int64_t> vec; PADDLE_ENFORCE(begin < end,
vec.reserve(end - begin); "Begin index must be less than end index in ddim slice.");
SliceVectorizeVisitor visitor(vec, begin, end); PADDLE_ENFORCE(begin >= 0,
boost::apply_visitor(visitor, dim); "Begin index can't be less than zero in ddim slice.");
return make_ddim(vec); DDim ret;
} ret.rank_ = end - begin;
for (int i = 0; i < ret.rank_; ++i) {
/// \cond HIDDEN ret[i] = dim[i + begin];
struct ArityVisitor : boost::static_visitor<int> {
template <int D>
int operator()(Dim<D>) const {
return D;
} }
}; return ret;
}
/// \endcond
int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); } int arity(const DDim& d) { return d.size(); }
/// \cond HIDDEN /// \cond HIDDEN
struct DDimPrinter : boost::static_visitor<void> { struct DDimPrinter {
std::ostream& os; std::ostream& os;
explicit DDimPrinter(std::ostream& os_) : os(os_) {} explicit DDimPrinter(std::ostream& os_) : os(os_) {}
...@@ -291,15 +183,10 @@ struct DDimPrinter : boost::static_visitor<void> { ...@@ -291,15 +183,10 @@ struct DDimPrinter : boost::static_visitor<void> {
/// \endcond /// \endcond
std::ostream& operator<<(std::ostream& os, const DDim& ddim) { std::ostream& operator<<(std::ostream& os, const DDim& ddim) {
DDimPrinter printer(os); ddim.apply_visitor(DDimPrinter(os));
boost::apply_visitor(printer, ddim);
return os; return os;
} }
DDim::DDim(std::initializer_list<int64_t> init_list) {
*this = make_ddim(init_list);
}
DDim flatten_to_2d(const DDim& src, int num_col_dims) { DDim flatten_to_2d(const DDim& src, int num_col_dims) {
int rank = src.size(); int rank = src.size();
return make_ddim({product(slice_ddim(src, 0, num_col_dims)), return make_ddim({product(slice_ddim(src, 0, num_col_dims)),
...@@ -309,21 +196,23 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) { ...@@ -309,21 +196,23 @@ DDim flatten_to_2d(const DDim& src, int num_col_dims) {
DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); } DDim flatten_to_1d(const DDim& src) { return make_ddim({product(src)}); }
DDim stride(const DDim& ddim) { DDim stride(const DDim& ddim) {
std::vector<int64_t> strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = 1; strides[ddim.size() - 1] = 1;
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i + 1]; strides[i] = strides[i + 1] * ddim[i + 1];
} }
return framework::make_ddim(strides); return strides;
} }
DDim stride_numel(const framework::DDim& ddim) { DDim stride_numel(const framework::DDim& ddim) {
std::vector<int64_t> strides(ddim.size()); DDim strides;
strides.rank_ = ddim.size();
strides[ddim.size() - 1] = ddim[ddim.size() - 1]; strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) { for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i]; strides[i] = strides[i + 1] * ddim[i];
} }
return framework::make_ddim(strides); return strides;
} }
} // namespace framework } // namespace framework
......
...@@ -18,8 +18,6 @@ limitations under the License. */ ...@@ -18,8 +18,6 @@ limitations under the License. */
#include <stdexcept> #include <stdexcept>
#include <vector> #include <vector>
#include "paddle/fluid/framework/dim.h" #include "paddle/fluid/framework/dim.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/variant.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -29,51 +27,138 @@ namespace framework { ...@@ -29,51 +27,138 @@ namespace framework {
* *
* The number of dimensions must be between [1, 9]. * The number of dimensions must be between [1, 9].
*/ */
struct DDim { class DDim {
typedef boost::variant<Dim<0>, Dim<1>, Dim<2>, Dim<3>, Dim<4>, Dim<5>, Dim<6>, public:
Dim<7>, Dim<8>, Dim<9>> constexpr static int kMaxRank = 9;
DDimVar;
DDimVar var;
DDim() : var(Dim<1>()) {} DDim() : rank_(1) { dim_[0] = 0; }
DDim(const int* d, int n);
DDim(const int64_t* d, int n);
template <int D> template <int D>
explicit DDim(const Dim<D>& in) : var(in) {} /*implicit*/ DDim(const Dim<D>& in) : rank_(D) { // NOLINT
UnsafeCast<D>() = in;
}
/*implicit*/ DDim(std::initializer_list<int64_t> init_list); /*implicit*/ DDim(std::initializer_list<int64_t> init_list)
: DDim(init_list.begin(), init_list.size()) {}
template <int D> template <int D>
DDim& operator=(const Dim<D>& in) { inline DDim& operator=(const Dim<D>& in) {
var = in; rank_ = D;
UnsafeCast<D>() = in;
return *this; return *this;
} }
int64_t& operator[](int idx); inline int64_t& operator[](int idx) { return dim_[idx]; }
int64_t operator[](int idx) const;
template <typename Visitor> inline int64_t operator[](int idx) const { return dim_[idx]; }
typename Visitor::result_type apply_visitor(Visitor& visitor) {
return var.apply_visitor(visitor); inline int64_t& at(int idx) {
PADDLE_ENFORCE(idx >= 0 && idx < rank_);
return dim_[idx];
} }
template <typename Visitor> inline int64_t at(int idx) const {
typename Visitor::result_type apply_visitor(Visitor& visitor) const { PADDLE_ENFORCE(idx >= 0 && idx < rank_);
return var.apply_visitor(visitor); return dim_[idx];
} }
DDimVar getVar() { return var; } template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type apply_visitor(
Visitor&& visitor);
template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type apply_visitor(
Visitor&& visitor) const;
bool operator==(const DDim& d) const;
bool operator!=(const DDim& d) const;
DDim operator+(const DDim& d) const;
bool operator==(DDim d) const; DDim operator*(const DDim& d) const;
bool operator!=(DDim d) const; // Make DDim act like std::vector<int64_t>
using iterator = int64_t*;
using const_iterator = const int64_t*;
DDim operator+(DDim d) const; int64_t* data() { return dim_.data(); }
const int64_t* data() const { return dim_.data(); }
DDim operator*(DDim d) const; iterator begin() { return data(); }
const_iterator begin() const { return data(); }
iterator end() { return data() + rank_; }
const_iterator end() const { return data() + rank_; }
int size() const { return rank_; }
private:
template <int M>
inline Dim<M>& UnsafeCast() {
return const_cast<Dim<M>&>(const_cast<const DDim*>(this)->UnsafeCast<M>());
}
int size() const; template <int M>
inline const Dim<M>& UnsafeCast() const {
static_assert(M >= 0 && M <= kMaxRank, "Invalid rank");
auto* p = static_cast<const void*>(&dim_);
return *reinterpret_cast<const Dim<M>*>(p);
}
friend DDim slice_ddim(const DDim& dim, int begin, int end);
friend DDim stride(const DDim& ddim);
friend DDim stride_numel(const DDim& ddim);
Dim<kMaxRank> dim_;
int rank_;
}; };
#define PADDLE_VISIT_DDIM(rank) \
case rank: \
return visitor(UnsafeCast<rank>())
template <typename Visitor>
typename std::result_of<Visitor(Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
template <typename Visitor>
typename std::result_of<Visitor(const Dim<0>&)>::type DDim::apply_visitor(
Visitor&& visitor) const {
switch (rank_) {
PADDLE_VISIT_DDIM(0);
PADDLE_VISIT_DDIM(1);
PADDLE_VISIT_DDIM(2);
PADDLE_VISIT_DDIM(3);
PADDLE_VISIT_DDIM(4);
PADDLE_VISIT_DDIM(5);
PADDLE_VISIT_DDIM(6);
PADDLE_VISIT_DDIM(7);
PADDLE_VISIT_DDIM(8);
PADDLE_VISIT_DDIM(9);
default:
PADDLE_THROW("Invalid rank %d", rank_);
}
}
#undef PADDLE_VISIT_DDIM
/** /**
* \brief Make a DDim from std::vector<int64_t> * \brief Make a DDim from std::vector<int64_t>
* *
...@@ -92,7 +177,7 @@ DDim make_ddim(const std::vector<int>& dims); ...@@ -92,7 +177,7 @@ DDim make_ddim(const std::vector<int>& dims);
DDim make_ddim(std::initializer_list<int64_t> dims); DDim make_ddim(std::initializer_list<int64_t> dims);
int64_t get(const DDim& dim, int idx); int64_t get(const DDim& dim, int idx);
void set(DDim& dim, int idx, int val); void set(DDim& dim, int idx, int val); // NOLINT
std::vector<int64_t> vectorize(const DDim& ddim); std::vector<int64_t> vectorize(const DDim& ddim);
std::vector<int> vectorize2int(const DDim& ddim); std::vector<int> vectorize2int(const DDim& ddim);
...@@ -129,12 +214,3 @@ DDim stride(const DDim& ddim); ...@@ -129,12 +214,3 @@ DDim stride(const DDim& ddim);
DDim stride_numel(const DDim& ddim); DDim stride_numel(const DDim& ddim);
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace boost {
template <typename T>
T get(const paddle::framework::DDim& in) {
return boost::get<T>(in.var);
}
} // namespace boost
...@@ -16,328 +16,184 @@ ...@@ -16,328 +16,184 @@
#include <iostream> #include <iostream>
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <string>
#include <type_traits> #include <type_traits>
#include "paddle/fluid/framework/array.h"
#include "paddle/fluid/platform/assert.h" #include "paddle/fluid/platform/assert.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/hostdevice.h" #include "paddle/fluid/platform/hostdevice.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Statically sized, statically indexed dimension // Statically sized, statically indexed dimension
template <int i> template <int N>
struct Dim { class Dim : public Array<int64_t, N> {
static constexpr int dimensions = i; public:
static_assert(N >= 0, "N must be not less than 0");
template <typename... Args> static constexpr int kRank = N;
HOSTDEVICE Dim(int64_t _head, Args... _tail) : head(_head), tail(_tail...) { using BaseClass = Array<int64_t, N>;
static_assert(sizeof...(_tail) == i - 1,
"Dim initialized with the wrong number of parameters");
}
HOSTDEVICE inline Dim(int64_t head, const Dim<N - 1>& tail) {
Dim(int64_t _head, const Dim<i - 1>& _tail) : head(_head), tail(_tail) {} (*this)[0] = head;
new (this->GetMutable() + 1) Dim<N - 1>(tail);
}
HOSTDEVICE template <typename... Args>
Dim() : head(0), tail() {} HOSTDEVICE explicit Dim(int64_t head, Args... args)
: BaseClass(head, args...) {}
/** Construct a Dim from a linear index and size. Uses Fortran order /** Construct a Dim from a linear index and size. Uses Fortran order
* indexing. */ * indexing. */
HOSTDEVICE HOSTDEVICE Dim(int64_t idx, const Dim<N>& size);
Dim(int64_t idx, const Dim<i>& size)
: head(idx % size.head), tail(idx / size.head, size.tail) {}
/** Construct a Dim with each dimension set to the given index */ /** Construct a Dim with each dimension set to the given index */
HOSTDEVICE HOSTDEVICE explicit Dim(int64_t idx) { this->Fill(idx); }
Dim(int64_t idx) : head(idx), tail(idx) {}
HOSTDEVICE HOSTDEVICE Dim() = default;
bool operator==(const Dim<i>& o) const {
return (head == o.head) && (tail == o.tail);
}
HOSTDEVICE HOSTDEVICE int64_t* data() { return this->GetMutable(); }
bool operator!=(const Dim<i>& o) const { return !(*this == o); }
HOSTDEVICE HOSTDEVICE const int64_t* data() const { return this->Get(); }
int64_t& operator[](int idx);
HOSTDEVICE
int64_t operator[](int idx) const;
HOST std::string to_string() const; HOST std::string to_string() const;
int64_t head;
Dim<i - 1> tail;
};
// Base case specialization
template <>
struct Dim<0> {
static constexpr int dimensions = 0;
HOSTDEVICE
Dim(int64_t _head) {}
HOSTDEVICE
Dim() {}
HOSTDEVICE
Dim(int idx, const Dim<0>& size) {
#ifndef __CUDA_ARCH__
if (idx > 0) {
throw std::invalid_argument("Index out of range.");
}
#else
PADDLE_ASSERT(idx == 0);
#endif
}
HOSTDEVICE
bool operator==(const Dim<0>& o) const { return true; }
HOSTDEVICE
bool operator!=(const Dim<0>& o) const { return false; }
HOSTDEVICE
int64_t& operator[](int idx);
HOSTDEVICE
int64_t operator[](int idx) const;
}; };
namespace { namespace detail {
template <int kStart, int kEnd, bool kStop>
// Helper for accessing Dim classes struct FortranOrderIndexingConstructorFunctor {
template <int i> HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx,
struct DimGetter { int64_t* out) {
// Return a copy if Dim is const out[kStart] = (*idx) % in[kStart];
template <typename D> (*idx) /= in[kStart];
HOSTDEVICE static int64_t impl(const D& d) { FortranOrderIndexingConstructorFunctor<kStart + 1, kEnd,
return DimGetter<i - 1>::impl(d.tail); kStart + 1 == kEnd>::Run(in, idx,
} out);
// Return a reference if Dim is mutable
template <typename D>
HOSTDEVICE static int64_t& impl(D& d) {
return DimGetter<i - 1>::impl(d.tail);
} }
}; };
// Eureka! We found the element! template <int kStart, int kEnd>
template <> struct FortranOrderIndexingConstructorFunctor<kStart, kEnd, true> {
struct DimGetter<0> { HOSTDEVICE inline static void Run(const int64_t* in, int64_t* idx,
// Return a copy if Dim is const int64_t* out) {}
template <typename D>
HOSTDEVICE static int64_t impl(const D& d) {
return d.head;
}
// Return a reference if Dim is mutable
template <typename D>
HOSTDEVICE static int64_t& impl(D& d) {
return d.head;
}
}; };
} // namespace detail
template <int D> template <int N>
HOSTDEVICE int64_t& indexer(Dim<D>& dim, int idx) { HOSTDEVICE Dim<N>::Dim(int64_t idx, const Dim<N>& size) {
#ifndef __CUDA_ARCH__ detail::FortranOrderIndexingConstructorFunctor<0, N, N == 0>::Run(
if (idx < 0) { size.Get(), &idx, this->GetMutable());
throw std::invalid_argument("Tried to access a negative dimension");
}
#else
PADDLE_ASSERT(idx >= 0);
#endif
if (idx == 0) {
return dim.head;
}
return indexer(dim.tail, idx - 1);
} }
template <> template <int idx, int N>
HOSTDEVICE int64_t& indexer<0>(Dim<0>& dim, int idx) { HOSTDEVICE inline int64_t get(const Dim<N>& dim) {
#ifndef __CUDA_ARCH__ return dim[idx];
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(false);
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
int64_t head = 0;
#else
static int64_t head = 0;
#endif
return head;
#endif
} }
template <int D> template <int idx, int N>
HOSTDEVICE int64_t indexer(const Dim<D>& dim, int idx) { HOSTDEVICE inline int64_t& get(Dim<N>& dim) { // NOLINT
#ifndef __CUDA_ARCH__ return dim[idx];
if (idx < 0) {
throw std::invalid_argument("Tried to access a negative dimension");
}
#else
PADDLE_ASSERT(idx >= 0);
#endif
if (idx == 0) {
return dim.head;
}
return indexer(dim.tail, idx - 1);
}
template <>
HOSTDEVICE int64_t indexer<0>(const Dim<0>& dim, int idx) {
#ifndef __CUDA_ARCH__
throw std::invalid_argument("Invalid index");
#else
PADDLE_ASSERT(false);
#if CUDA_VERSION < 8000
// On CUDA versions previous to 8.0, only __shared__ variables
// could be declared as static in the device code.
int64_t head = 0;
#else
static int64_t head = 0;
#endif
return head;
#endif
}
} // namespace
// Static access to constant Dim
template <int i, int l>
HOSTDEVICE int64_t get(const Dim<l>& d) {
return DimGetter<i>::impl(d);
}
// Static access to mutable Dim
template <int i, int l>
HOSTDEVICE int64_t& get(Dim<l>& d) {
return DimGetter<i>::impl(d);
}
// Dynamic access to constant Dim
template <int l>
HOSTDEVICE int64_t Dim<l>::operator[](int i) const {
return indexer(*this, i);
} }
// Dynamic access to mutable Dim template <int N>
template <int l> HOSTDEVICE inline int64_t get(const Dim<N>& dim, int idx) {
HOSTDEVICE int64_t& Dim<l>::operator[](int i) { return dim[idx];
return indexer(*this, i);
} }
// Dynamic access to constant Dim template <int N>
inline HOSTDEVICE int64_t Dim<0>::operator[](int i) const { HOSTDEVICE inline int64_t& get(Dim<N>& dim, int idx) { // NOLINT
return indexer(*this, i); return dim[idx];
}
// Dynamic access to mutable Dim
inline HOSTDEVICE int64_t& Dim<0>::operator[](int i) {
return indexer(*this, i);
}
// Dynamic access to constant Dim
// without std::enable_if will try to instantiate this on get<0>(d)
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t>::type get(const Dim<l>& d,
int i) {
return d[i];
}
// Dynamic access to mutable Dim
template <int l>
HOSTDEVICE typename std::enable_if<(l > 0), int64_t&>::type get(Dim<l>& d,
int i) {
return d[i];
} }
// Dot product of two dims // Dot product of two dims
template <int i> template <int N>
HOSTDEVICE int64_t linearize(const Dim<i>& a, const Dim<i>& b) { HOSTDEVICE inline int64_t linearize(const Dim<N>& a, const Dim<N>& b) {
return a.head * b.head + linearize(a.tail, b.tail); return UnrollProduct<N>::Run(a.Get(), b.Get());
}
// Base case dot product of two Dims
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t linearize(const Dim<0>& a, const Dim<0>& b) {
return 0;
} }
// Product of a Dim // Product of a Dim
template <int i> template <int N>
HOSTDEVICE int64_t product(const Dim<i>& a, int prod = 1) { HOSTDEVICE inline int64_t product(const Dim<N>& a) {
return prod * a.head * product(a.tail); return UnrollProduct<N>::Run(a.Get());
}
// Base case product of a Dim
// Notice it is inline because it is no longer a template
template <>
HOSTDEVICE inline int64_t product(const Dim<0>& a, int prod) {
return prod;
} }
// Is 0 <= idx_i < size_i for all i? // Is 0 <= idx_i < size_i for all i?
template <int i> namespace detail {
HOSTDEVICE bool contained(const Dim<i>& idx, const Dim<i>& size) { template <int kStart, int kEnd, bool kStop>
return ((0 <= idx.head) && (idx.head < size.head) && struct ContainedFunctor {
contained(idx.tail, size.tail)); HOSTDEVICE static inline bool Run(const int64_t* idx, const int64_t* size) {
} return (idx[kStart] >= 0 && idx[kStart] < size[kStart]) &&
ContainedFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(idx,
size);
}
};
template <int kStart, int kEnd>
struct ContainedFunctor<kStart, kEnd, true> {
HOSTDEVICE static constexpr inline bool Run(const int64_t* idx,
const int64_t* size) {
return true;
}
};
} // namespace detail
// Base case of is 0 <= idx_i < size_i ? template <int N>
// Notice it is inline because it is no longer a template HOSTDEVICE inline bool contained(const Dim<N>& idx, const Dim<N>& size) {
template <> return detail::ContainedFunctor<0, N, N == 0>::Run(idx.Get(), size.Get());
HOSTDEVICE inline bool contained(const Dim<0>& idx, const Dim<0>& size) {
return true;
} }
/** /**
* \brief Compute exclusive prefix-multiply of a Dim. * \brief Compute exclusive prefix-multiply of a Dim.
*/ */
template <int i> namespace detail {
HOSTDEVICE Dim<i> ex_prefix_mul(const Dim<i>& src, int mul = 1) { template <int kStart, int kEnd, bool kStop>
return Dim<i>(mul, ex_prefix_mul(src.tail, mul * src.head)); struct ExPrefixMulFunctor {
} HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) {
kStart == 0 ? out[kStart] = 1 : out[kStart] =
out[kStart - 1] * in[kStart - 1];
detail::ExPrefixMulFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(in,
out);
}
};
template <int kStart, int kEnd>
struct ExPrefixMulFunctor<kStart, kEnd, true> {
HOSTDEVICE static inline void Run(const int64_t* in, int64_t* out) {}
};
} // namespace detail
///\cond HIDDEN template <int N>
// Base case of ex_prefix_mul HOSTDEVICE inline Dim<N> ex_prefix_mul(const Dim<N>& src) {
// Notice it is inline because it is no longer a template Dim<N> ret;
template <> detail::ExPrefixMulFunctor<0, N, N == 0>::Run(src.Get(), ret.GetMutable());
HOSTDEVICE inline Dim<0> ex_prefix_mul(const Dim<0>& src, int mul) { return ret;
return Dim<0>();
} }
///\endcond
/** /**
* Add two dimensions together * Add two dimensions together
*/ */
template <int i> template <int N>
HOSTDEVICE Dim<i> dim_plus(const Dim<i>& a, const Dim<i>& b) { HOSTDEVICE inline Dim<N> dim_plus(const Dim<N>& a, const Dim<N>& b) {
return Dim<i>(a.head + b.head, dim_plus(a.tail, b.tail)); Dim<N> ret;
UnrollAdd<N>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
} }
// Base case template <int N>
template <> HOSTDEVICE inline Dim<N> operator+(const Dim<N>& lhs, const Dim<N>& rhs) {
HOSTDEVICE inline Dim<0> dim_plus(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
}
template <int i>
HOSTDEVICE Dim<i> operator+(const Dim<i>& lhs, const Dim<i>& rhs) {
return dim_plus(lhs, rhs); return dim_plus(lhs, rhs);
} }
/** /**
* Multiply two dimensions together * Multiply two dimensions together
*/ */
template <int i> template <int N>
HOSTDEVICE Dim<i> dim_mult(const Dim<i>& a, const Dim<i>& b) { HOSTDEVICE inline Dim<N> dim_mult(const Dim<N>& a, const Dim<N>& b) {
return Dim<i>(a.head * b.head, dim_mult(a.tail, b.tail)); Dim<N> ret;
} UnrollMul<N>::Run(a.Get(), b.Get(), ret.GetMutable());
return ret;
// Base case
template <>
HOSTDEVICE inline Dim<0> dim_mult(const Dim<0>& a, const Dim<0>& b) {
return Dim<0>();
} }
template <int i> template <int i>
...@@ -354,23 +210,32 @@ HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) { ...@@ -354,23 +210,32 @@ HOSTDEVICE Dim<i> operator*(const Dim<i>& lhs, const Dim<i>& rhs) {
* \return Dim object the same size as \p size with normalized strides * \return Dim object the same size as \p size with normalized strides
* *
*/ */
namespace detail {
template <int kStart, int kEnd, bool kStop>
struct NormalizeStridesFunctor {
HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride,
int64_t* ret) {
ret[kStart] = (size[kStart] == 1 ? 0 : stride[kStart]);
NormalizeStridesFunctor<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(
size, stride, ret);
}
};
template <int i> template <int kStart, int kEnd>
HOSTDEVICE Dim<i> normalize_strides(const Dim<i>& size, const Dim<i>& stride) { struct NormalizeStridesFunctor<kStart, kEnd, true> {
int norm_stride = size.head == 1 ? 0 : stride.head; HOSTDEVICE static void Run(const int64_t* size, const int64_t* stride,
return Dim<i>(norm_stride, normalize_strides(size.tail, stride.tail)); int64_t* ret) {}
} };
} // namespace detail
///\cond HIDDEN
template <> template <int N>
HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size, HOSTDEVICE Dim<N> normalize_strides(const Dim<N>& size, const Dim<N>& stride) {
const Dim<0>& stride) { Dim<N> ret;
return Dim<0>(); detail::NormalizeStridesFunctor<0, N, N == 0>::Run(size.Get(), stride.Get(),
ret.GetMutable());
return ret;
} }
///\endcond
/** /**
* Helper function to create a Dim * Helper function to create a Dim
* *
...@@ -379,25 +244,17 @@ HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size, ...@@ -379,25 +244,17 @@ HOSTDEVICE inline Dim<0> normalize_strides(const Dim<0>& size,
*/ */
template <typename... Args> template <typename... Args>
HOSTDEVICE Dim<sizeof...(Args)> make_dim(Args... idxes) { HOSTDEVICE inline Dim<sizeof...(Args)> make_dim(Args... idxes) {
return Dim<sizeof...(Args)>(idxes...); return Dim<sizeof...(Args)>(idxes...);
} }
// Allows us to output a Dim // Allows us to output a Dim
// XXX For some reason, overloading fails to resolve this correctly template <int N>
template <int i> inline std::ostream& operator<<(std::ostream& os, const Dim<N>& d) {
typename std::enable_if<(i > 1), std::ostream&>::type operator<<( os << d[0];
std::ostream& os, const Dim<i>& d) { for (int i = 1; i < N; ++i) {
os << d.head << ", " << d.tail; os << ", " << d[i];
return os; }
}
// Base case that allows us to output a Dim
// XXX I wish this could be an overload instead of a template
template <int i>
typename std::enable_if<(i == 1), std::ostream&>::type operator<<(
std::ostream& os, const Dim<i>& d) {
os << d.head;
return os; return os;
} }
...@@ -405,25 +262,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) { ...@@ -405,25 +262,23 @@ inline std::ostream& operator<<(std::ostream& os, const Dim<0>& d) {
return os; return os;
} }
template <int i> template <int N>
HOST std::string Dim<i>::to_string() const { HOST std::string Dim<N>::to_string() const {
std::stringstream stream; std::stringstream stream;
stream << *this; stream << *this;
return stream.str(); return stream.str();
} }
template <int D> template <int N>
HOSTDEVICE Dim<D> linear_to_dimension(int linear_index, Dim<D> extents) { HOSTDEVICE Dim<N> linear_to_dimension(int linear_index, const Dim<N>& extents) {
Dim<D> result; Dim<N> result;
for (int i = 0; i < D - 1; ++i) { for (int i = 0; i < N - 1; ++i) {
result[i] = linear_index % extents[i]; result[i] = linear_index % extents[i];
linear_index /= extents[i]; linear_index /= extents[i];
} }
result[D - 1] = linear_index; result[N - 1] = linear_index;
return result; return result;
} }
......
...@@ -62,7 +62,7 @@ static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) { ...@@ -62,7 +62,7 @@ static DLDataType GetDLDataTypeFromTypeIndex(const std::type_index &type) {
struct DLContextVisitor : public boost::static_visitor<::DLContext> { struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CPUPlace &place) const { inline ::DLContext operator()(const platform::CPUPlace &place) const {
DLContext ctx; ::DLContext ctx;
ctx.device_type = kDLCPU; ctx.device_type = kDLCPU;
ctx.device_id = 0; ctx.device_id = 0;
return ctx; return ctx;
...@@ -70,7 +70,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -70,7 +70,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CUDAPlace &place) const { inline ::DLContext operator()(const platform::CUDAPlace &place) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DLContext ctx; ::DLContext ctx;
ctx.device_type = kDLGPU; ctx.device_type = kDLGPU;
ctx.device_id = place.device; ctx.device_id = place.device;
return ctx; return ctx;
...@@ -81,7 +81,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> { ...@@ -81,7 +81,7 @@ struct DLContextVisitor : public boost::static_visitor<::DLContext> {
inline ::DLContext operator()(const platform::CUDAPinnedPlace &place) const { inline ::DLContext operator()(const platform::CUDAPinnedPlace &place) const {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
DLContext ctx; ::DLContext ctx;
ctx.device_type = kDLCPUPinned; ctx.device_type = kDLCPUPinned;
ctx.device_id = 0; ctx.device_id = 0;
return ctx; return ctx;
......
...@@ -38,7 +38,7 @@ class DLPackTensor { ...@@ -38,7 +38,7 @@ class DLPackTensor {
// The shape in DLTensor is defined as int64_t* // The shape in DLTensor is defined as int64_t*
// Add this member to make TVMTensor init without heap allocation // Add this member to make TVMTensor init without heap allocation
ShapeType shape_[9]; ShapeType shape_[DDim::kMaxRank];
}; };
} // namespace framework } // namespace framework
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace framework {
namespace detail {
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollFillConstant {
template <typename T>
HOSTDEVICE inline static void Run(T *data, T val) {
data[kStart] = val;
UnrollFillConstant<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(data, val);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollFillConstant<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(T *data, T val) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollAssign {
template <typename Tin, typename Tout>
HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) {
d2[kStart] = static_cast<Tout>(d1[kStart]);
UnrollAssign<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollAssign<kStart, kEnd, true> {
template <typename Tin, typename Tout>
HOSTDEVICE inline static void Run(const Tin *d1, Tout *d2) {}
};
template <typename T, size_t kStart, size_t kEnd, bool kStop>
struct UnrollVarArgsAssign {
template <typename... Args>
HOSTDEVICE inline static void Run(T *d, T val, Args... args) {
static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument");
d[kStart] = val;
UnrollVarArgsAssign<T, kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d,
args...);
}
};
template <typename T, size_t kStart, size_t kEnd>
struct UnrollVarArgsAssign<T, kStart, kEnd, true> {
HOSTDEVICE inline static void Run(T *d) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollCompare {
template <typename T>
HOSTDEVICE inline static bool Run(const T *d1, const T *d2) {
return d1[kStart] == d2[kStart] &&
UnrollCompare<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollCompare<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline constexpr static bool Run(const T *d1, const T *d2) {
return true;
}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollAdd {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {
d3[kStart] = d1[kStart] + d2[kStart];
UnrollAdd<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2, d3);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollAdd<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollMul {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {
d3[kStart] = d1[kStart] * d2[kStart];
UnrollMul<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2, d3);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollMul<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline static void Run(const T *d1, const T *d2, T *d3) {}
};
template <size_t kStart, size_t kEnd, bool kStop>
struct UnrollProduct {
template <typename T>
HOSTDEVICE inline static T Run(const T *d) {
return d[kStart] *
UnrollProduct<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d);
}
template <typename T>
HOSTDEVICE inline static T Run(const T *d1, const T *d2) {
return d1[kStart] * d2[kStart] +
UnrollProduct<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct UnrollProduct<kStart, kEnd, true> {
template <typename T>
HOSTDEVICE inline constexpr static T Run(const T *d) {
return 1;
}
template <typename T>
HOSTDEVICE inline constexpr static T Run(const T *d1, const T *d2) {
return 0;
}
};
} // namespace detail
template <size_t N>
using UnrollFillConstant = detail::UnrollFillConstant<0, N, N == 0>;
template <size_t N>
using UnrollAssign = detail::UnrollAssign<0, N, N == 0>;
template <typename T, size_t N>
using UnrollVarArgsAssign = detail::UnrollVarArgsAssign<T, 0, N, N == 0>;
template <size_t N>
using UnrollCompare = detail::UnrollCompare<0, N, N == 0>;
template <size_t N>
using UnrollAdd = detail::UnrollAdd<0, N, N == 0>;
template <size_t N>
using UnrollMul = detail::UnrollMul<0, N, N == 0>;
template <size_t N>
using UnrollProduct = detail::UnrollProduct<0, N, N == 0>;
} // namespace framework
} // namespace paddle
...@@ -86,8 +86,6 @@ class UnaryLogicalOpInferShape : public framework::InferShapeBase { ...@@ -86,8 +86,6 @@ class UnaryLogicalOpInferShape : public framework::InferShapeBase {
OpComment comment; OpComment comment;
PADDLE_ENFORCE(context->HasInput("X"), PADDLE_ENFORCE(context->HasInput("X"),
"Input(X) of %s operator must not be null", comment.type); "Input(X) of %s operator must not be null", comment.type);
auto dim_x = context->GetInputDim("X");
context->SetOutputDim("Out", context->GetInputDim("X")); context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out"); context->ShareLoD("X", "Out");
} }
......
...@@ -68,7 +68,6 @@ void CropFunction(const framework::ExecutionContext& context) { ...@@ -68,7 +68,6 @@ void CropFunction(const framework::ExecutionContext& context) {
} }
out->mutable_data<T>(out_dims, context.GetPlace()); out->mutable_data<T>(out_dims, context.GetPlace());
auto x_stride = framework::stride(x->dims()); auto x_stride = framework::stride(x->dims());
auto out_stride = framework::stride(out->dims());
auto offsets = GetOffsets(context); auto offsets = GetOffsets(context);
int64_t offset = 0; int64_t offset = 0;
for (size_t i = 0; i < offsets.size(); ++i) { for (size_t i = 0; i < offsets.size(); ++i) {
......
...@@ -378,7 +378,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> { ...@@ -378,7 +378,6 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel<T> {
->GetMutable<CudnnRNNCache>(); ->GetMutable<CudnnRNNCache>();
auto input_dims = input->dims(); auto input_dims = input->dims();
auto weight_dims = weight->dims();
auto init_h_dims = init_h->dims(); auto init_h_dims = init_h->dims();
auto init_c_dims = init_c->dims(); auto init_c_dims = init_c->dims();
in_grad->mutable_data<T>(ctx.GetPlace()); in_grad->mutable_data<T>(ctx.GetPlace());
......
...@@ -27,8 +27,8 @@ struct StridedMemcpyFunctor; ...@@ -27,8 +27,8 @@ struct StridedMemcpyFunctor;
template <typename T> template <typename T>
struct StridedMemcpyFunctor<T, 0> { struct StridedMemcpyFunctor<T, 0> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src, void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<0> src_stride, framework::Dim<0> dst_dim, const int64_t* src_stride, const int64_t* dst_dim,
framework::Dim<0> dst_stride, T* dst) const { const int64_t* dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace(); auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
...@@ -50,18 +50,18 @@ struct StridedMemcpyFunctor<T, 0> { ...@@ -50,18 +50,18 @@ struct StridedMemcpyFunctor<T, 0> {
template <typename T> template <typename T>
struct StridedMemcpyFunctor<T, 1> { struct StridedMemcpyFunctor<T, 1> {
void operator()(const platform::DeviceContext& dev_ctx, const T* src, void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<1> src_stride, framework::Dim<1> dst_dim, const int64_t* src_stride, const int64_t* dst_dim,
framework::Dim<1> dst_stride, T* dst) const { const int64_t* dst_stride, T* dst) const {
auto place = dev_ctx.GetPlace(); auto place = dev_ctx.GetPlace();
if (platform::is_cpu_place(place)) { if (platform::is_cpu_place(place)) {
auto& cpu_place = boost::get<platform::CPUPlace>(place); auto& cpu_place = boost::get<platform::CPUPlace>(place);
memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim.head); memory::Copy(cpu_place, dst, cpu_place, src, sizeof(T) * dst_dim[0]);
} else { } else {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
auto& gpu_place = boost::get<platform::CUDAPlace>(place); auto& gpu_place = boost::get<platform::CUDAPlace>(place);
auto& cuda_ctx = auto& cuda_ctx =
reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx); reinterpret_cast<const platform::CUDADeviceContext&>(dev_ctx);
memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim.head, memory::Copy(gpu_place, dst, gpu_place, src, sizeof(T) * dst_dim[0],
cuda_ctx.stream()); cuda_ctx.stream());
#else #else
PADDLE_THROW("Paddle is not compiled with GPU"); PADDLE_THROW("Paddle is not compiled with GPU");
...@@ -73,19 +73,19 @@ struct StridedMemcpyFunctor<T, 1> { ...@@ -73,19 +73,19 @@ struct StridedMemcpyFunctor<T, 1> {
template <typename T, int Rank> template <typename T, int Rank>
struct StridedMemcpyFunctor { struct StridedMemcpyFunctor {
void operator()(const platform::DeviceContext& dev_ctx, const T* src, void operator()(const platform::DeviceContext& dev_ctx, const T* src,
framework::Dim<Rank> src_stride, framework::Dim<Rank> dst_dim, const int64_t* src_stride, const int64_t* dst_dim,
framework::Dim<Rank> dst_stride, T* dst) const { const int64_t* dst_stride, T* dst) const {
for (int64_t i = 0; i < dst_dim.head; ++i) { for (int64_t i = 0; i < dst_dim[0]; ++i) {
StridedMemcpyFunctor<T, Rank - 1> func; StridedMemcpyFunctor<T, Rank - 1> func;
func(dev_ctx, src, src_stride.tail, dst_dim.tail, dst_stride.tail, dst); func(dev_ctx, src, src_stride + 1, dst_dim + 1, dst_stride + 1, dst);
src += src_stride.head; src += src_stride[0];
dst += dst_stride.head; dst += dst_stride[0];
} }
} }
}; };
template <typename T> template <typename T>
struct StridedCopyDimVisitor : public boost::static_visitor<void> { struct StridedCopyDimVisitor {
StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src, StridedCopyDimVisitor(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& src_stride, const framework::DDim& src_stride,
const framework::DDim& dst_stride, T* dst) const framework::DDim& dst_stride, T* dst)
...@@ -95,13 +95,11 @@ struct StridedCopyDimVisitor : public boost::static_visitor<void> { ...@@ -95,13 +95,11 @@ struct StridedCopyDimVisitor : public boost::static_visitor<void> {
dst_stride_(dst_stride), dst_stride_(dst_stride),
dst_(dst) {} dst_(dst) {}
template <typename Dim> template <int D>
void operator()(Dim dst_dim) const { void operator()(const framework::Dim<D>& dst_dim) const {
Dim src_stride = boost::get<Dim>(src_stride_); StridedMemcpyFunctor<T, D> functor;
Dim dst_stride = boost::get<Dim>(dst_stride_); functor(dev_ctx_, src_, src_stride_.data(), dst_dim.data(),
constexpr int dim = Dim::dimensions; dst_stride_.data(), dst_);
StridedMemcpyFunctor<T, dim> functor;
functor(dev_ctx_, src_, src_stride, dst_dim, dst_stride, dst_);
} }
const platform::DeviceContext& dev_ctx_; const platform::DeviceContext& dev_ctx_;
......
...@@ -64,8 +64,6 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel { ...@@ -64,8 +64,6 @@ class GenerateProposalLabelsOp : public framework::OperatorWithKernel {
"Output(BboxOutsideWeights) of RpnTargetAssignOp should not be null"); "Output(BboxOutsideWeights) of RpnTargetAssignOp should not be null");
auto rpn_rois_dims = ctx->GetInputDim("RpnRois"); auto rpn_rois_dims = ctx->GetInputDim("RpnRois");
auto gt_classes_dims = ctx->GetInputDim("GtClasses");
auto is_crowd_dims = ctx->GetInputDim("IsCrowd");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto im_info_dims = ctx->GetInputDim("ImInfo"); auto im_info_dims = ctx->GetInputDim("ImInfo");
......
...@@ -53,12 +53,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel { ...@@ -53,12 +53,6 @@ class GenerateProposalsOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(ctx->HasInput("Variances"), PADDLE_ENFORCE(ctx->HasInput("Variances"),
"Input(Variances) shouldn't be null."); "Input(Variances) shouldn't be null.");
auto scores_dims = ctx->GetInputDim("Scores");
auto bbox_deltas_dims = ctx->GetInputDim("BboxDeltas");
auto im_info_dims = ctx->GetInputDim("ImInfo");
auto anchors_dims = ctx->GetInputDim("Anchors");
auto variances_dims = ctx->GetInputDim("Variances");
ctx->SetOutputDim("RpnRois", {-1, 4}); ctx->SetOutputDim("RpnRois", {-1, 4});
ctx->SetOutputDim("RpnRoiProbs", {-1, 1}); ctx->SetOutputDim("RpnRoiProbs", {-1, 1});
} }
......
...@@ -58,7 +58,6 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel { ...@@ -58,7 +58,6 @@ class RpnTargetAssignOp : public framework::OperatorWithKernel {
auto anchor_dims = ctx->GetInputDim("Anchor"); auto anchor_dims = ctx->GetInputDim("Anchor");
auto gt_boxes_dims = ctx->GetInputDim("GtBoxes"); auto gt_boxes_dims = ctx->GetInputDim("GtBoxes");
auto is_crowd_dims = ctx->GetInputDim("IsCrowd");
auto im_info_dims = ctx->GetInputDim("ImInfo"); auto im_info_dims = ctx->GetInputDim("ImInfo");
PADDLE_ENFORCE_EQ(anchor_dims.size(), 2, PADDLE_ENFORCE_EQ(anchor_dims.size(), 2,
"The rank of Input(Anchor) must be 2."); "The rank of Input(Anchor) must be 2.");
......
...@@ -178,7 +178,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { ...@@ -178,7 +178,6 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(), PADDLE_ENFORCE_GE(x_dims.size(), y_dims.size(),
"Rank of first input must >= rank of second input."); "Rank of first input must >= rank of second input.");
......
...@@ -77,7 +77,6 @@ class ExpandKernel : public framework::OpKernel<T> { ...@@ -77,7 +77,6 @@ class ExpandKernel : public framework::OpKernel<T> {
auto& expand_times = context.Attr<std::vector<int>>("expand_times"); auto& expand_times = context.Attr<std::vector<int>>("expand_times");
auto* out0 = context.Output<Tensor>("Out"); auto* out0 = context.Output<Tensor>("Out");
Eigen::DSizes<int, Rank> bcast_dims; Eigen::DSizes<int, Rank> bcast_dims;
auto x_dims = in0->dims();
for (size_t i = 0; i < expand_times.size(); ++i) { for (size_t i = 0; i < expand_times.size(); ++i) {
bcast_dims[i] = expand_times[i]; bcast_dims[i] = expand_times[i];
} }
......
...@@ -148,7 +148,6 @@ class FCOpKernel : public framework::OpKernel<T> { ...@@ -148,7 +148,6 @@ class FCOpKernel : public framework::OpKernel<T> {
auto w = ctx.Input<Tensor>("W"); auto w = ctx.Input<Tensor>("W");
auto bias = ctx.Input<Tensor>("Bias"); auto bias = ctx.Input<Tensor>("Bias");
auto output = ctx.Output<Tensor>("Out"); auto output = ctx.Output<Tensor>("Out");
auto in_dims = input->dims();
auto w_dims = w->dims(); auto w_dims = w->dims();
auto out_dims = output->dims(); auto out_dims = output->dims();
int M = framework::product(out_dims) / out_dims[out_dims.size() - 1]; int M = framework::product(out_dims) / out_dims[out_dims.size() - 1];
......
...@@ -242,15 +242,15 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> { ...@@ -242,15 +242,15 @@ class FusedEmbeddingFCLSTMKernel : public framework::OpKernel<T> {
bool is_reverse = ctx.Attr<bool>("is_reverse"); \ bool is_reverse = ctx.Attr<bool>("is_reverse"); \
bool use_peepholes = ctx.Attr<bool>("use_peepholes"); bool use_peepholes = ctx.Attr<bool>("use_peepholes");
#define INIT_BASE_SIZES \ #define INIT_BASE_SIZES \
auto ids_dims = ids->dims(); /* T x M*/ \ auto ids_dims = ids->dims(); /* T x M*/ \
auto ids_numel = ids->numel(); /* T x 1*/ \ auto ids_numel = framework::product(ids_dims); /* T x 1*/ \
auto wh_dims = wh->dims(); /* D x 4D*/ \ auto wh_dims = wh->dims(); /* D x 4D*/ \
const int D = wh_dims[0]; \ const int D = wh_dims[0]; \
const int D2 = D * 2; \ const int D2 = D * 2; \
const int D3 = D * 3; \ const int D3 = D * 3; \
int64_t row_number = embeddings->dims()[0]; \ int64_t row_number = embeddings->dims()[0]; \
int64_t row_width = embeddings->dims()[1]; \ int64_t row_width = embeddings->dims()[1]; \
const int D4 = wh_dims[1]; const int D4 = wh_dims[1];
#define INIT_BASE_INPUT_DATAS \ #define INIT_BASE_INPUT_DATAS \
......
...@@ -88,7 +88,6 @@ class HingeLossGradOp : public framework::OperatorWithKernel { ...@@ -88,7 +88,6 @@ class HingeLossGradOp : public framework::OperatorWithKernel {
"Input(Logits@GRAD) should not be null."); "Input(Logits@GRAD) should not be null.");
auto pred_dims = ctx->GetInputDim("Logits"); auto pred_dims = ctx->GetInputDim("Logits");
auto lab_dims = ctx->GetInputDim("Labels");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss")); auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims); PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
......
...@@ -92,7 +92,6 @@ class LogLossGradOp : public framework::OperatorWithKernel { ...@@ -92,7 +92,6 @@ class LogLossGradOp : public framework::OperatorWithKernel {
"Output(Predicted@GRAD) should not be null."); "Output(Predicted@GRAD) should not be null.");
auto pred_dims = ctx->GetInputDim("Predicted"); auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss")); auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims); PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
......
...@@ -37,9 +37,6 @@ void Transpose<DeviceContext, T, Rank>::operator()( ...@@ -37,9 +37,6 @@ void Transpose<DeviceContext, T, Rank>::operator()(
for (int i = 0; i < Rank; i++) { for (int i = 0; i < Rank; i++) {
permute[i] = axis[i]; permute[i] = axis[i];
} }
auto in_dim = in.dims();
auto out_dim = out->dims();
auto eigen_in = framework::EigenTensor<T, Rank>::From(in); auto eigen_in = framework::EigenTensor<T, Rank>::From(in);
auto eigen_out = framework::EigenTensor<T, Rank>::From(*out); auto eigen_out = framework::EigenTensor<T, Rank>::From(*out);
auto* dev = context.eigen_device(); auto* dev = context.eigen_device();
......
...@@ -76,7 +76,6 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> { ...@@ -76,7 +76,6 @@ class SoftmaxFunctor<DeviceContext, float, true, enable_if_CPU<DeviceContext>> {
void operator()(const DeviceContext& context, const framework::Tensor* X, void operator()(const DeviceContext& context, const framework::Tensor* X,
framework::Tensor* Y) { framework::Tensor* Y) {
auto in_dims = X->dims(); auto in_dims = X->dims();
auto out_dims = Y->dims();
const float* in_data = X->data<float>(); const float* in_data = X->data<float>();
float* out_data = Y->data<float>(); float* out_data = Y->data<float>();
const int kBatchDim = 0; const int kBatchDim = 0;
......
...@@ -87,7 +87,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { ...@@ -87,7 +87,6 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
"Input(Out@Grad) must not be null."); "Input(Out@Grad) must not be null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
auto intermediate_dims = ctx->GetInputDim("IntermediateVal"); auto intermediate_dims = ctx->GetInputDim("IntermediateVal");
auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out"));
......
...@@ -146,12 +146,6 @@ class MulGradOp : public framework::OperatorWithKernel { ...@@ -146,12 +146,6 @@ class MulGradOp : public framework::OperatorWithKernel {
"Input(Out@GRAD) should not be null"); "Input(Out@GRAD) should not be null");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y"); auto y_dims = ctx->GetInputDim("Y");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_mat_dims = framework::flatten_to_2d(
x_dims, ctx->Attrs().Get<int>("x_num_col_dims"));
auto y_mat_dims = framework::flatten_to_2d(
y_dims, ctx->Attrs().Get<int>("y_num_col_dims"));
auto x_grad_name = framework::GradVarName("X"); auto x_grad_name = framework::GradVarName("X");
auto y_grad_name = framework::GradVarName("Y"); auto y_grad_name = framework::GradVarName("Y");
......
...@@ -36,7 +36,6 @@ class NCEOp : public framework::OperatorWithKernel { ...@@ -36,7 +36,6 @@ class NCEOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("Input"); auto x_dims = ctx->GetInputDim("Input");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto w_dims = ctx->GetInputDim("Weight");
PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]); PADDLE_ENFORCE_EQ(x_dims[0], label_dims[0]);
int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1; int num_true_classes = label_dims.size() == 2 ? label_dims[1] : 1;
if (ctx->HasInput("Bias")) { if (ctx->HasInput("Bias")) {
......
...@@ -43,7 +43,6 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -43,7 +43,6 @@ class NormKernel : public framework::OpKernel<T> {
out_norm->mutable_data<T>(ctx.GetPlace()); out_norm->mutable_data<T>(ctx.GetPlace());
auto xdim = in_x->dims(); auto xdim = in_x->dims();
auto ndim = out_norm->dims();
T eps = static_cast<T>(ctx.Attr<float>("epsilon")); T eps = static_cast<T>(ctx.Attr<float>("epsilon"));
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
if (axis < 0) axis = xdim.size() + axis; if (axis < 0) axis = xdim.size() + axis;
......
...@@ -41,7 +41,6 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> { ...@@ -41,7 +41,6 @@ class CPUPSROIPoolOpKernel : public framework::OpKernel<T> {
int rois_num = rois->dims()[0]; int rois_num = rois->dims()[0];
auto in_stride = framework::stride(in_dims); auto in_stride = framework::stride(in_dims);
auto roi_stride = framework::stride(rois->dims());
auto out_stride = framework::stride(out->dims()); auto out_stride = framework::stride(out->dims());
const T* input_data = in->data<T>(); const T* input_data = in->data<T>();
......
...@@ -143,8 +143,6 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> { ...@@ -143,8 +143,6 @@ class SequenceSliceGradOpKernel : public framework::OpKernel<T> {
set_zero(ctx.template device_context<DeviceContext>(), x_grad, set_zero(ctx.template device_context<DeviceContext>(), x_grad,
static_cast<T>(0)); static_cast<T>(0));
auto out_grad_stride = framework::stride(out_grad->dims());
for (size_t i = 0; i < out_lod[0].size() - 1; ++i) { for (size_t i = 0; i < out_lod[0].size() - 1; ++i) {
Tensor out_grad_t = Tensor out_grad_t =
out_grad->Slice(static_cast<int>(out_lod[0][i]), out_grad->Slice(static_cast<int>(out_lod[0][i]),
......
...@@ -40,7 +40,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src, ...@@ -40,7 +40,7 @@ inline void StridedMemcpy(const platform::DeviceContext& dev_ctx, const T* src,
const framework::DDim& dst_stride, T* dst) { const framework::DDim& dst_stride, T* dst) {
paddle::operators::detail::StridedCopyDimVisitor<T> func( paddle::operators::detail::StridedCopyDimVisitor<T> func(
dev_ctx, src, src_stride, dst_stride, dst); dev_ctx, src, src_stride, dst_stride, dst);
boost::apply_visitor(func, dst_dim); dst_dim.apply_visitor(func);
} }
// Strided numel memory copy from src to dst by the specified axis // Strided numel memory copy from src to dst by the specified axis
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册