Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8fbe97e4
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8fbe97e4
编写于
9月 21, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
9月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Revert "SparseConv support duplicate coordinates (#44976)" (#45202)
This reverts commit
e8de9dfd
.
上级
a93a95bf
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
53 addition
and
168 deletion
+53
-168
paddle/phi/kernels/funcs/sparse/scatter.cu.h
paddle/phi/kernels/funcs/sparse/scatter.cu.h
+2
-7
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+27
-120
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+2
-38
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+22
-3
未找到文件。
paddle/phi/kernels/funcs/sparse/scatter.cu.h
浏览文件 @
8fbe97e4
...
...
@@ -79,7 +79,6 @@ __global__ void ScatterKernelV2(const T* input,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
buffer_counts
,
T
*
out
)
{
...
...
@@ -97,11 +96,10 @@ __global__ void ScatterKernelV2(const T* input,
&
sums
);
for
(
int
it
=
0
;
it
<
buffer_counts
;
it
++
)
{
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
const
int
group_offset
=
it
*
max_voxel
*
kernel_size
*
non_zero_num
;
const
int
group_offset
=
it
*
kernel_size
*
non_zero_num
;
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
const
int
out_feature_i
=
index_groups
[
indices_i
*
max_voxel
*
kernel_size
+
j
+
group_offset
];
index_groups
[
indices_i
*
kernel_size
+
j
+
group_offset
];
LoadT
vec_in
;
phi
::
Load
<
T
,
VecSize
>
(
input
+
out_feature_i
*
channels
+
channels_i
*
VecSize
,
&
vec_in
);
...
...
@@ -123,7 +121,6 @@ void ScatterV2(const GPUContext& dev_ctx,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
buffer_counts
,
T
*
output
)
{
...
...
@@ -139,7 +136,6 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups
,
non_zero_num
,
kernel_size
,
max_voxel
,
channels
,
buffer_counts
,
output
);
...
...
@@ -154,7 +150,6 @@ void ScatterV2(const GPUContext& dev_ctx,
index_groups
,
non_zero_num
,
kernel_size
,
max_voxel
,
channels
,
buffer_counts
,
output
);
...
...
paddle/phi/kernels/sparse/gpu/conv.cu.h
浏览文件 @
8fbe97e4
...
...
@@ -66,7 +66,6 @@ __global__ void GatherKernelV2(const T* inputs,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
buffer_count
,
T
*
output
)
{
...
...
@@ -84,11 +83,10 @@ __global__ void GatherKernelV2(const T* inputs,
#pragma unroll
for
(
int
it
=
0
;
it
<
buffer_count
;
it
++
)
{
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
const
int
group_offset
=
it
*
kernel_size
*
max_voxel
*
non_zero_num
;
const
int
group_offset
=
it
*
kernel_size
*
non_zero_num
;
#pragma unroll
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
int
out_i
=
index_groups
[
indices_i
*
kernel_size
*
max_voxel
+
j
+
group_offset
];
int
out_i
=
index_groups
[
indices_i
*
kernel_size
+
j
+
group_offset
];
phi
::
Store
<
T
,
VecSize
>
(
in_vec
,
output
+
out_i
*
channels
+
channels_i
*
VecSize
);
}
...
...
@@ -130,7 +128,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
max_voxel
,
const
int
channels
,
const
int
buffer_count
,
T
*
output
)
{
...
...
@@ -146,7 +143,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups
,
non_zero_num
,
kernel_size
,
max_voxel
,
channels
,
buffer_count
,
output
);
...
...
@@ -161,7 +157,6 @@ inline void GatherV2(const GPUContext& dev_ctx,
index_groups
,
non_zero_num
,
kernel_size
,
max_voxel
,
channels
,
buffer_count
,
output
);
...
...
@@ -207,7 +202,7 @@ __global__ void UniqueKernel(const IntT* in_indexs,
template
<
typename
IntT
>
__global__
void
GroupIndexs
(
const
int
*
out_index_table
,
const
int
n
,
const
int
offset
,
const
int
kernel_size
,
IntT
*
out_indexs
,
int
*
out_index_counts
,
int
*
out_index_groups
)
{
...
...
@@ -219,7 +214,7 @@ __global__ void GroupIndexs(const int* out_index_table,
// kernel_size at most
int
j
=
atomicAdd
(
out_index_counts
+
real_index
,
1
);
// nnz * kernel_size
out_index_groups
[
real_index
*
offset
+
j
]
=
i
;
out_index_groups
[
real_index
*
kernel_size
+
j
]
=
i
;
}
}
...
...
@@ -303,36 +298,18 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
}
}
template
<
typename
IntT
,
bool
save_out_index
=
true
>
template
<
typename
IntT
>
__global__
void
GetOutIndexTable
(
const
IntT
*
indices
,
const
IntT
non_zero_num
,
const
Dims4D
dims
,
int
*
out_index_table
,
int
*
out_index_table2
,
int
*
max_voxel
)
{
__shared__
int
cache_max
;
if
(
threadIdx
.
x
==
0
)
{
cache_max
=
0
;
}
__syncthreads
();
int
*
out_index_table
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
batch
=
indices
[
i
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
if
(
save_out_index
)
{
out_index_table
[
index
]
=
i
==
0
?
-
1
:
i
;
}
int
count
=
atomicAdd
(
out_index_table2
+
index
,
1
);
atomicMax
(
&
cache_max
,
count
);
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicMax
(
max_voxel
,
cache_max
+
1
);
out_index_table
[
index
]
=
i
==
0
?
-
1
:
i
;
}
}
...
...
@@ -341,22 +318,10 @@ __global__ void GetOutIndexTable(int* indexs,
const
int
non_zero_num
,
const
Dims4D
out_dims
,
int
*
out_index_table
,
int
*
out_index_table2
,
int
*
max_voxel
,
IntT
*
out_indices
)
{
__shared__
int
cache_max
;
if
(
threadIdx
.
x
==
0
)
{
cache_max
=
0
;
}
__syncthreads
();
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
index
=
static_cast
<
IntT
>
(
indexs
[
i
]);
out_index_table
[
index
]
=
i
;
int
count
=
atomicAdd
(
out_index_table2
+
index
,
1
);
atomicMax
(
&
cache_max
,
count
);
IntT
batch
,
x
,
y
,
z
;
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
...
...
@@ -367,11 +332,6 @@ __global__ void GetOutIndexTable(int* indexs,
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
indexs
[
i
]
=
0
;
}
__syncthreads
();
if
(
threadIdx
.
x
==
0
)
{
atomicMax
(
max_voxel
,
cache_max
+
1
);
}
}
template
<
typename
IntT
>
...
...
@@ -491,7 +451,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
template
<
typename
IntT
>
__global__
void
GroupIndexs
(
const
int
n
,
const
int
offset
,
const
int
kernel_size
,
const
IntT
*
indexs
,
int
*
index_counts
,
int
*
index_groups
)
{
...
...
@@ -500,7 +460,7 @@ __global__ void GroupIndexs(const int n,
// kernel_size at most
int
j
=
atomicAdd
(
index_counts
+
index
,
1
);
// nnz * kernel_size
index_groups
[
index
*
offset
+
j
]
=
i
;
index_groups
[
index
*
kernel_size
+
j
]
=
i
;
}
}
...
...
@@ -508,7 +468,7 @@ __global__ void GroupIndexs(const int n,
template
<
typename
IntT
>
__global__
void
GroupIndexsV2
(
const
int
rulebook_len
,
const
int
non_zero_num
,
const
int
offset
,
const
int
kernel_size
,
const
int
half_kernel_offset
,
const
IntT
*
indexs
,
int
*
index_counts
,
...
...
@@ -519,11 +479,11 @@ __global__ void GroupIndexsV2(const int rulebook_len,
i
<
half_kernel_offset
?
index_counts
:
index_counts
+
non_zero_num
;
int
*
groups_ptr
=
i
<
half_kernel_offset
?
index_groups
:
index_groups
+
non_zero_num
*
offset
;
:
index_groups
+
non_zero_num
*
kernel_size
;
// conflict kernel_size times at most
int
j
=
atomicAdd
(
counts_ptr
+
index
,
1
);
// nnz * kernel_size
groups_ptr
[
index
*
offset
+
j
]
=
i
;
groups_ptr
[
index
*
kernel_size
+
j
]
=
i
;
}
}
...
...
@@ -622,10 +582,6 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensor
out_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
});
int
*
out_index_table_ptr
=
out_index_table
.
data
<
int
>
();
DenseTensor
out_index_table2
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
+
1
});
int
*
out_index_table2_ptr
=
out_index_table2
.
data
<
int
>
();
int
*
h_max_voxel
=
h_counter
+
kernel_size
;
if
(
subm
)
{
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
IntT
*
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
...
...
@@ -636,29 +592,14 @@ int ProductRuleBook(const Context& dev_ctx,
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table2_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
,
1
);
GetOutIndexTable
<
IntT
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_indices
.
data
<
IntT
>
(),
non_zero_num
,
d_x_dims
,
out_index_table_ptr
,
out_index_table2_ptr
,
out_index_table2_ptr
+
table_size
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
h_max_voxel
,
out_index_table2_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
GetOutIndexTable
<
IntT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_indices
.
data
<
IntT
>
(),
non_zero_num
,
d_x_dims
,
out_index_table_ptr
);
size_t
cache_size
=
kernel_size
*
2
*
sizeof
(
int
)
+
...
...
@@ -712,22 +653,6 @@ int ProductRuleBook(const Context& dev_ctx,
out_rulebook_ptr
);
*
rulebook
=
out_rulebook
;
unique_value
->
ResizeAndAllocate
(
{
static_cast
<
int
>
(
non_zero_num
*
h_max_voxel
[
0
]
*
kernel_size
)});
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
out_index
->
ResizeAndAllocate
({
static_cast
<
int
>
(
rulebook_len
)});
int
*
out_index_ptr
=
out_index
->
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
rulebook_len
,
dev_ctx
.
stream
());
GroupIndexs
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
kernel_size
*
h_max_voxel
[
0
],
out_rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
return
rulebook_len
;
}
else
{
...
...
@@ -811,35 +736,17 @@ int ProductRuleBook(const Context& dev_ctx,
IntT
*
out_indices_ptr
=
out_indices
.
data
<
IntT
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table_ptr
,
0
,
sizeof
(
int
)
*
table_size
,
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_table2_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out_nnz
,
1
);
GetOutIndexTable
<
IntT
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_index_ptr
,
out_nnz
,
d_out_dims
,
out_index_table_ptr
,
out_index_table2_ptr
,
out_index_table2_ptr
+
table_size
,
out_indices_ptr
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
h_max_voxel
,
out_index_table2_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
GetOutIndexTable
<
IntT
><<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
out_index_ptr
,
out_nnz
,
d_out_dims
,
out_index_table_ptr
,
out_indices_ptr
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
unique_value
->
ResizeAndAllocate
(
{
static_cast
<
int
>
(
out_nnz
*
h_max_voxel
[
0
]
*
kernel_size
)});
unique_value
->
ResizeAndAllocate
({
static_cast
<
int
>
(
out_nnz
*
kernel_size
)});
int
*
unique_value_ptr
=
unique_value
->
data
<
int
>
();
GroupIndexs
<<<
config
.
block_per_grid
,
...
...
@@ -847,7 +754,7 @@ int ProductRuleBook(const Context& dev_ctx,
0
,
dev_ctx
.
stream
()
>>>
(
out_index_table_ptr
,
rulebook_len
,
kernel_size
*
h_max_voxel
[
0
]
,
kernel_size
,
rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
...
...
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
8fbe97e4
...
...
@@ -119,44 +119,10 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
}
int
max_voxel
=
counter_ptr
[
kernel_size
];
if
(
!
subm
)
{
const
auto
&
x_dims
=
x
.
dims
();
Dims4D
d_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
int64_t
table_size
=
1
;
for
(
int
i
=
0
;
i
<
x_dims
.
size
()
-
1
;
i
++
)
{
table_size
*=
x_dims
[
i
];
}
DenseTensor
in_index_table
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
table_size
+
1
});
int
*
in_index_table_ptr
=
in_index_table
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
in_index_table_ptr
,
0
,
sizeof
(
int
)
*
(
table_size
+
1
),
dev_ctx
.
stream
());
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
x
.
nnz
(),
1
);
GetOutIndexTable
<
IntT
,
false
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
x
.
indices
().
data
<
IntT
>
(),
x
.
nnz
(),
d_x_dims
,
nullptr
,
in_index_table_ptr
,
in_index_table_ptr
+
table_size
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
max_voxel
,
in_index_table_ptr
+
table_size
,
sizeof
(
int
),
gpuMemcpyDeviceToHost
,
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x_grad
->
nnz
()
*
max_voxel
*
kernel_size
*
2
)});
dev_ctx
,
{
static_cast
<
int
>
(
x_grad
->
nnz
()
*
kernel_size
*
2
)});
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x
.
nnz
()
*
2
)});
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
...
...
@@ -169,7 +135,7 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
x
.
nnz
(),
kernel_size
*
max_voxel
,
kernel_size
,
offsets
[
kernel_size
/
2
],
rulebook_ptr
,
out_index_ptr
,
...
...
@@ -181,7 +147,6 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
max_voxel
,
in_channels
,
2
,
in_features_ptr
);
...
...
@@ -242,7 +207,6 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
unique_value
.
data
<
int
>
(),
x_grad
->
nnz
(),
kernel_size
,
max_voxel
,
in_channels
,
2
,
x_grad_values_ptr
);
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
8fbe97e4
...
...
@@ -66,7 +66,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
DenseTensor
h_counter
,
h_offsets
;
h_counter
.
Resize
({
kernel_size
+
1
});
h_counter
.
Resize
({
kernel_size
});
h_offsets
.
Resize
({
kernel_size
+
1
});
int
*
h_counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_counter
);
int
*
h_offsets_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_offsets
);
...
...
@@ -74,7 +74,7 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
DenseTensor
counter_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
+
1
});
DenseTensor
counter_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensor
offsets_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
...
...
@@ -143,6 +143,26 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
if
(
subm
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
unique_value
.
ResizeAndAllocate
(
{
static_cast
<
int
>
(
out
->
nnz
()
*
kernel_size
)});
out_index
.
ResizeAndAllocate
({
static_cast
<
int
>
(
rulebook_len
)});
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
rulebook_len
,
dev_ctx
.
stream
());
GroupIndexs
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
kernel_size
,
rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
}
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter_ptr
[
i
]
<=
0
)
{
...
...
@@ -176,7 +196,6 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
unique_value
.
data
<
int
>
(),
out
->
nnz
(),
kernel_size
,
h_counter_ptr
[
kernel_size
],
out_channels
,
1
,
out_values_ptr
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录