提交 110a4022 编写于 作者: F fengjiayi

make up missing 'majel::' and import majel/util.h

上级 051f0b99
...@@ -4,14 +4,9 @@ ...@@ -4,14 +4,9 @@
#include <sstream> #include <sstream>
#include <stdexcept> #include <stdexcept>
#include <type_traits> #include <type_traits>
/*
#ifdef __CUDACC__
#include <host_defines.h>
#endif
*/
#include "hostdevice.h" #include "majel/hostdevice.h"
#include "paddle/utils/Logging.h" #include "majel/util.h"
namespace majel { namespace majel {
...@@ -79,7 +74,7 @@ struct Dim<1> { ...@@ -79,7 +74,7 @@ struct Dim<1> {
throw std::invalid_argument("Index out of range."); throw std::invalid_argument("Index out of range.");
} }
#else #else
CHECK(idx < size.head); MAJEL_ASSERT(idx < size.head);
#endif #endif
} }
...@@ -136,7 +131,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) { ...@@ -136,7 +131,7 @@ HOSTDEVICE int& indexer(Dim<D>& dim, int idx) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
} }
#else #else
CHECK(idx >= 0); MAJEL_ASSERT(idx >= 0);
#endif #endif
if (idx == 0) { if (idx == 0) {
return dim.head; return dim.head;
...@@ -151,7 +146,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) { ...@@ -151,7 +146,7 @@ HOSTDEVICE int& indexer<1>(Dim<1>& dim, int idx) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
} }
#else #else
CHECK(idx == 0); MAJEL_ASSERT(idx == 0);
#endif #endif
return dim.head; return dim.head;
} }
...@@ -163,7 +158,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) { ...@@ -163,7 +158,7 @@ HOSTDEVICE int indexer(const Dim<D>& dim, int idx) {
throw std::invalid_argument("Tried to access a negative dimension"); throw std::invalid_argument("Tried to access a negative dimension");
} }
#else #else
CHECK(idx >= 0); MAJEL_ASSERT(idx >= 0);
#endif #endif
if (idx == 0) { if (idx == 0) {
return dim.head; return dim.head;
...@@ -178,7 +173,7 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) { ...@@ -178,7 +173,7 @@ HOSTDEVICE int indexer<1>(const Dim<1>& dim, int idx) {
throw std::invalid_argument("Invalid index"); throw std::invalid_argument("Invalid index");
} }
#else #else
CHECK(idx == 0); MAJEL_ASSERT(idx == 0);
#endif #endif
return dim.head; return dim.head;
} }
......
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <sstream> #include <sstream>
#include "majel/dim.h"
#include "majel/dim.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
__global__ void test(majel::Dim<2>* o) { __global__ void test(majel::Dim<2>* o) {
...@@ -16,22 +16,22 @@ __global__ void dyn_idx_gpu(int* o) { ...@@ -16,22 +16,22 @@ __global__ void dyn_idx_gpu(int* o) {
TEST(Dim, Equality) { TEST(Dim, Equality) {
// construct a Dim on the CPU // construct a Dim on the CPU
auto a = majel::make_dim(3, 4); auto a = majel::make_dim(3, 4);
EXPECT_EQ(get<0>(a), 3); EXPECT_EQ(majel::get<0>(a), 3);
EXPECT_EQ(get<1>(a), 4); EXPECT_EQ(majel::get<1>(a), 4);
// construct a Dim on the GPU // construct a Dim on the GPU
thrust::device_vector<majel::Dim<2>> t(2); thrust::device_vector<majel::Dim<2>> t(2);
test<<<1,1>>>(thrust::raw_pointer_cast(t.data())); test<<<1,1>>>(thrust::raw_pointer_cast(t.data()));
a = t[0]; a = t[0];
EXPECT_EQ(get<0>(a), 5); EXPECT_EQ(majel::get<0>(a), 5);
EXPECT_EQ(get<1>(a), 6); EXPECT_EQ(majel::get<1>(a), 6);
// linearization // linearization
auto b = make_dim(7, 8); auto b = majel::make_dim(7, 8);
EXPECT_EQ(linearize(a, b), 83); EXPECT_EQ(majel::linearize(a, b), 83);
// product // product
EXPECT_EQ(product(a), 30); EXPECT_EQ(majel::product(a), 30);
// mutate a Dim // mutate a Dim
majel::get<1>(b) = 10; majel::get<1>(b) = 10;
...@@ -53,7 +53,7 @@ TEST(Dim, Equality) { ...@@ -53,7 +53,7 @@ TEST(Dim, Equality) {
EXPECT_EQ(res, 6); EXPECT_EQ(res, 6);
// ex_prefix_mul // ex_prefix_mul
majel::Dim<3> c = majel::ex_prefix_mul(Dim<3>(3, 4, 5)); majel::Dim<3> c = majel::ex_prefix_mul(majel::Dim<3>(3, 4, 5));
EXPECT_EQ(majel::get<0>(c), 1); EXPECT_EQ(majel::get<0>(c), 1);
EXPECT_EQ(majel::get<1>(c), 3); EXPECT_EQ(majel::get<1>(c), 3);
EXPECT_EQ(majel::get<2>(c), 12); EXPECT_EQ(majel::get<2>(c), 12);
......
#pragma once
#define STRINGIFY(x) #x
#define TOSTRING(x) STRINGIFY(x)
#if defined(__APPLE__) && defined(__CUDA_ARCH__) && !defined(NDEBUG)
#include <stdio.h>
#define MAJEL_ASSERT(e) \
do { \
if (!(e)) { \
printf( \
"%s:%d Assertion `%s` failed.\n", __FILE__, __LINE__, TOSTRING(e)); \
asm("trap;"); \
} \
} while (0)
#define MAJEL_ASSERT_MSG(e, m) \
do { \
if (!(e)) { \
printf("%s:%d Assertion `%s` failed (%s).\n", \
__FILE__, \
__LINE__, \
TOSTRING(e), \
m); \
asm("trap;"); \
} \
} while (0)
#else
#include <assert.h>
#define MAJEL_ASSERT(e) assert(e)
#define MAJEL_ASSERT_MSG(e, m) assert((e) && (m))
#endif
namespace majel {
namespace detail {
inline int div_up(int x, int y) { return (x + y - 1) / y; }
}
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册