Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
1aa64d13
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
1aa64d13
编写于
11月 09, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
11月 09, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Sparse]optimize sparse convolution and fix MaskHelper bug (#47703)
上级
5c7fce47
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
69 addition
and
36 deletion
+69
-36
paddle/phi/kernels/funcs/sparse/utils.cu.h
paddle/phi/kernels/funcs/sparse/utils.cu.h
+13
-0
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+32
-26
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
+24
-10
未找到文件。
paddle/phi/kernels/funcs/sparse/utils.cu.h
浏览文件 @
1aa64d13
...
@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
...
@@ -26,6 +26,19 @@ __global__ void DistanceKernel(const T* start, const T* end, T* distance) {
}
}
}
}
inline
__device__
bool
SetBits
(
const
int
value
,
int
*
ptr
)
{
const
int
index
=
value
>>
5
;
const
int
mask
=
1
<<
(
value
&
31
);
const
int
old
=
atomicOr
(
ptr
+
index
,
mask
);
return
(
mask
&
old
)
!=
0
;
}
inline
__device__
bool
TestBits
(
const
int
value
,
const
int
*
ptr
)
{
const
int
index
=
value
>>
5
;
const
int
mask
=
1
<<
(
value
&
31
);
return
(
mask
&
ptr
[
index
])
!=
0
;
}
}
// namespace sparse
}
// namespace sparse
}
// namespace funcs
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/sparse/gpu/conv.cu.h
浏览文件 @
1aa64d13
...
@@ -167,7 +167,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
...
@@ -167,7 +167,7 @@ inline void GatherV2(const GPUContext& dev_ctx,
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
UniqueKernel
(
const
IntT
*
in_indexs
,
__global__
void
UniqueKernel
(
const
IntT
*
in_indexs
,
const
int
rulebook_len
,
const
int
rulebook_len
,
int
*
out_index_table
,
int
*
index_flags
,
int
*
out_indexs
,
int
*
out_indexs
,
int
*
nnz
)
{
int
*
nnz
)
{
extern
__shared__
int
cache
[];
extern
__shared__
int
cache
[];
...
@@ -182,8 +182,8 @@ __global__ void UniqueKernel(const IntT* in_indexs,
...
@@ -182,8 +182,8 @@ __global__ void UniqueKernel(const IntT* in_indexs,
if
(
i
<
rulebook_len
)
{
if
(
i
<
rulebook_len
)
{
// atomicOr only support int
// atomicOr only support int
int
index
=
static_cast
<
int
>
(
in_indexs
[
i
]);
int
index
=
static_cast
<
int
>
(
in_indexs
[
i
]);
int
flag
=
atomicOr
(
out_index_table
+
index
,
1
);
const
bool
flag
=
phi
::
funcs
::
sparse
::
SetBits
(
index
,
index_flags
);
if
(
flag
==
0
)
{
if
(
!
flag
)
{
int
j
=
atomicAdd
(
&
count
,
1
);
int
j
=
atomicAdd
(
&
count
,
1
);
cache
[
j
]
=
index
;
cache
[
j
]
=
index
;
}
}
...
@@ -284,7 +284,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
...
@@ -284,7 +284,6 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
kernel_i
=
kernel_index
;
kernel_i
=
kernel_index
;
}
}
// rulebook[kernel_index * non_zero_num + i] = kernel_i;
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
i
]
=
in_i
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
out_index
;
rulebook
[
kernel_index
*
non_zero_num
+
offset
+
i
]
=
out_index
;
++
kernel_index
;
++
kernel_index
;
...
@@ -299,17 +298,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
...
@@ -299,17 +298,19 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
}
}
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
GetOutIndexTable
(
const
IntT
*
indices
,
__global__
void
GetOutIndexTable1
(
const
IntT
*
indices
,
const
IntT
non_zero_num
,
const
IntT
non_zero_num
,
const
Dims4D
dims
,
const
Dims4D
dims
,
int
*
out_index_table
)
{
int
*
index_flags
,
int
*
out_index_table
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
batch
=
indices
[
i
];
IntT
batch
=
indices
[
i
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
out_index_table
[
index
]
=
i
==
0
?
-
1
:
i
;
phi
::
funcs
::
sparse
::
SetBits
(
index
,
index_flags
);
out_index_table
[
index
]
=
i
;
}
}
}
}
...
@@ -375,6 +376,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
...
@@ -375,6 +376,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const
Dims4D
paddings
,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
Dims4D
strides
,
const
int
*
index_flags
,
const
int
*
out_index_table
,
const
int
*
out_index_table
,
T
*
rulebook
,
T
*
rulebook
,
int
*
counter
)
{
int
*
counter
)
{
...
@@ -417,9 +419,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
...
@@ -417,9 +419,10 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
T
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
T
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
int
real_out_index
=
out_index_table
[
out_index
];
const
bool
flag
=
if
(
real_out_index
!=
0
)
{
phi
::
funcs
::
sparse
::
TestBits
(
out_index
,
index_flags
);
real_out_index
=
real_out_index
==
-
1
?
0
:
real_out_index
;
if
(
flag
)
{
int
real_out_index
=
out_index_table
[
out_index
];
in_i
=
i
;
in_i
=
i
;
int
buf_i
=
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
int
buf_i
=
atomicAdd
(
&
counter_buf
[
kernel_index
],
1
);
kernel_i
=
kernel_index
;
kernel_i
=
kernel_index
;
...
@@ -440,7 +443,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
...
@@ -440,7 +443,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
__syncthreads
();
__syncthreads
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
threadIdx
.
x
<
counter_buf
[
i
])
{
if
(
threadIdx
.
x
<
counter_buf
[
i
])
{
// rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i;
rulebook
[
i
*
non_zero_num
+
counter_buf2
[
i
]
+
threadIdx
.
x
]
=
rulebook
[
i
*
non_zero_num
+
counter_buf2
[
i
]
+
threadIdx
.
x
]
=
rulebook_buf
[
i
*
blockDim
.
x
+
threadIdx
.
x
];
rulebook_buf
[
i
*
blockDim
.
x
+
threadIdx
.
x
];
rulebook
[
i
*
non_zero_num
+
offset
+
counter_buf2
[
i
]
+
threadIdx
.
x
]
=
rulebook
[
i
*
non_zero_num
+
offset
+
counter_buf2
[
i
]
+
threadIdx
.
x
]
=
...
@@ -575,12 +577,18 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -575,12 +577,18 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensorMeta
rulebook_meta
(
DenseTensorMeta
rulebook_meta
(
indices_dtype
,
{
rulebook_rows
,
rulebook_cols
},
DataLayout
::
NCHW
);
indices_dtype
,
{
rulebook_rows
,
rulebook_cols
},
DataLayout
::
NCHW
);
int
64_t
table_size
=
1
;
int
table_size
=
1
;
for
(
int
i
=
0
;
i
<
out_dims
.
size
()
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
out_dims
.
size
()
-
1
;
i
++
)
{
table_size
*=
out_dims
[
i
];
table_size
*=
out_dims
[
i
];
}
}
DenseTensor
out_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
DenseTensor
out_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
int
*
out_index_table_ptr
=
out_index_table
.
data
<
int
>
();
int
*
out_index_table_ptr
=
out_index_table
.
data
<
int
>
();
// index_flags: flag the indices exist or not
int
index_flags_size
=
(
table_size
+
31
)
/
32
;
DenseTensor
index_flags
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
index_flags_size
});
int
*
index_flags_ptr
=
index_flags
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
index_flags_ptr
,
0
,
sizeof
(
int
)
*
index_flags
.
numel
(),
dev_ctx
.
stream
());
if
(
subm
)
{
if
(
subm
)
{
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
...
@@ -590,16 +598,16 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -590,16 +598,16 @@ int ProductRuleBook(const Context& dev_ctx,
phi
::
Copy
(
dev_ctx
,
x
.
indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
phi
::
Copy
(
dev_ctx
,
x
.
indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
auto
config
=
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
GetOutIndexTable
<
IntT
><<<
config
.
block_per_grid
,
GetOutIndexTable1
<
IntT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
config
.
thread_per_block
,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
dev_ctx
.
stream
()
>>>
(
out_indices
.
data
<
IntT
>
(),
out_indices
.
data
<
IntT
>
(),
non_zero_num
,
d_x_dims
,
out_index_table_ptr
);
non_zero_num
,
d_x_dims
,
index_flags_ptr
,
out_index_table_ptr
);
size_t
cache_size
=
size_t
cache_size
=
kernel_size
*
2
*
sizeof
(
int
)
+
kernel_size
*
2
*
sizeof
(
int
)
+
...
@@ -625,6 +633,7 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -625,6 +633,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings
,
d_paddings
,
d_dilations
,
d_dilations
,
d_strides
,
d_strides
,
index_flags_ptr
,
out_index_table_ptr
,
out_index_table_ptr
,
rulebook_ptr
,
rulebook_ptr
,
counter_ptr
);
counter_ptr
);
...
@@ -695,9 +704,6 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -695,9 +704,6 @@ int ProductRuleBook(const Context& dev_ctx,
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
int
*
unique_key_ptr
=
unique_key
.
data
<
int
>
();
int
*
unique_key_ptr
=
unique_key
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
unique_key_ptr
,
0
,
sizeof
(
int
),
dev_ctx
.
stream
());
unique_key_ptr
,
0
,
sizeof
(
int
),
dev_ctx
.
stream
());
...
@@ -708,7 +714,7 @@ int ProductRuleBook(const Context& dev_ctx,
...
@@ -708,7 +714,7 @@ int ProductRuleBook(const Context& dev_ctx,
cache_size
,
cache_size
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
+
rulebook_len
,
dev_ctx
.
stream
()
>>>
(
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
rulebook_len
,
out_index_table
_ptr
,
index_flags
_ptr
,
out_index_ptr
,
out_index_ptr
,
unique_key_ptr
);
unique_key_ptr
);
...
...
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
浏览文件 @
1aa64d13
...
@@ -25,6 +25,7 @@ limitations under the License. */
...
@@ -25,6 +25,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/flatten_indices.cu.h"
#include "paddle/phi/kernels/funcs/sparse/utils.cu.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx,
...
@@ -118,15 +119,20 @@ void SparseMaskKernel(const Context& dev_ctx,
}
}
template
<
typename
IntT
>
template
<
typename
IntT
>
__global__
void
MaskTable
(
const
IntT
*
x_indexs
,
const
int
n
,
int
*
table
)
{
__global__
void
MaskTable
(
const
IntT
*
x_indexs
,
const
int
n
,
int
*
index_flags
,
int
*
table
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
int
index
=
x_indexs
[
i
];
int
index
=
x_indexs
[
i
];
table
[
index
]
=
i
==
0
?
-
1
:
i
;
phi
::
funcs
::
sparse
::
SetBits
(
index
,
index_flags
);
table
[
index
]
=
i
;
}
}
}
}
template
<
typename
T
,
typename
IntT
,
int
VecSize
>
template
<
typename
T
,
typename
IntT
,
int
VecSize
>
__global__
void
MaskCopy
(
const
IntT
*
mask_indexs
,
__global__
void
MaskCopy
(
const
IntT
*
mask_indexs
,
const
int
*
index_flags
,
const
int
*
table
,
const
int
*
table
,
const
int
n
,
const
int
n
,
const
int
stride
,
const
int
stride
,
...
@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
...
@@ -135,9 +141,10 @@ __global__ void MaskCopy(const IntT* mask_indexs,
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
,
int64_t
)
{
int
j
=
table
[
mask_indexs
[
i
]];
const
int
mask_index
=
mask_indexs
[
i
];
if
(
j
!=
0
)
{
const
bool
flag
=
phi
::
funcs
::
sparse
::
TestBits
(
mask_index
,
index_flags
);
if
(
j
==
-
1
)
j
=
0
;
if
(
flag
)
{
int
j
=
table
[
mask_index
];
for
(
int
k
=
0
;
k
<
stride
;
k
+=
VecSize
)
{
for
(
int
k
=
0
;
k
<
stride
;
k
+=
VecSize
)
{
LoadT
vec_x
;
LoadT
vec_x
;
phi
::
Load
<
T
,
VecSize
>
(
x_values
+
j
*
stride
+
k
,
&
vec_x
);
phi
::
Load
<
T
,
VecSize
>
(
x_values
+
j
*
stride
+
k
,
&
vec_x
);
...
@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -217,12 +224,15 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
int
table_size
=
1
;
int
table_size
=
1
;
auto
x_dims
=
x
.
dims
();
auto
x_dims
=
x
.
dims
();
for
(
int
i
=
0
;
i
<
x_dims
.
size
()
-
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
sparse_dim
;
i
++
)
{
table_size
*=
x_dims
[
i
];
table_size
*=
x_dims
[
i
];
}
}
DenseTensor
table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
DenseTensor
table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
DenseTensor
index_flags
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{(
table_size
+
31
)
/
32
});
table
.
data
<
int
>
(),
0
,
table_size
*
sizeof
(
int
),
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
index_flags
.
data
<
int
>
(),
0
,
index_flags
.
numel
()
*
sizeof
(
int
),
dev_ctx
.
stream
());
const
int64_t
stride
=
const
int64_t
stride
=
x
.
dims
().
size
()
==
sparse_dim
?
1
:
x
.
values
().
dims
()[
1
];
x
.
dims
().
size
()
==
sparse_dim
?
1
:
x
.
values
().
dims
()[
1
];
*
out
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
values
());
*
out
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
values
());
...
@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -234,8 +244,10 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
MaskTable
<<<
config
.
block_per_grid
,
MaskTable
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
config
.
thread_per_block
,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
dev_ctx
.
stream
()
>>>
(
x_indexs_ptr
,
x_indexs_ptr
,
x_indexs
.
numel
(),
table
.
data
<
int
>
());
x_indexs
.
numel
(),
index_flags
.
data
<
int
>
(),
table
.
data
<
int
>
());
config
=
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
mask_indexs
.
numel
(),
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
mask_indexs
.
numel
(),
1
);
...
@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -246,6 +258,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config
.
thread_per_block
,
config
.
thread_per_block
,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
index_flags
.
data
<
int
>
(),
table
.
data
<
int
>
(),
table
.
data
<
int
>
(),
mask_indexs
.
numel
(),
mask_indexs
.
numel
(),
stride
,
stride
,
...
@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -256,6 +269,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
config
.
thread_per_block
,
config
.
thread_per_block
,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
dev_ctx
.
stream
()
>>>
(
mask_indexs_ptr
,
index_flags
.
data
<
int
>
(),
table
.
data
<
int
>
(),
table
.
data
<
int
>
(),
mask_indexs
.
numel
(),
mask_indexs
.
numel
(),
stride
,
stride
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录