From 646334b558d617b393bb6583162fb27baa7371cc Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 18 May 2017 15:34:18 +0800 Subject: [PATCH] import DDim of Majel into Paddle --- paddle/majel/CMakeLists.txt | 2 +- paddle/majel/ddim.cc | 222 +++++++++++++++ paddle/majel/ddim.h | 109 ++++++++ paddle/majel/dim.h | 456 +++++++++++++++++++++++++++++++ paddle/majel/hostdevice.h | 9 + paddle/majel/test/CMakeLists.txt | 5 + paddle/majel/test/ddim_test.cc | 65 +++++ paddle/majel/test/dim_test.cu | 128 +++++++++ 8 files changed, 995 insertions(+), 1 deletion(-) create mode 100644 paddle/majel/ddim.cc create mode 100644 paddle/majel/ddim.h create mode 100644 paddle/majel/dim.h create mode 100644 paddle/majel/hostdevice.h create mode 100644 paddle/majel/test/ddim_test.cc create mode 100644 paddle/majel/test/dim_test.cu diff --git a/paddle/majel/CMakeLists.txt b/paddle/majel/CMakeLists.txt index d4bce38906e..47f33dd8321 100644 --- a/paddle/majel/CMakeLists.txt +++ b/paddle/majel/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(majel SRCS place.cc) +cc_library(majel SRCS place.cc ddim.cc) if(WITH_TESTING) add_subdirectory(test) diff --git a/paddle/majel/ddim.cc b/paddle/majel/ddim.cc new file mode 100644 index 00000000000..e92d9afa0b6 --- /dev/null +++ b/paddle/majel/ddim.cc @@ -0,0 +1,222 @@ +#include + +namespace majel { + +///@cond HIDDEN + +template +Dim make_dim(const int* d) { + return Dim(*d, make_dim(d + 1)); +} + +template <> +Dim<1> make_dim<1>(const int* d) { + return Dim<1>(*d); +} + +void make_ddim(DDim& ddim, const int* dims, int n) { + switch (n) { + case 1: + ddim = make_dim<1>(dims); + break; + case 2: + ddim = make_dim<2>(dims); + break; + case 3: + ddim = make_dim<3>(dims); + break; + case 4: + ddim = make_dim<4>(dims); + break; + case 5: + ddim = make_dim<5>(dims); + break; + case 6: + ddim = make_dim<6>(dims); + break; + case 7: + ddim = make_dim<7>(dims); + break; + case 8: + ddim = make_dim<8>(dims); + break; + case 9: + ddim = make_dim<9>(dims); + break; + default: + throw std::invalid_argument( + "Dynamic dimensions must have between [1, 9] dimensions."); + } +} + +///@endcond + +DDim make_ddim(std::initializer_list dims) { + DDim result(make_dim(0)); + make_ddim(result, dims.begin(), dims.size()); + return result; +} + +DDim make_ddim(const std::vector& dims) { + DDim result(make_dim(0)); + make_ddim(result, &dims[0], dims.size()); + return result; +} + +///@cond HIDDEN +// XXX For some reason, putting this in an anonymous namespace causes errors +class DynamicMutableIndexer : public boost::static_visitor { +public: + DynamicMutableIndexer(int idx) : idx_(idx) {} + + template + int& operator()(Dim& dim) const { + return dim[idx_]; + } + +private: + int idx_; +}; + +class DynamicConstIndexer : public boost::static_visitor { +public: + DynamicConstIndexer(int idx) : idx_(idx) {} + + template + int operator()(const Dim& dim) const { + return dim[idx_]; + } + +private: + int idx_; +}; + +///@endcond + +int& DDim::operator[](int idx) { + return boost::apply_visitor(DynamicMutableIndexer(idx), var); +} + +int DDim::operator[](int idx) const { + return boost::apply_visitor(DynamicConstIndexer(idx), var); +} + +bool DDim::operator==(DDim d) const { + if (var.which() != d.getVar().which()) { + return false; + } else { + std::vector v1 = vectorize(*this); + std::vector v2 = vectorize(d); + + for (unsigned int i = 0; i < v1.size(); i++) { + if (v1[i] != v2[i]) { + return false; + } + } + + return true; + } +} + +bool DDim::operator!=(DDim d) const { return !(*this == d); } + +DDim DDim::operator+(DDim d) const { + std::vector v1 = vectorize(*this); + std::vector v2 = vectorize(d); + + std::vector v3; + + assert(v1.size() == v2.size()); + + for (unsigned int i = 0; i < v1.size(); i++) { + v3.push_back(v1[i] + v2[i]); + } + + return make_ddim(v3); +} + +DDim DDim::operator*(DDim d) const { + std::vector v1 = vectorize(*this); + std::vector v2 = vectorize(d); + + std::vector v3; + + assert(v1.size() == v2.size()); + + for (unsigned int i = 0; i < v1.size(); i++) { + v3.push_back(v1[i] * v2[i]); + } + + return make_ddim(v3); +} + +int get(const DDim& ddim, int idx) { return ddim[idx]; } + +void set(DDim& ddim, int idx, int value) { ddim[idx] = value; } + +///@cond HIDDEN +struct VectorizeVisitor : public boost::static_visitor<> { + std::vector& vector; + + VectorizeVisitor(std::vector& v) : vector(v) {} + + template + void operator()(const T& t) { + vector.push_back(t.head); + this->operator()(t.tail); + } + + void operator()(const Dim<1>& t) { vector.push_back(t.head); } +}; +///@endcond + +std::vector vectorize(const DDim& ddim) { + std::vector result; + VectorizeVisitor visitor(result); + boost::apply_visitor(visitor, ddim); + return result; +} + +ssize_t product(const DDim& ddim) { + ssize_t result = 1; + std::vector v = vectorize(ddim); + for (auto i : v) { + result *= i; + } + return result; +} + +///\cond HIDDEN + +struct ArityVisitor : boost::static_visitor { + template + int operator()(Dim) const { + return D; + } +}; + +///\endcond + +int arity(const DDim& d) { return boost::apply_visitor(ArityVisitor(), d); } + +///\cond HIDDEN + +struct DDimPrinter : boost::static_visitor { + std::ostream& os; + DDimPrinter(std::ostream& os_) : os(os_) {} + + template + void operator()(const T& t) { + os << t; + } +}; + +///\endcond + +std::ostream& operator<<(std::ostream& os, const majel::DDim& ddim) { + DDimPrinter printer(os); + boost::apply_visitor(printer, ddim); + return os; +} + +} // namespace majel diff --git a/paddle/majel/ddim.h b/paddle/majel/ddim.h new file mode 100644 index 00000000000..64cebf89581 --- /dev/null +++ b/paddle/majel/ddim.h @@ -0,0 +1,109 @@ +#pragma once + +#include +#include +#include +#include + +#include "majel/dim.h" + +namespace majel { + +namespace { +typedef boost::variant, + Dim<2>, + Dim<3>, + Dim<4>, + Dim<5>, + Dim<6>, + Dim<7>, + Dim<8>, + Dim<9>> + DDimVar; +} + +/** + * \brief A dynamically sized dimension. + * + * The number of dimensions must be between [1, 9]. + */ +struct DDim { + DDimVar var; + + DDim() : var(Dim<1>()) {} + + template + DDim(const Dim& in) : var(in) {} + + template + DDim& operator=(const Dim& in) { + var = in; + return *this; + } + + int& operator[](int idx); + int operator[](int idx) const; + + template + typename Visitor::result_type apply_visitor(Visitor& visitor) { + return var.apply_visitor(visitor); + } + + template + typename Visitor::result_type apply_visitor(Visitor& visitor) const { + return var.apply_visitor(visitor); + } + + DDimVar getVar() { return var; } + + bool operator==(DDim d) const; + + bool operator!=(DDim d) const; + + DDim operator+(DDim d) const; + + DDim operator*(DDim d) const; +}; + +/** + * \brief Make a DDim from std::vector + * + * \param dims An vector of ints. Must be sized between [1, 9] + */ +DDim make_ddim(const std::vector& dims); + +/** + * \brief Make a DDim from an initializer list + * + * \param dims An initializer list of ints. Must be sized between [1, 9] + * + */ +DDim make_ddim(std::initializer_list dims); + +int get(const DDim& dim, int idx); +void set(DDim& dim, int idx, int val); + +std::vector vectorize(const DDim& ddim); + +ssize_t product(const DDim& ddim); + +/** + * \brief What is the length of this dimension? + * + * \param Dynamic dimension to inspect + */ + +int arity(const DDim& ddim); + +std::ostream& operator<<(std::ostream&, const majel::DDim&); + +} // namespace majel + +namespace boost { + +template +T get(const majel::DDim& in) { + return boost::get(in.var); +} + +} // namespace boost diff --git a/paddle/majel/dim.h b/paddle/majel/dim.h new file mode 100644 index 00000000000..cf7682b6865 --- /dev/null +++ b/paddle/majel/dim.h @@ -0,0 +1,456 @@ +#pragma once + +#include +#include +#include +#include +/* +#ifdef __CUDACC__ + #include +#endif +*/ + +#include "hostdevice.h" +#include "paddle/utils/Logging.h" + +namespace majel { + +// Statically sized, statically indexed dimension +template +struct Dim { + static constexpr int dimensions = i; + + template + HOSTDEVICE Dim(int _head, Args... _tail) : head(_head), tail(_tail...) { + static_assert(sizeof...(_tail) == i - 1, + "Dim initialized with the wrong number of parameters"); + } + + HOSTDEVICE + Dim(int _head, const Dim& _tail) : head(_head), tail(_tail) {} + + HOSTDEVICE + Dim() : head(0), tail() {} + + /** Construct a Dim from a linear index and size. Uses Fortran order + * indexing. */ + HOSTDEVICE + Dim(int idx, const Dim& size) + : head(idx % size.head), tail(idx / size.head, size.tail) {} + + /** Construct a Dim with each dimension set to the given index */ + HOSTDEVICE + Dim(int idx) : head(idx), tail(idx) {} + + HOSTDEVICE + bool operator==(const Dim& o) const { + return (head == o.head) && (tail == o.tail); + } + + HOSTDEVICE + bool operator!=(const Dim& o) const { return !(*this == o); } + + HOSTDEVICE + int& operator[](int idx); + HOSTDEVICE + int operator[](int idx) const; + + HOST std::string to_string() const; + + int head; + Dim tail; +}; + +// Base case specialization +template <> +struct Dim<1> { + static constexpr int dimensions = 1; + + HOSTDEVICE + Dim(int _head) : head(_head) {} + + HOSTDEVICE + Dim() : head(0) {} + + HOSTDEVICE + Dim(int idx, const Dim<1>& size) : head(idx) { +#ifndef __CUDA_ARCH__ + if (idx >= size.head) { + throw std::invalid_argument("Index out of range."); + } +#else + CHECK(idx < size.head); +#endif + } + + HOSTDEVICE + bool operator==(const Dim<1>& o) const { return (head == o.head); } + + HOSTDEVICE + bool operator!=(const Dim<1>& o) const { return !(*this == o); } + + HOSTDEVICE + int& operator[](int idx); + HOSTDEVICE + int operator[](int idx) const; + + int head; +}; + +namespace { + +// Helper for accessing Dim classes +template +struct DimGetter { + // Return a copy if Dim is const + template + HOSTDEVICE static int impl(const D& d) { + return DimGetter::impl(d.tail); + } + // Return a reference if Dim is mutable + template + HOSTDEVICE static int& impl(D& d) { + return DimGetter::impl(d.tail); + } +}; + +// Eureka! We found the element! +template <> +struct DimGetter<0> { + // Return a copy if Dim is const + template + HOSTDEVICE static int impl(const D& d) { + return d.head; + } + // Return a reference if Dim is mutable + template + HOSTDEVICE static int& impl(D& d) { + return d.head; + } +}; + +template +HOSTDEVICE int& indexer(Dim& dim, int idx) { +#ifndef __CUDA_ARCH__ + if (idx < 0) { + throw std::invalid_argument("Tried to access a negative dimension"); + } +#else + CHECK(idx >= 0); +#endif + if (idx == 0) { + return dim.head; + } + return indexer(dim.tail, idx - 1); +} + +template <> +HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { +#ifndef __CUDA_ARCH__ + if (idx != 0) { + throw std::invalid_argument("Invalid index"); + } +#else + CHECK(idx == 0); +#endif + return dim.head; +} + +template +HOSTDEVICE int indexer(const Dim& dim, int idx) { +#ifndef __CUDA_ARCH__ + if (idx < 0) { + throw std::invalid_argument("Tried to access a negative dimension"); + } +#else + CHECK(idx >= 0); +#endif + if (idx == 0) { + return dim.head; + } + return indexer(dim.tail, idx - 1); +} + +template <> +HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { +#ifndef __CUDA_ARCH__ + if (idx != 0) { + throw std::invalid_argument("Invalid index"); + } +#else + CHECK(idx == 0); +#endif + return dim.head; +} + +} // namespace +// Static access to constant Dim +template +HOSTDEVICE int get(const Dim& d) { + return DimGetter::impl(d); +} + +// Static access to mutable Dim +template +HOSTDEVICE int& get(Dim& d) { + return DimGetter::impl(d); +} + +// Dynamic access to constant Dim +template +HOSTDEVICE int Dim::operator[](int i) const { + return indexer(*this, i); +} + +// Dynamic access to mutable Dim +template +HOSTDEVICE int& Dim::operator[](int i) { + return indexer(*this, i); +} + +// Dynamic access to constant Dim +inline HOSTDEVICE int Dim<1>::operator[](int i) const { + return indexer(*this, i); +} + +// Dynamic access to mutable Dim +inline HOSTDEVICE int& Dim<1>::operator[](int i) { return indexer(*this, i); } + +// Dynamic access to constant Dim +// without std::enable_if will try to instantiate this on get<0>(d) +template +HOSTDEVICE typename std::enable_if<(l > 0), int>::type get(const Dim& d, + int i) { + return d[i]; +} + +// Dynamic access to mutable Dim +template +HOSTDEVICE typename std::enable_if<(l > 0), int&>::type get(Dim& d, int i) { + return d[i]; +} + +// Dot product of two dims +template +HOSTDEVICE int linearize(const Dim& a, const Dim& b) { + return a.head * b.head + linearize(a.tail, b.tail); +} + +// Base case dot product of two Dims +// Notice it is inline because it is no longer a template +template <> +HOSTDEVICE inline int linearize(const Dim<1>& a, const Dim<1>& b) { + return a.head * b.head; +} + +// Product of a Dim +template +HOSTDEVICE int product(const Dim& a, int prod = 1) { + return prod * a.head * product(a.tail); +} + +// Base case product of a Dim +// Notice it is inline because it is no longer a template +template <> +HOSTDEVICE inline int product(const Dim<1>& a, int prod) { + return prod * a.head; +} + +// Is 0 <= idx_i < size_i for all i? +template +HOSTDEVICE bool contained(const Dim& idx, const Dim& size) { + return ((0 <= idx.head) && (idx.head < size.head) && + contained(idx.tail, size.tail)); +} + +// Base case of is 0 <= idx_i < size_i ? +// Notice it is inline because it is no longer a template +template <> +HOSTDEVICE inline bool contained(const Dim<1>& idx, const Dim<1>& size) { + return ((0 <= idx.head) && (idx.head < size.head)); +} + +/** + * \brief Check if a size and a stride create a Fortran order contiguous + * block of memory. + */ +template +HOST bool contiguous(const Dim& size, const Dim& stride, int mul = 1) { + if (product(size) == 0) return true; + int contiguous_stride = get<0>(size) == 1 ? 0 : mul; + return (get<0>(stride) == contiguous_stride && + contiguous(size.tail, stride.tail, mul * get<0>(size))); +} + +///\cond HIDDEN +// Base case of contiguous, check the nth stride is the size of +// the prefix multiply of n-1 dims. +template <> +inline bool contiguous(const Dim<1>& size, const Dim<1>& stride, int mul) { + if (get<0>(size) == 0) return true; + int contiguous_stride = get<0>(size) == 1 ? 0 : mul; + return get<0>(stride) == contiguous_stride; +} +///\endcond + +/** + * \brief Compute exclusive prefix-multiply of a Dim. + */ +template +HOSTDEVICE Dim ex_prefix_mul(const Dim& src, int mul = 1) { + return Dim(mul, ex_prefix_mul(src.tail, mul * src.head)); +} + +///\cond HIDDEN +// Base case of ex_prefix_mul +// Notice it is inline because it is no longer a template +template <> +HOSTDEVICE inline Dim<1> ex_prefix_mul(const Dim<1>& src, int mul) { + return Dim<1>(mul); +} +///\endcond + +/** + * \brief Calculate strides of a contiguous array of the given size + * + * Sets the stride for any dimension with an extent of 1 to 0. + * \param size Dim object containing the size of the array. + * \param base The base stride to use. + * \return Dim object the same size as \p size with the strides. + */ +template +HOSTDEVICE Dim contiguous_strides(const Dim& size, int base = 1) { + int stride = size.head == 1 ? 0 : base; + return Dim(stride, contiguous_strides(size.tail, base * size.head)); +} + +///\cond HIDDEN + +// Base case of contiguous_strides +template <> +HOSTDEVICE inline Dim<1> contiguous_strides(const Dim<1>& size, int base) { + int stride = size.head == 1 ? 0 : base; + return Dim<1>(stride); +} + +///\endcond + +/** + * Add two dimensions together + */ +template +HOSTDEVICE Dim dim_plus(const Dim& a, const Dim& b) { + return Dim(a.head + b.head, dim_plus(a.tail, b.tail)); +} + +// Base case +template <> +HOSTDEVICE inline Dim<1> dim_plus(const Dim<1>& a, const Dim<1>& b) { + return Dim<1>(a.head + b.head); +} + +template +HOSTDEVICE Dim operator+(const Dim& lhs, const Dim& rhs) { + return dim_plus(lhs, rhs); +} + +/** + * Multiply two dimensions together + */ +template +HOSTDEVICE Dim dim_mult(const Dim& a, const Dim& b) { + return Dim(a.head * b.head, dim_mult(a.tail, b.tail)); +} + +// Base case +template <> +HOSTDEVICE inline Dim<1> dim_mult(const Dim<1>& a, const Dim<1>& b) { + return Dim<1>(a.head * b.head); +} + +template +HOSTDEVICE Dim operator*(const Dim& lhs, const Dim& rhs) { + return dim_mult(lhs, rhs); +} + +/** + * \brief Normalize strides to ensure any dimension with extent 1 + * has stride 0. + * + * \param size Dim object containing the size of an array + * \param stride Dim object containing stride of an array + * \return Dim object the same size as \p size with normalized strides + * + */ + +template +HOSTDEVICE Dim normalize_strides(const Dim& size, const Dim& stride) { + int norm_stride = size.head == 1 ? 0 : stride.head; + return Dim(norm_stride, normalize_strides(size.tail, stride.tail)); +} + +///\cond HIDDEN + +template <> +HOSTDEVICE inline Dim<1> normalize_strides(const Dim<1>& size, + const Dim<1>& stride) { + int norm_stride = size.head == 1 ? 0 : stride.head; + return Dim<1>(norm_stride); +} + +///\endcond + +/** + * Helper function to create a Dim + * + * \param idxes The type of Dim constructed depends on the number of params + * + */ + +template +HOSTDEVICE Dim make_dim(Args... idxes) { + return Dim(idxes...); +} + +// Allows us to output a Dim +// XXX For some reason, overloading fails to resolve this correctly +template +typename std::enable_if<(i > 1), std::ostream&>::type operator<<( + std::ostream& os, const majel::Dim& d) { + os << d.head << ", " << d.tail; + return os; +} + +// Base case that allows us to output a Dim +// XXX I wish this could be an overload instead of a template +template +typename std::enable_if<(i == 1), std::ostream&>::type operator<<( + std::ostream& os, const majel::Dim& d) { + os << d.head; + return os; +} + +template +HOST std::string Dim::to_string() const { + std::stringstream stream; + + stream << *this; + + return stream.str(); +} + +template +HOSTDEVICE Dim linear_to_dimension(int linear_index, Dim extents) { + Dim result; + + for (int i = 0; i < D - 1; ++i) { + result[i] = linear_index % extents[i]; + linear_index /= extents[i]; + } + + result[D - 1] = linear_index; + + return result; +} + +} // namespace majel diff --git a/paddle/majel/hostdevice.h b/paddle/majel/hostdevice.h new file mode 100644 index 00000000000..e7de86b7b2f --- /dev/null +++ b/paddle/majel/hostdevice.h @@ -0,0 +1,9 @@ +#pragma once + +#ifdef __CUDACC__ +#define HOSTDEVICE __host__ __device__ +#define HOST __host__ +#else +#define HOSTDEVICE +#define HOST +#endif diff --git a/paddle/majel/test/CMakeLists.txt b/paddle/majel/test/CMakeLists.txt index 68f9059874a..42453799526 100644 --- a/paddle/majel/test/CMakeLists.txt +++ b/paddle/majel/test/CMakeLists.txt @@ -2,6 +2,11 @@ cc_test(place_test SRCS place_test.cc DEPS majel) +cc_test(ddim_test + SRCS ddim_test.cc + DEPS majel) + if(WITH_GPU) nv_test(cuda_test SRCS cuda_test.cu) + nv_test(dim_test SRCS dim_test.cu DEPS majel) endif() diff --git a/paddle/majel/test/ddim_test.cc b/paddle/majel/test/ddim_test.cc new file mode 100644 index 00000000000..5c604517961 --- /dev/null +++ b/paddle/majel/test/ddim_test.cc @@ -0,0 +1,65 @@ +//#include +//#include +#include +#include + +#include "gtest/gtest.h" +#include "majel/ddim.h" + +TEST(DDim, Equality) { + // construct a DDim from an initialization list + majel::DDim ddim = majel::make_ddim({9, 1, 5}); + EXPECT_EQ(ddim[0], 9); + EXPECT_EQ(ddim[1], 1); + EXPECT_EQ(ddim[2], 5); + + // construct a DDim from a vector + std::vector vec({9, 1, 5}); + majel::DDim vddim = majel::make_ddim(vec); + EXPECT_EQ(ddim[0], 9); + EXPECT_EQ(ddim[1], 1); + EXPECT_EQ(ddim[2], 5); + + // mutate a DDim + ddim[1] = 2; + EXPECT_EQ(ddim[1], 2); + majel::set(ddim, 0, 6); + EXPECT_EQ(majel::get(ddim, 0), 6); + + // vectorize a DDim + std::vector res_vec = majel::vectorize(vddim); + EXPECT_EQ(res_vec[0], 9); + EXPECT_EQ(res_vec[1], 1); + EXPECT_EQ(res_vec[2], 5); + majel::Dim<3> d(3, 2, 1); + res_vec = majel::vectorize(majel::DDim(d)); + EXPECT_EQ(res_vec[0], 3); + EXPECT_EQ(res_vec[1], 2); + EXPECT_EQ(res_vec[2], 1); + + // add two DDims + majel::DDim ddim_sum = ddim + vddim; + EXPECT_EQ(ddim_sum[0], 15); + EXPECT_EQ(ddim_sum[1], 3); + EXPECT_EQ(ddim_sum[2], 10); + + // multiply two DDims + majel::DDim ddim_mul = ddim * vddim; + EXPECT_EQ(ddim_mul[0], 54); + EXPECT_EQ(ddim_mul[1], 2); + EXPECT_EQ(ddim_mul[2], 25); + + // arity of a DDim + EXPECT_EQ(majel::arity(ddim), 3); + + // product of a DDim + EXPECT_EQ(majel::product(vddim), 45); +} + +TEST(DDim, Print) { + // print a DDim + std::stringstream ss; + majel::DDim ddim = majel::make_ddim({2, 3, 4}); + ss << ddim; + EXPECT_EQ("2, 3, 4", ss.str()); +} diff --git a/paddle/majel/test/dim_test.cu b/paddle/majel/test/dim_test.cu new file mode 100644 index 00000000000..380204531d8 --- /dev/null +++ b/paddle/majel/test/dim_test.cu @@ -0,0 +1,128 @@ +#include +#include +#include "majel/dim.h" + +#include "gtest/gtest.h" + +__global__ void test(majel::Dim<2>* o) { + o[0] = majel::make_dim(5, 6); +} + +__global__ void dyn_idx_gpu(int* o) { + auto d = majel::make_dim(5, 6); + o[0] = d[1]; +} + +TEST(Dim, Equality) { + // construct a Dim on the CPU + auto a = majel::make_dim(3, 4); + EXPECT_EQ(get<0>(a), 3); + EXPECT_EQ(get<1>(a), 4); + + // construct a Dim on the GPU + thrust::device_vector> t(2); + test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); + a = t[0]; + EXPECT_EQ(get<0>(a), 5); + EXPECT_EQ(get<1>(a), 6); + + // linearization + auto b = make_dim(7, 8); + EXPECT_EQ(linearize(a, b), 83); + + // product + EXPECT_EQ(product(a), 30); + + // mutate a Dim + majel::get<1>(b) = 10; + EXPECT_EQ(majel::get<0>(b), 7); + EXPECT_EQ(majel::get<1>(b), 10); + + // dynamic access + majel::get(b, 0) = 8; + b[1] = 11; + EXPECT_EQ(majel::get<0>(b), 8); + EXPECT_EQ(majel::get<1>(b), 11); + EXPECT_EQ(majel::get(b, 0), 8); + EXPECT_EQ(b[1], 11); + + // dynamic access on GPU + thrust::device_vector r(1); + dyn_idx_gpu<<<1,1>>>(thrust::raw_pointer_cast(r.data())); + int res = r[0]; + EXPECT_EQ(res, 6); + + // ex_prefix_mul + majel::Dim<3> c = majel::ex_prefix_mul(Dim<3>(3, 4, 5)); + EXPECT_EQ(majel::get<0>(c), 1); + EXPECT_EQ(majel::get<1>(c), 3); + EXPECT_EQ(majel::get<2>(c), 12); + + // contiguous_strides + c = majel::contiguous_strides(majel::Dim<3>(10, 1, 10)); + EXPECT_EQ(majel::get<0>(c), 1); + EXPECT_EQ(majel::get<1>(c), 0); + EXPECT_EQ(majel::get<2>(c), 10); + c = majel::contiguous_strides(majel::Dim<3>(10, 10, 1)); + EXPECT_EQ(majel::get<0>(c), 1); + EXPECT_EQ(majel::get<1>(c), 10); + EXPECT_EQ(majel::get<2>(c), 0); + c = majel::contiguous_strides(majel::Dim<3>(1, 10, 10)); + EXPECT_EQ(majel::get<0>(c), 0); + EXPECT_EQ(majel::get<1>(c), 1); + EXPECT_EQ(majel::get<2>(c), 10); + c = majel::contiguous_strides(majel::Dim<3>(2, 3, 4)); + EXPECT_EQ(majel::get<0>(c), 1); + EXPECT_EQ(majel::get<1>(c), 2); + EXPECT_EQ(majel::get<2>(c), 6); + + // generate from an index + auto size = majel::make_dim(4, 5, 2); + c = majel::Dim<3>(14, size); + EXPECT_EQ(majel::get<0>(c), 2); + EXPECT_EQ(majel::get<1>(c), 3); + EXPECT_EQ(majel::get<2>(c), 0); + c = majel::Dim<3>(25, size); + EXPECT_EQ(majel::get<0>(c), 1); + EXPECT_EQ(majel::get<1>(c), 1); + EXPECT_EQ(majel::get<2>(c), 1); +} + +TEST(Dim, Bool) { + auto a = majel::make_dim(3, 4); + auto b = majel::make_dim(5, 6); + auto c = majel::make_dim(3, 4); + + // in_bounds check + EXPECT_TRUE(majel::contained(a, b)); + EXPECT_FALSE(majel::contained(b, a)); + + // comparison + EXPECT_TRUE(a == a); + EXPECT_FALSE(a == b); + EXPECT_TRUE(a == c); + + // contiguous check + int x = 4, y = 5, z = 2; + majel::Dim<3> sizef(x, y, z); + majel::Dim<3> stridea(1, x, x*y); + majel::Dim<3> strideb(2, 2*x, 2*x*y); + majel::Dim<3> stridec(1, x, 2*x*y); + EXPECT_TRUE(majel::contiguous(sizef, stridea)); + EXPECT_FALSE(majel::contiguous(sizef, strideb)); + EXPECT_FALSE(majel::contiguous(sizef, stridec)); +} + +TEST(Dim, Print) { + { + std::stringstream ss; + auto a = majel::make_dim(2, 3); + ss << a; + EXPECT_EQ(ss.str(), "2, 3"); + } + { + std::stringstream ss; + ss << majel::make_dim(8); + EXPECT_EQ(ss.str(), "8"); + } +} -- GitLab