未验证 提交 c3001324 编写于 作者: C Connor Holmes 提交者: GitHub

Add predicated global load (#2373)

Co-authored-by: NReza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com>
上级 f4a92a19
......@@ -25,6 +25,9 @@ enum class StorePolicy {
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src);
template <int AccessSize, LoadPolicy policy = LoadPolicy::CacheAll>
__device__ __forceinline__ void load_global(void* dst, const void* src, bool do_access);
// Shared accesses have no cache policy
template <int AccessSize>
__device__ __forceinline__ void load_shared(void* dst, const void* src);
......@@ -98,6 +101,36 @@ __device__ __forceinline__ void load_global<16>(void* dst, const void* src)
#endif
}
template <>
__device__ __forceinline__ void load_global<16>(void* dst, const void* src, bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
......@@ -112,6 +145,38 @@ __device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* d
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
......@@ -127,6 +192,38 @@ __device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void
#endif
}
template <>
__device__ __forceinline__ void load_global<16, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
uint4* data = reinterpret_cast<uint4*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %5, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\tmov.b32 %2, 0;\n"
"\tmov.b32 %3, 0;\n"
"\t@p ld.global.cg.v4.u32 {%0, %1, %2, %3}, [%4];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w)
: "l"(src), "r"((int)do_access));
#else
const uint4* src_cast = reinterpret_cast<const uint4*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
data[0].z = 0;
data[0].w = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8>(void* dst, const void* src)
{
......@@ -141,6 +238,32 @@ __device__ __forceinline__ void load_global<8>(void* dst, const void* src)
#endif
}
template <>
__device__ __forceinline__ void load_global<8>(void* dst, const void* src, bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
......@@ -155,6 +278,34 @@ __device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* ds
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.cg.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
......@@ -170,6 +321,34 @@ __device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void*
#endif
}
template <>
__device__ __forceinline__ void load_global<8, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
uint2* data = reinterpret_cast<uint2*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %3, 0;\n"
"\tmov.b32 %0, 0;\n"
"\tmov.b32 %1, 0;\n"
"\t@p ld.global.cs.v2.u32 {%0, %1}, [%2];\n"
"}\n"
: "=r"(data[0].x), "=r"(data[0].y)
: "l"(src), "r"((int)do_access));
#else
const uint2* src_cast = reinterpret_cast<const uint2*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0].x = 0;
data[0].y = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4>(void* dst, const void* src)
{
......@@ -182,6 +361,30 @@ __device__ __forceinline__ void load_global<4>(void* dst, const void* src)
#endif
}
template <>
__device__ __forceinline__ void load_global<4>(void* dst, const void* src, bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst, const void* src)
{
......@@ -194,6 +397,32 @@ __device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* ds
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheGlobal>(void* dst,
const void* src,
bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.cg.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
const void* src)
......@@ -207,6 +436,32 @@ __device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void*
#endif
}
template <>
__device__ __forceinline__ void load_global<4, LoadPolicy::CacheStreaming>(void* dst,
const void* src,
bool do_access)
{
int32_t* data = reinterpret_cast<int32_t*>(dst);
#ifdef PTX_AVAILABLE
asm volatile(
"{\n"
"\t.reg .pred p;\n"
"\tsetp.ne.b32 p, %2, 0;\n"
"\tmov.b32 %0, 0;\n"
"\t@p ld.global.cs.u32 {%0}, [%1];\n"
"}\n"
: "=r"(data[0])
: "l"(src), "r"((int)do_access));
#else
const int32_t* src_cast = reinterpret_cast<const int32_t*>(src);
if (do_access) {
data[0] = src_cast[0];
} else {
data[0] = 0;
}
#endif
}
/////////// Load Shared ///////////
namespace internal {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册