提交 0a86a070 编写于 作者: M Megvii Engine Team

fix(mgb/dnn): fix cub potential issues

Wrap cub with CUB_NS_PREFIX and remove dependency on Thrust to avoid potential linking issues

GitOrigin-RevId: 53893b0a3957b9321fcdbca30a5c2659504f2553
上级 282dfc62
......@@ -41,14 +41,6 @@
#include "../util_device.cuh"
#include "../util_namespace.cuh"
#include <thrust/version.h>
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -121,17 +113,7 @@ public:
typedef value_type* pointer; ///< The type of a pointer to an element the iterator can point to
typedef value_type reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -41,12 +41,6 @@
#include "../util_device.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -115,17 +109,7 @@ public:
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::device_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
public:
......
......@@ -41,13 +41,6 @@
#include "../util_device.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -135,17 +128,7 @@ public:
typedef void pointer; ///< The type of a pointer to an element the iterator can point to
typedef Reference reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::device_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -40,13 +40,6 @@
#include "../thread/thread_store.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -104,17 +97,7 @@ public:
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -41,13 +41,6 @@
#include "../util_device.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -102,17 +95,7 @@ public:
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -39,13 +39,6 @@
#include "../util_namespace.cuh"
#include "../util_macro.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -74,17 +67,7 @@ public:
typedef void pointer; ///< The type of a pointer to an element the iterator can point to
typedef void reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -42,13 +42,6 @@
#include "../util_debug.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -119,17 +112,7 @@ public:
typedef T* pointer; ///< The type of a pointer to an element the iterator can point to
typedef T reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::device_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -44,12 +44,6 @@
#if (CUDA_VERSION >= 5050) || defined(DOXYGEN_ACTIVE) // This iterator is compatible with CUDA 5.5 and newer
#if (THRUST_VERSION >= 100700) // This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -212,17 +206,7 @@ public:
typedef T* pointer; ///< The type of a pointer to an element the iterator can point to
typedef T reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::device_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -41,12 +41,6 @@
#include "../util_device.cuh"
#include "../util_namespace.cuh"
#if (THRUST_VERSION >= 100700)
// This iterator is compatible with Thrust API 1.7 and newer
#include <thrust/iterator/iterator_facade.h>
#include <thrust/iterator/iterator_traits.h>
#endif // THRUST_VERSION
/// Optional outer namespace(s)
CUB_NS_PREFIX
......@@ -125,17 +119,7 @@ public:
typedef ValueType* pointer; ///< The type of a pointer to an element the iterator can point to
typedef ValueType reference; ///< The type of a reference to an element the iterator can point to
#if (THRUST_VERSION >= 100700)
// Use Thrust's iterator categories so we can use these iterators in Thrust 1.7 (or newer) methods
typedef typename thrust::detail::iterator_facade_category<
thrust::any_system_tag,
thrust::random_access_traversal_tag,
value_type,
reference
>::type iterator_category; ///< The iterator category
#else
typedef std::random_access_iterator_tag iterator_category; ///< The iterator category
#endif // THRUST_VERSION
private:
......
......@@ -38,9 +38,9 @@
//#define CUB_NS_POSTFIX } }
#ifndef CUB_NS_PREFIX
#define CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace megdnn { namespace cuda {
#endif
#ifndef CUB_NS_POSTFIX
#define CUB_NS_POSTFIX
#define CUB_NS_POSTFIX } }
#endif
......@@ -16,6 +16,7 @@
namespace {
using namespace megdnn;
using namespace cuda;
template <typename T> __global__ void kernel(const T *a, const T *b,
dt_float32 *c,
......
......@@ -43,6 +43,8 @@
namespace {
using namespace megdnn::cuda;
template <int block_size_log2, int max_nr_threads_per_row>
__global__ void reduce_column_with_scale_u4(const uint8_t* src, int32_t scale,
int rows, int cols_int32,
......
......@@ -355,8 +355,8 @@ static __global__ void kern_reduce_block_cnt(const ctype* input_data,
static MEGDNN_NOINLINE cudaError_t
invoke_cub_scan(const uint64_t* input, uint64_t* output, void* workspace,
size_t& workspace_size, uint32_t size, cudaStream_t stream) {
return cub::DeviceScan::InclusiveSum(workspace, workspace_size, input,
output, size, stream);
return cub::DeviceScan::InclusiveSum(workspace, workspace_size,
input, output, size, stream);
}
static __global__ void kern_init_zero(uint64_t* dst) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册