Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9841b308
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9841b308
编写于
7月 26, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
7月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize sparse convolution (#43576)
上级
22342d51
变更
27
展开全部
隐藏空白更改
内联
并排
Showing
27 changed file
with
1474 addition
and
345 deletion
+1474
-345
paddle/phi/api/yaml/sparse_api.yaml
paddle/phi/api/yaml/sparse_api.yaml
+9
-9
paddle/phi/api/yaml/sparse_bw_api.yaml
paddle/phi/api/yaml/sparse_bw_api.yaml
+7
-7
paddle/phi/core/sparse_coo_tensor.h
paddle/phi/core/sparse_coo_tensor.h
+50
-0
paddle/phi/kernels/funcs/sparse/convolution.h
paddle/phi/kernels/funcs/sparse/convolution.h
+83
-0
paddle/phi/kernels/funcs/sparse/scatter.cu.h
paddle/phi/kernels/funcs/sparse/scatter.cu.h
+110
-12
paddle/phi/kernels/sparse/conv_grad_kernel.h
paddle/phi/kernels/sparse/conv_grad_kernel.h
+10
-1
paddle/phi/kernels/sparse/conv_kernel.h
paddle/phi/kernels/sparse/conv_kernel.h
+9
-3
paddle/phi/kernels/sparse/cpu/conv.h
paddle/phi/kernels/sparse/cpu/conv.h
+3
-4
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
+22
-13
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
+65
-38
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
+8
-7
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
+13
-9
paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu
paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu
+29
-10
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+760
-0
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+77
-73
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+110
-87
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
+1
-0
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
+10
-22
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
+24
-19
paddle/phi/kernels/sparse/pool_grad_kernel.h
paddle/phi/kernels/sparse/pool_grad_kernel.h
+3
-1
paddle/phi/kernels/sparse/pool_kernel.h
paddle/phi/kernels/sparse/pool_kernel.h
+13
-4
paddle/phi/tests/api/test_sparse_conv_api.cc
paddle/phi/tests/api/test_sparse_conv_api.cc
+2
-2
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
+16
-9
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
+15
-7
python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
+4
-4
python/paddle/incubate/sparse/nn/functional/conv.py
python/paddle/incubate/sparse/nn/functional/conv.py
+11
-4
python/paddle/incubate/sparse/nn/layer/conv.py
python/paddle/incubate/sparse/nn/layer/conv.py
+10
-0
未找到文件。
paddle/phi/api/yaml/sparse_api.yaml
浏览文件 @
9841b308
...
@@ -80,14 +80,14 @@
...
@@ -80,14 +80,14 @@
data_type
:
x
data_type
:
x
backward
:
cast_grad
backward
:
cast_grad
-
api
:
conv3d
-
api
:
conv3d
_coo
args
:
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm)
args
:
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm
, str key
)
output
:
Tensor(out), Tensor(rulebook)
output
:
Tensor(out), Tensor(rulebook)
, Tensor(counter)
kernel
:
kernel
:
func
:
conv3d_coo{sparse_coo, dense -> sparse_coo, dense}
func
:
conv3d_coo{sparse_coo, dense -> sparse_coo, dense
, dense
}
layout
:
x
layout
:
x
intermediate
:
rulebook
intermediate
:
rulebook, counter
backward
:
conv3d_grad
backward
:
conv3d_
coo_
grad
-
api
:
coo_to_dense
-
api
:
coo_to_dense
args
:
(Tensor x)
args
:
(Tensor x)
...
@@ -352,11 +352,11 @@
...
@@ -352,11 +352,11 @@
-
api
:
maxpool
-
api
:
maxpool
args
:
(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
args
:
(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output
:
Tensor(out), Tensor(rulebook)
output
:
Tensor(out), Tensor(rulebook)
, Tensor(counter)
kernel
:
kernel
:
func
:
maxpool_coo{sparse_coo -> sparse_coo, dense}
func
:
maxpool_coo{sparse_coo -> sparse_coo, dense
, dense
}
layout
:
x
layout
:
x
intermediate
:
rulebook
intermediate
:
rulebook
, counter
backward
:
maxpool_grad
backward
:
maxpool_grad
-
api
:
mv
-
api
:
mv
...
...
paddle/phi/api/yaml/sparse_bw_api.yaml
浏览文件 @
9841b308
...
@@ -81,12 +81,12 @@
...
@@ -81,12 +81,12 @@
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr}
data_type
:
out_grad
data_type
:
out_grad
-
backward_api
:
conv3d_grad
-
backward_api
:
conv3d_
coo_
grad
forward
:
conv3d
(Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTenso
r)
forward
:
conv3d
_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counte
r)
args
:
(Tensor x, Tensor kernel, Tensor
rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm
)
args
:
(Tensor x, Tensor kernel, Tensor
out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key
)
output
:
Tensor(x_grad), Tensor(kernel_grad)
output
:
Tensor(x_grad), Tensor(kernel_grad)
kernel
:
kernel
:
func
:
conv3d_coo_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense}
func
:
conv3d_coo_grad{sparse_coo, dense,
sparse_coo, dense,
dense, sparse_coo -> sparse_coo, dense}
-
backward_api
:
coo_to_dense_grad
-
backward_api
:
coo_to_dense_grad
forward
:
coo_to_dense(Tensor x) -> Tensor(out)
forward
:
coo_to_dense(Tensor x) -> Tensor(out)
...
@@ -164,11 +164,11 @@
...
@@ -164,11 +164,11 @@
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
-
backward_api
:
maxpool_grad
-
backward_api
:
maxpool_grad
forward
:
maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
forward
:
maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
, Tensor(counter)
args
:
(Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
args
:
(Tensor x, Tensor rulebook, Tensor
counter, Tensor
out, Tensor out_grad, int[] kernel_sizes)
output
:
Tensor(x_grad)
output
:
Tensor(x_grad)
kernel
:
kernel
:
func
:
maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
func
:
maxpool_coo_grad {sparse_coo, dense,
dense,
sparse_coo, sparse_coo -> sparse_coo}
-
backward_api
:
multiply_grad
-
backward_api
:
multiply_grad
forward
:
multiply(Tensor x, Tensor y) -> Tensor(out)
forward
:
multiply(Tensor x, Tensor y) -> Tensor(out)
...
...
paddle/phi/core/sparse_coo_tensor.h
浏览文件 @
9841b308
...
@@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase,
...
@@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase,
/// \brief get the dnese dim
/// \brief get the dnese dim
int32_t
dense_dim
()
const
;
int32_t
dense_dim
()
const
;
/// \brief query table according to key
const
std
::
pair
<
DenseTensor
,
DenseTensor
>*
IndicesPairs
(
const
std
::
string
&
key
)
const
{
if
(
indices_dict_
==
nullptr
)
{
return
nullptr
;
}
const
auto
&
iter
=
indices_dict_
->
find
(
key
);
if
(
iter
==
indices_dict_
->
end
())
{
return
nullptr
;
}
return
&
iter
->
second
;
}
/// \brief save (key, indices_pairs)
void
SaveIndicesPairs
(
const
std
::
string
&
key
,
const
std
::
pair
<
DenseTensor
,
DenseTensor
>&
indices_pairs
)
{
if
(
indices_dict_
==
nullptr
)
{
indices_dict_
=
std
::
make_shared
<
std
::
map
<
std
::
string
,
std
::
pair
<
DenseTensor
,
DenseTensor
>>>
();
}
auto
ret
=
indices_dict_
->
insert
({
key
,
indices_pairs
});
if
(
ret
.
second
==
false
)
{
ret
.
first
->
second
=
indices_pairs
;
}
}
/// \brief get indices_dict_
const
std
::
shared_ptr
<
std
::
map
<
std
::
string
,
std
::
pair
<
DenseTensor
,
DenseTensor
>>>&
GetIndicesDict
()
const
{
return
indices_dict_
;
}
/// \brief set indices_dict_
void
SetIndicesDict
(
const
std
::
shared_ptr
<
std
::
map
<
std
::
string
,
std
::
pair
<
DenseTensor
,
DenseTensor
>>>&
indices_dict
)
{
indices_dict_
=
indices_dict
;
}
private:
private:
// save the indices of non zero elements in original dense tensor
// save the indices of non zero elements in original dense tensor
DenseTensor
non_zero_indices_
;
DenseTensor
non_zero_indices_
;
...
@@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase,
...
@@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase,
bool
coalesced_
=
false
;
bool
coalesced_
=
false
;
// save the number of non zero elements in each batch
// save the number of non zero elements in each batch
DDim
dims_
;
DDim
dims_
;
// for submanifold conv
// SubmConv will generate a rulebook and a counter, which can be
// reused by different SubmConv.
// refer to sparse/gpu/convolution_kernel.cu.
std
::
shared_ptr
<
std
::
map
<
std
::
string
,
std
::
pair
<
DenseTensor
,
DenseTensor
>>>
indices_dict_
=
nullptr
;
/* --------------------------- */
/* --------------------------- */
/* example: non zero element is scalar */
/* example: non zero element is scalar */
/* --------------------------- */
/* --------------------------- */
...
...
paddle/phi/kernels/funcs/sparse/convolution.h
浏览文件 @
9841b308
...
@@ -15,6 +15,7 @@ limitations under the License. */
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
namespace
phi
{
namespace
phi
{
...
@@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) {
...
@@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) {
offsets
[
n
]
=
offset
;
offsets
[
n
]
=
offset
;
}
}
template
<
typename
IntT
>
inline
const
IntT
*
GetRulebookPtr
(
const
SparseCooTensor
&
coo
,
const
DenseTensor
&
rulebook
,
const
std
::
string
&
key
,
int
*
rulebook_len
)
{
if
(
!
key
.
empty
())
{
const
auto
*
indices_pairs
=
coo
.
IndicesPairs
(
key
);
if
(
indices_pairs
!=
nullptr
)
{
const
DenseTensor
&
tmp_rulebook
=
indices_pairs
->
first
;
*
rulebook_len
=
tmp_rulebook
.
dims
()[
1
];
return
tmp_rulebook
.
data
<
IntT
>
();
}
}
*
rulebook_len
=
rulebook
.
dims
()[
1
];
return
rulebook
.
data
<
IntT
>
();
}
inline
const
int
*
GetCounterPtr
(
const
SparseCooTensor
&
coo
,
const
DenseTensor
&
counter
,
const
std
::
string
&
key
)
{
if
(
!
key
.
empty
())
{
const
auto
*
indices_pairs
=
coo
.
IndicesPairs
(
key
);
if
(
indices_pairs
!=
nullptr
)
{
return
indices_pairs
->
second
.
data
<
int
>
();
}
}
return
counter
.
data
<
int
>
();
}
template
<
typename
T
,
typename
IntT
,
typename
Context
>
inline
const
IntT
*
PrepareSubm
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
string
&
key
,
const
DDim
&
out_dims
,
SparseCooTensor
*
out
,
int
*
counter
,
int
*
offsets
,
int
*
rulebook_len
,
bool
*
need_product_rulebook
)
{
const
auto
*
indices_pairs
=
x
.
IndicesPairs
(
key
);
if
(
indices_pairs
!=
nullptr
)
{
*
need_product_rulebook
=
false
;
const
DenseTensor
&
rulebook
=
indices_pairs
->
first
;
const
int
counter_size
=
indices_pairs
->
second
.
numel
();
memcpy
(
counter
,
indices_pairs
->
second
.
data
<
int
>
(),
counter_size
*
sizeof
(
int
));
out
->
SetIndicesDict
(
x
.
GetIndicesDict
());
*
rulebook_len
=
rulebook
.
dims
()[
1
];
DenseTensor
out_indices
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_indices
());
DenseTensor
out_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
phi
::
Copy
(
dev_ctx
,
x
.
non_zero_indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
out
->
SetMember
(
out_indices
,
out_values
,
out_dims
,
false
);
PrefixSum
<
int
>
(
counter
,
offsets
,
counter_size
);
return
rulebook
.
data
<
IntT
>
();
}
return
nullptr
;
}
template
<
typename
Context
>
inline
void
SaveToTable
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
string
&
key
,
const
DenseTensor
&
in_rulebook
,
const
DenseTensor
&
h_counter
,
SparseCooTensor
*
out
,
DenseTensor
*
out_rulebook
,
DenseTensor
*
counter
)
{
out
->
SetIndicesDict
(
x
.
GetIndicesDict
());
if
(
!
key
.
empty
())
{
out
->
SaveIndicesPairs
(
key
,
std
::
make_pair
(
in_rulebook
,
h_counter
));
}
else
{
*
out_rulebook
=
in_rulebook
;
counter
->
Resize
({
h_counter
.
numel
()});
int
*
counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
counter
);
memcpy
(
counter_ptr
,
h_counter
.
data
<
int
>
(),
h_counter
.
numel
()
*
sizeof
(
int
));
}
}
}
// namespace sparse
}
// namespace sparse
}
// namespace funcs
}
// namespace funcs
}
// namespace phi
}
// namespace phi
paddle/phi/kernels/funcs/sparse/scatter.cu.h
浏览文件 @
9841b308
...
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
...
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#pragma once
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#define VecBytes 16
namespace
phi
{
namespace
phi
{
namespace
funcs
{
namespace
funcs
{
...
@@ -28,33 +33,126 @@ namespace sparse {
...
@@ -28,33 +33,126 @@ namespace sparse {
* channels: the output channel size
* channels: the output channel size
* out: the outputs
* out: the outputs
**/
**/
template
<
typename
T
>
template
<
typename
T
,
int
VecSize
>
__global__
void
ScatterKernel
(
const
T
*
input
,
__global__
void
ScatterKernel
(
const
T
*
input
,
const
int
*
unique_value
,
const
int
*
unique_value
,
const
int
*
out_index
,
const
int
*
out_index
,
const
int
non_zero_num
,
const
int
non_zero_num
,
const
int
rulebook_len
,
const
int
rulebook_len
,
const
int
channels
,
const
int
channels
,
T
*
out
,
T
*
out
)
{
const
bool
subm
=
false
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
const
int
vec_channels
=
channels
/
VecSize
;
int
indices_i
=
i
/
channels
;
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
int
channels_i
=
i
-
indices_i
*
channels
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
vec_channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
indices_i
=
i
/
vec_channels
;
int
channels_i
=
i
-
indices_i
*
vec_channels
;
int
start
=
unique_value
[
indices_i
];
int
start
=
unique_value
[
indices_i
];
int
end
=
indices_i
==
non_zero_num
-
1
?
rulebook_len
int
end
=
indices_i
==
non_zero_num
-
1
?
rulebook_len
:
unique_value
[
indices_i
+
1
];
:
unique_value
[
indices_i
+
1
];
// max(end-start) = kernel_size
// max(end-start) = kernel_size
T
sum
=
static_cast
<
T
>
(
0
);
StoreT
sums
=
{
static_cast
<
T
>
(
0
)};
if
(
subm
)
{
sum
=
out
[
indices_i
*
channels
+
channels_i
];
}
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
for
(
int
j
=
start
;
j
<
end
;
j
++
)
{
const
int
out_feature_i
=
out_index
[
j
];
const
int
out_feature_i
=
out_index
[
j
];
sum
+=
input
[
out_feature_i
*
channels
+
channels_i
];
LoadT
vec_in
;
phi
::
Load
<
T
,
VecSize
>
(
input
+
out_feature_i
*
channels
+
channels_i
*
VecSize
,
&
vec_in
);
#pragma unroll
for
(
int
k
=
0
;
k
<
VecSize
;
k
++
)
{
sums
[
k
]
+=
vec_in
[
k
];
}
}
}
out
[
indices_i
*
channels
+
channels_i
]
=
sum
;
phi
::
Store
<
T
,
VecSize
>
(
sums
,
out
+
indices_i
*
channels
+
channels_i
*
VecSize
);
}
}
// scatter's index has been grouped in advance
// index_counts record the count of each group
// index_groups save the index of each group
template
<
typename
T
,
int
VecSize
>
__global__
void
ScatterKernelV2
(
const
T
*
input
,
const
int
*
index_counts
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
channels
,
const
int
buffer_counts
,
T
*
out
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
const
int
vec_channels
=
channels
/
VecSize
;
using
LoadT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
using
StoreT
=
phi
::
AlignedVector
<
T
,
VecSize
>
;
for
(
int
i
=
tid
;
i
<
non_zero_num
*
vec_channels
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
indices_i
=
i
/
vec_channels
;
int
channels_i
=
i
-
indices_i
*
vec_channels
;
StoreT
sums
=
{
static_cast
<
T
>
(
0
)};
phi
::
Load
<
T
,
VecSize
>
(
out
+
indices_i
*
channels
+
channels_i
*
VecSize
,
&
sums
);
for
(
int
it
=
0
;
it
<
buffer_counts
;
it
++
)
{
int
len
=
index_counts
[
indices_i
+
it
*
non_zero_num
];
const
int
group_offset
=
it
*
kernel_size
*
non_zero_num
;
for
(
int
j
=
0
;
j
<
len
;
j
++
)
{
const
int
out_feature_i
=
index_groups
[
indices_i
*
kernel_size
+
j
+
group_offset
];
LoadT
vec_in
;
phi
::
Load
<
T
,
VecSize
>
(
input
+
out_feature_i
*
channels
+
channels_i
*
VecSize
,
&
vec_in
);
#pragma unroll
for
(
int
k
=
0
;
k
<
VecSize
;
k
++
)
{
sums
[
k
]
+=
vec_in
[
k
];
}
}
}
phi
::
Store
<
T
,
VecSize
>
(
sums
,
out
+
indices_i
*
channels
+
channels_i
*
VecSize
);
}
}
template
<
typename
T
>
void
ScatterV2
(
const
GPUContext
&
dev_ctx
,
const
T
*
input
,
const
int
*
index_counts
,
const
int
*
index_groups
,
const
int
non_zero_num
,
const
int
kernel_size
,
const
int
channels
,
const
int
buffer_counts
,
T
*
output
)
{
const
int
VecSize
=
VecBytes
/
sizeof
(
T
);
if
(
channels
%
VecSize
==
0
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
*
channels
/
VecSize
,
1
);
ScatterKernelV2
<
T
,
VecSize
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
input
,
index_counts
,
index_groups
,
non_zero_num
,
kernel_size
,
channels
,
buffer_counts
,
output
);
}
else
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
non_zero_num
*
channels
,
1
);
ScatterKernelV2
<
T
,
1
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
input
,
index_counts
,
index_groups
,
non_zero_num
,
kernel_size
,
channels
,
buffer_counts
,
output
);
}
}
}
}
...
...
paddle/phi/kernels/sparse/conv_grad_kernel.h
浏览文件 @
9841b308
...
@@ -25,13 +25,16 @@ template <typename T, typename Context>
...
@@ -25,13 +25,16 @@ template <typename T, typename Context>
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
x_grad
,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
);
DenseTensor
*
kernel_grad
);
...
@@ -40,13 +43,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
...
@@ -40,13 +43,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
const
Context
&
dev_ctx
,
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
)
{
const
bool
subm
,
const
std
::
string
&
key
)
{
SparseCooTensor
x_grad
;
SparseCooTensor
x_grad
;
DenseTensor
kernel_grad
;
DenseTensor
kernel_grad
;
...
@@ -54,13 +60,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
...
@@ -54,13 +60,16 @@ std::tuple<SparseCooTensor, DenseTensor> Conv3dCooGrad(
Conv3dCooGradKernel
<
T
,
Context
>
(
dev_ctx
,
Conv3dCooGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
x
,
kernel
,
kernel
,
out
,
rulebook
,
rulebook
,
counter
,
out_grad
,
out_grad
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
&
x_grad
,
&
x_grad
,
&
kernel_grad
);
&
kernel_grad
);
return
std
::
make_tuple
(
x_grad
,
kernel_grad
);
return
std
::
make_tuple
(
x_grad
,
kernel_grad
);
...
...
paddle/phi/kernels/sparse/conv_kernel.h
浏览文件 @
9841b308
...
@@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
...
@@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
);
DenseTensor
*
rulebook
,
DenseTensor
*
counter
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
SparseCooTensor
Conv3dCoo
(
const
Context
&
dev_ctx
,
SparseCooTensor
Conv3dCoo
(
const
Context
&
dev_ctx
,
...
@@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
...
@@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
DenseTensor
*
rulebook
)
{
const
std
::
string
&
key
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
SparseCooTensor
coo
;
SparseCooTensor
coo
;
Conv3dCooKernel
<
T
,
Context
>
(
dev_ctx
,
Conv3dCooKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
x
,
...
@@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
...
@@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
&
coo
,
&
coo
,
rulebook
);
rulebook
,
counter
);
return
coo
;
return
coo
;
}
}
...
...
paddle/phi/kernels/sparse/cpu/conv
olution
.h
→
paddle/phi/kernels/sparse/cpu/conv.h
浏览文件 @
9841b308
...
@@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx,
const
DDim
&
out_dims
,
const
DDim
&
out_dims
,
const
bool
subm
,
const
bool
subm
,
DenseTensor
*
rulebook
,
DenseTensor
*
rulebook
,
DenseTensor
*
counter_per_kernel
)
{
int
*
counter_per_kernel
)
{
const
int64_t
non_zero_num
=
x
.
nnz
();
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
auto
&
non_zero_indices
=
x
.
non_zero_indices
();
const
IntT
*
indices_ptr
=
non_zero_indices
.
data
<
IntT
>
();
const
IntT
*
indices_ptr
=
non_zero_indices
.
data
<
IntT
>
();
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
memset
(
counter_p
tr
,
0
,
kernel_size
*
sizeof
(
int
));
memset
(
counter_p
er_kernel
,
0
,
kernel_size
*
sizeof
(
int
));
int
rulebook_len
=
0
;
int
rulebook_len
=
0
;
// calc the rulebook_len
// calc the rulebook_len
...
@@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx,
...
@@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx,
}
}
if
(
rulebook_ptr
==
nullptr
)
{
if
(
rulebook_ptr
==
nullptr
)
{
counter_p
tr
[
kernel_index
-
1
]
+=
1
;
counter_p
er_kernel
[
kernel_index
-
1
]
+=
1
;
++
rulebook_len
;
++
rulebook_len
;
}
else
{
}
else
{
rulebook_ptr
[
rulebook_index
]
=
kernel_index
-
1
;
rulebook_ptr
[
rulebook_index
]
=
kernel_index
-
1
;
...
...
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
浏览文件 @
9841b308
...
@@ -17,7 +17,7 @@ limitations under the License. */
...
@@ -17,7 +17,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/sparse/cpu/conv
olution
.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -34,22 +34,27 @@ template <typename T, typename IntT = int>
...
@@ -34,22 +34,27 @@ template <typename T, typename IntT = int>
void
Conv3dCooGradCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
Conv3dCooGradCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
x_grad
,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
out_channels
=
kernel_dims
[
4
];
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
const
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
0
;
const
IntT
*
rulebook_ptr
=
phi
::
funcs
::
sparse
::
GetRulebookPtr
<
IntT
>
(
out
,
rulebook
,
key
,
&
rulebook_len
);
const
int
*
counter_ptr
=
phi
::
funcs
::
sparse
::
GetCounterPtr
(
out
,
counter
,
key
);
DenseTensorMeta
in_features_meta
(
DenseTensorMeta
in_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
...
@@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
&
x_grad_indices
);
&
x_grad_indices
);
x_grad
->
SetMember
(
x_grad_indices
,
x_grad_values
,
x
.
dims
(),
true
);
x_grad
->
SetMember
(
x_grad_indices
,
x_grad_values
,
x
.
dims
(),
true
);
std
::
vector
<
IntT
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
);
std
::
vector
<
IntT
>
offsets
(
kernel_size
+
1
);
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
IntT
offset
=
0
;
counter
[
rulebook_ptr
[
i
]]
+=
1
;
int
max_count
=
0
;
}
IntT
offset
=
0
,
max_count
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
offset
+=
counter
_ptr
[
i
];
if
(
i
<
half_kernel_size
)
{
if
(
i
<
half_kernel_size
)
{
max_count
=
std
::
max
(
max_count
,
counter
[
i
]);
max_count
=
std
::
max
(
max_count
,
counter
_ptr
[
i
]);
}
}
}
}
offsets
[
kernel_size
]
=
offset
;
offsets
[
kernel_size
]
=
offset
;
...
@@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
if
(
counter
_ptr
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
continue
;
continue
;
}
}
const
int
M
=
counter
[
i
];
const
int
M
=
counter
_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
...
@@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
// 4. scatter
// 4. scatter
Scatter
<
T
,
IntT
>
(
d_x_features_ptr
,
Scatter
<
T
,
IntT
>
(
d_x_features_ptr
,
rulebook
.
data
<
IntT
>
()
+
rulebook_len
,
rulebook
_ptr
+
rulebook_len
,
rulebook_len
,
rulebook_len
,
in_channels
,
in_channels
,
x_grad_values_ptr
);
x_grad_values_ptr
);
...
@@ -181,13 +184,16 @@ template <typename T, typename Context>
...
@@ -181,13 +184,16 @@ template <typename T, typename Context>
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
x_grad
,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
...
@@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
...
@@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
Conv3dCooGradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
x
,
kernel
,
kernel
,
out
,
rulebook
,
rulebook
,
counter
,
out_grad
,
out_grad
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
x_grad
,
x_grad
,
kernel_grad
);
kernel_grad
);
}));
}));
...
...
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
浏览文件 @
9841b308
...
@@ -14,9 +14,10 @@ limitations under the License. */
...
@@ -14,9 +14,10 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/sparse/cpu/conv
olution
.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
// update padding and dilation
// update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1
// Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight)
// if x.layout != NDHWC then transpose(x), transpose(weight)
...
@@ -66,26 +69,50 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -66,26 +69,50 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// Second algorithm:
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
// 1. product rulebook
DenseTensorMeta
counter_meta
(
DenseTensor
h_counter
,
h_offsets
;
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
h_counter
.
Resize
({
kernel_size
});
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
h_offsets
.
Resize
({
kernel_size
+
1
});
int
*
h_counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_counter
);
ProductRuleBook
<
T
,
CPUContext
,
IntT
>
(
dev_ctx
,
int
*
h_offsets_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_offsets
);
x
,
kernel_sizes
,
// DenseTensor* rulebook = nullptr;
subm_paddings
,
const
IntT
*
rulebook_ptr
=
nullptr
;
dilations
,
int
n
=
0
;
subm_strides
,
bool
need_product_rulebook
=
true
;
out_dims
,
if
(
subm
&&
!
key
.
empty
())
{
subm
,
rulebook_ptr
=
phi
::
funcs
::
sparse
::
PrepareSubm
<
T
,
IntT
,
CPUContext
>
(
rulebook
,
dev_ctx
,
&
counter_per_kernel
);
x
,
key
,
UpdateRulebookAndOutIndex
<
T
,
CPUContext
,
IntT
>
(
out_dims
,
dev_ctx
,
x
,
kernel_size
,
out_channels
,
out_dims
,
rulebook
,
out
);
out
,
h_counter_ptr
,
int
n
=
rulebook
->
dims
()[
1
];
h_offsets_ptr
,
const
int
*
counter_ptr
=
counter_per_kernel
.
data
<
int
>
();
&
n
,
&
need_product_rulebook
);
}
if
(
need_product_rulebook
)
{
DenseTensor
tmp_rulebook
;
ProductRuleBook
<
T
,
CPUContext
,
IntT
>
(
dev_ctx
,
x
,
kernel_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
out_dims
,
subm
,
&
tmp_rulebook
,
h_counter_ptr
);
UpdateRulebookAndOutIndex
<
T
,
CPUContext
,
IntT
>
(
dev_ctx
,
x
,
kernel_size
,
out_channels
,
out_dims
,
&
tmp_rulebook
,
out
);
n
=
tmp_rulebook
.
dims
()[
1
];
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
phi
::
funcs
::
sparse
::
SaveToTable
(
dev_ctx
,
x
,
key
,
tmp_rulebook
,
h_counter
,
out
,
rulebook
,
counter
);
}
// int n = rulebook->dims()[1];
// 2. gather
// 2. gather
DenseTensorMeta
in_features_meta
(
DenseTensorMeta
in_features_meta
(
...
@@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
Gather
<
T
,
IntT
>
(
x
.
non_zero_elements
().
data
<
T
>
(),
Gather
<
T
,
IntT
>
(
x
.
non_zero_elements
().
data
<
T
>
(),
rulebook
->
data
<
IntT
>
()
+
n
,
rulebook
_ptr
+
n
,
n
,
n
,
in_channels
,
in_channels
,
in_features_ptr
);
in_features_ptr
);
// 3. call gemm for every werght
// 3. call gemm for every werght
auto
blas
=
phi
::
funcs
::
GetBlas
<
CPUContext
,
T
>
(
dev_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
CPUContext
,
T
>
(
dev_ctx
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
int
offset
=
0
;
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
h_offsets_ptr
[
i
]
=
offset
;
offset
+=
counter_ptr
[
i
];
offset
+=
h_
counter_ptr
[
i
];
}
}
offsets
[
kernel_size
]
=
offset
;
h_offsets_ptr
[
kernel_size
]
=
offset
;
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter_ptr
[
i
]
<=
0
)
{
if
(
h_
counter_ptr
[
i
]
<=
0
)
{
continue
;
continue
;
}
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
// call gemm: (n, in_channels) * (in_channels, out_channels)
const
int
M
=
counter_ptr
[
i
];
const
int
M
=
h_
counter_ptr
[
i
];
const
int
K
=
in_channels
;
// in_channels
const
int
K
=
in_channels
;
// in_channels
const
int
N
=
out_channels
;
// out_channels
const
int
N
=
out_channels
;
// out_channels
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
h_offsets_ptr
[
i
]
*
in_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
T
*
tmp_out_ptr
=
out_features_ptr
+
offsets
[
i
]
*
out_channels
;
T
*
tmp_out_ptr
=
out_features_ptr
+
h_offsets_ptr
[
i
]
*
out_channels
;
blas
.
GEMM
(
CblasNoTrans
,
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
CblasNoTrans
,
M
,
M
,
...
@@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// 4. scatter
// 4. scatter
T
*
out_values_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
T
*
out_values_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
memset
(
out_values_ptr
,
0
,
sizeof
(
T
)
*
out
->
nnz
()
*
out_channels
);
memset
(
out_values_ptr
,
0
,
sizeof
(
T
)
*
out
->
nnz
()
*
out_channels
);
Scatter
<
T
,
IntT
>
(
out_features_ptr
,
Scatter
<
T
,
IntT
>
(
rulebook
->
data
<
IntT
>
()
+
n
*
2
,
out_features_ptr
,
rulebook_ptr
+
n
*
2
,
n
,
out_channels
,
out_values_ptr
);
n
,
out_channels
,
out_values_ptr
);
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
...
@@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
...
@@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"Conv3dCooCPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"Conv3dCooCPUKernel"
,
([
&
]
{
Conv3dCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
Conv3dCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
...
@@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
...
@@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
out
,
out
,
rulebook
);
rulebook
,
counter
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
浏览文件 @
9841b308
...
@@ -28,6 +28,7 @@ template <typename T, typename IntT = int>
...
@@ -28,6 +28,7 @@ template <typename T, typename IntT = int>
void
MaxPoolCooGradCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
MaxPoolCooGradCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
...
@@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
const
int
channels
=
x
.
dims
()[
4
];
const
int
channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
const
int
*
counter_ptr
=
counter
.
data
<
int
>
();
counter
[
rulebook_ptr
[
i
]]
+=
1
;
}
phi
::
funcs
::
sparse
::
PrefixSum
(
counter_ptr
,
&
offsets
[
0
],
kernel_size
);
phi
::
funcs
::
sparse
::
PrefixSum
(
&
counter
[
0
],
&
offsets
[
0
],
kernel_size
);
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
...
@@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx,
phi
::
funcs
::
MaxPoolGrad
<
T
>
grad_functor
;
phi
::
funcs
::
MaxPoolGrad
<
T
>
grad_functor
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
j
=
0
;
j
<
counter
[
i
];
j
++
)
{
for
(
int
j
=
0
;
j
<
counter
_ptr
[
i
];
j
++
)
{
IntT
in_i
=
rulebook_ptr
[
rulebook_len
+
offsets
[
i
]
+
j
];
IntT
in_i
=
rulebook_ptr
[
rulebook_len
+
offsets
[
i
]
+
j
];
IntT
out_i
=
rulebook_ptr
[
rulebook_len
*
2
+
offsets
[
i
]
+
j
];
IntT
out_i
=
rulebook_ptr
[
rulebook_len
*
2
+
offsets
[
i
]
+
j
];
for
(
int
c
=
0
;
c
<
channels
;
c
++
)
{
for
(
int
c
=
0
;
c
<
channels
;
c
++
)
{
...
@@ -78,6 +78,7 @@ template <typename T, typename Context>
...
@@ -78,6 +78,7 @@ template <typename T, typename Context>
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
...
@@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
...
@@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGradCPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGradCPUKernel"
,
([
&
]
{
MaxPoolCooGradCPUKernel
<
T
,
data_t
>
(
MaxPoolCooGradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
dev_ctx
,
x
,
rulebook
,
counter
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
浏览文件 @
9841b308
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/conv
olution
.h"
#include "paddle/phi/kernels/sparse/cpu/conv.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
const
std
::
vector
<
int
>&
real_kernel_sizes
=
...
@@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
real_kernel_sizes
[
3
];
const
int
in_channels
=
real_kernel_sizes
[
3
];
DenseTensorMeta
counter_meta
(
std
::
vector
<
int
>
counter_per_kernel
(
kernel_size
,
0
);
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
// 1. product rule book
// 1. product rule book
...
@@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
...
@@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx,
out_dims
,
out_dims
,
false
,
false
,
rulebook
,
rulebook
,
&
counter_per_kernel
);
counter_per_kernel
.
data
()
);
UpdateRulebookAndOutIndex
<
T
,
CPUContext
,
IntT
>
(
UpdateRulebookAndOutIndex
<
T
,
CPUContext
,
IntT
>
(
dev_ctx
,
x
,
kernel_size
,
in_channels
,
out_dims
,
rulebook
,
out
);
dev_ctx
,
x
,
kernel_size
,
in_channels
,
out_dims
,
rulebook
,
out
);
int
rulebook_len
=
rulebook
->
dims
()[
1
];
int
rulebook_len
=
rulebook
->
dims
()[
1
];
const
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
const
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
const
int
*
counter_ptr
=
counter_per_kernel
.
data
<
int
>
();
counter
->
Resize
({
kernel_size
});
int
*
counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
counter
);
memcpy
(
counter_ptr
,
counter_per_kernel
.
data
(),
kernel_size
*
sizeof
(
int
));
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
phi
::
funcs
::
sparse
::
PrefixSum
(
counter_ptr
,
&
offsets
[
0
],
kernel_size
);
phi
::
funcs
::
sparse
::
PrefixSum
(
counter_ptr
,
&
offsets
[
0
],
kernel_size
);
...
@@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
...
@@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooCPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooCPUKernel"
,
([
&
]
{
MaxPoolCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
MaxPoolCooCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
...
@@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
...
@@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations
,
dilations
,
strides
,
strides
,
out
,
out
,
rulebook
);
rulebook
,
counter
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu
浏览文件 @
9841b308
...
@@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
...
@@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx,
}
}
// 5. scatter the values
// 5. scatter the values
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
nnz
*
stride
,
1
);
const
int
VecSize
=
VecBytes
/
sizeof
(
T
);
phi
::
funcs
::
sparse
::
ScatterKernel
<
T
>
if
(
stride
%
VecSize
==
0
)
{
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
x_values_ptr
,
dev_ctx
,
nnz
*
stride
/
VecSize
,
1
);
public_indexs
.
data
<
int
>
(),
phi
::
funcs
::
sparse
::
ScatterKernel
<
T
,
VecSize
>
values_indexs_ptr
,
<<<
config
.
block_per_grid
,
out_nnz
,
config
.
thread_per_block
,
nnz
,
0
,
stride
,
dev_ctx
.
stream
()
>>>
(
x_values_ptr
,
out_values
.
data
<
T
>
());
public_indexs
.
data
<
int
>
(),
values_indexs_ptr
,
out_nnz
,
nnz
,
stride
,
out_values
.
data
<
T
>
());
}
else
{
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
nnz
*
stride
,
1
);
phi
::
funcs
::
sparse
::
ScatterKernel
<
T
,
1
>
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
x_values_ptr
,
public_indexs
.
data
<
int
>
(),
values_indexs_ptr
,
out_nnz
,
nnz
,
stride
,
out_values
.
data
<
T
>
());
}
// 6. convert index to coordinate
// 6. convert index to coordinate
Dim
<
DDim
::
kMaxRank
>
const_dims
;
Dim
<
DDim
::
kMaxRank
>
const_dims
;
...
...
paddle/phi/kernels/sparse/gpu/conv.cu.h
0 → 100644
浏览文件 @
9841b308
此差异已折叠。
点击以展开。
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
9841b308
...
@@ -19,13 +19,11 @@ limitations under the License. */
...
@@ -19,13 +19,11 @@ limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -42,43 +40,42 @@ template <typename T, typename IntT>
...
@@ -42,43 +40,42 @@ template <typename T, typename IntT>
void
Conv3dCooGradGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
Conv3dCooGradGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
x_grad
,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
out_channels
=
kernel_dims
[
4
];
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
const
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
0
;
const
IntT
*
rulebook_ptr
=
phi
::
funcs
::
sparse
::
GetRulebookPtr
<
IntT
>
(
out
,
rulebook
,
key
,
&
rulebook_len
);
const
int
*
counter_ptr
=
phi
::
funcs
::
sparse
::
GetCounterPtr
(
out
,
counter
,
key
);
DenseTensorMeta
in_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
DenseTensorMeta
d_x_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
in_channels
},
DataLayout
::
NCHW
);
DenseTensorMeta
out_grad_features_meta
(
x
.
dtype
(),
{
rulebook_len
,
out_channels
},
DataLayout
::
NCHW
);
phi
::
DenseTensor
in_features
=
phi
::
DenseTensor
in_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
in_features_meta
)
);
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
in_channels
}
);
phi
::
DenseTensor
d_x_features
=
phi
::
DenseTensor
d_x_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
d_x_features_meta
)
);
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
in_channels
}
);
phi
::
DenseTensor
out_grad_features
=
phi
::
DenseTensor
out_grad_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_grad_features_meta
)
);
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
out_channels
}
);
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
T
*
d_x_features_ptr
=
d_x_features
.
data
<
T
>
();
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
T
*
out_grad_features_ptr
=
out_grad_features
.
data
<
T
>
();
*
kernel_grad
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
kernel
);
*
kernel_grad
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
kernel
);
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
T
*
d_kernel_ptr
=
kernel_grad
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
set_zero
(
dev_ctx
,
kernel_grad
,
static_cast
<
T
>
(
0.0
f
));
d_kernel_ptr
,
0
,
sizeof
(
T
)
*
kernel_grad
->
numel
(),
dev_ctx
.
stream
(
));
int
half_kernel_size
=
kernel_size
/
2
;
int
half_kernel_size
=
kernel_size
/
2
;
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
...
@@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_indices
());
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
non_zero_indices
());
DenseTensor
x_grad_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
DenseTensor
x_grad_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
T
*
x_grad_values_ptr
=
x_grad_values
.
data
<
T
>
();
T
*
x_grad_values_ptr
=
x_grad_values
.
data
<
T
>
();
set_zero
(
dev_ctx
,
&
x_grad_values
,
static_cast
<
T
>
(
0.0
f
));
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
x_grad_values_ptr
,
set_zero
(
dev_ctx
,
&
d_x_features
,
static_cast
<
T
>
(
0.0
f
));
0
,
sizeof
(
T
)
*
x_grad_values
.
numel
(),
dev_ctx
.
stream
());
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
d_x_features_ptr
,
0
,
sizeof
(
T
)
*
d_x_features
.
numel
(),
dev_ctx
.
stream
());
phi
::
Copy
<
GPUContext
>
(
dev_ctx
,
phi
::
Copy
<
GPUContext
>
(
dev_ctx
,
x
.
non_zero_indices
(),
x
.
non_zero_indices
(),
dev_ctx
.
GetPlace
(),
dev_ctx
.
GetPlace
(),
...
@@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
&
x_grad_indices
);
&
x_grad_indices
);
x_grad
->
SetMember
(
x_grad_indices
,
x_grad_values
,
x
.
dims
(),
true
);
x_grad
->
SetMember
(
x_grad_indices
,
x_grad_values
,
x
.
dims
(),
true
);
std
::
vector
<
IntT
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
h_counter
(
rulebook_len
,
0
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
rulebook_ptr
,
rulebook_len
*
sizeof
(
IntT
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
int
offset
=
0
,
max_count
=
0
;
counter
[
h_counter
[
i
]]
+=
1
;
}
IntT
offset
=
0
,
max_count
=
0
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
offsets
[
i
]
=
offset
;
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
offset
+=
counter
_ptr
[
i
];
if
(
i
<
half_kernel_size
)
{
if
(
i
<
half_kernel_size
)
{
max_count
=
std
::
max
(
max_count
,
counter
[
i
]);
max_count
=
std
::
max
(
max_count
,
counter
_ptr
[
i
]);
}
}
}
}
offsets
[
kernel_size
]
=
offset
;
offsets
[
kernel_size
]
=
offset
;
...
@@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
}
}
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
auto
config
=
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
GatherKernel
<
T
,
IntT
><<<
config
.
block_per_grid
.
x
,
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
config
.
thread_per_block
.
x
,
dev_ctx
,
{
static_cast
<
int
>
(
x_grad
->
nnz
()
*
kernel_size
*
2
)});
0
,
DenseTensor
out_index
=
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
phi
::
Empty
<
int
>
(
dev_ctx
,
{
static_cast
<
int
>
(
x
.
nnz
()
*
2
)});
rulebook_ptr
+
rulebook_len
,
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
in_features_ptr
,
int
*
unique_value_ptr
=
unique_value
.
data
<
int
>
();
rulebook_len
,
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
in_channels
);
out_index_ptr
,
0
,
sizeof
(
int
)
*
x
.
nnz
()
*
2
,
dev_ctx
.
stream
()
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
GroupIndexsV2
<<<
config
.
block_per_grid
,
dev_ctx
,
rulebook_len
*
out_channels
,
1
);
config
.
thread_per_block
,
GatherKernel
<
T
,
IntT
>
0
,
<<<
config
.
block_per_grid
.
x
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
config
.
thread_per_block
.
x
,
x
.
nnz
(),
0
,
kernel_size
,
dev_ctx
.
stream
()
>>>
(
out_grad
.
non_zero_elements
().
data
<
T
>
(),
offsets
[
kernel_size
/
2
],
rulebook_ptr
+
rulebook_len
*
2
,
rulebook_ptr
,
out_grad_features_ptr
,
out_index_ptr
,
rulebook_len
,
unique_value_ptr
);
out_channels
);
GatherV2
<
T
,
IntT
>
(
dev_ctx
,
x
.
non_zero_elements
().
data
<
T
>
(),
out_index_ptr
,
unique_value_ptr
,
x
.
nnz
(),
kernel_size
,
in_channels
,
2
,
in_features_ptr
);
Gather
<
T
,
IntT
>
(
dev_ctx
,
out_grad
.
non_zero_elements
().
data
<
T
>
(),
rulebook_ptr
+
rulebook_len
,
rulebook_len
,
out_channels
,
out_grad_features_ptr
);
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
if
(
counter
_ptr
[
i
]
<=
0
||
(
subm
&&
i
==
half_kernel_size
))
{
continue
;
continue
;
}
}
const
int
M
=
counter
[
i
];
const
int
M
=
counter
_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
...
@@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
}
}
// 4. scatter
// 4. scatter
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
dev_ctx
,
rulebook_len
*
in_channels
,
1
);
d_x_features_ptr
,
out_index
.
data
<
int
>
(),
phi
::
funcs
::
ScatterCUDAKernel
<<<
config
.
block_per_grid
,
unique_value
.
data
<
int
>
(),
config
.
thread_per_block
,
x_grad
->
nnz
(),
0
,
kernel_size
,
dev_ctx
.
stream
()
>>>
(
in_channels
,
d_x_features_ptr
,
2
,
rulebook_ptr
+
rulebook_len
,
x_grad_values_ptr
);
x_grad_values_ptr
,
rulebook_len
,
in_channels
,
false
);
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
void
Conv3dCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
kernel
,
const
DenseTensor
&
kernel
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
x_grad
,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
DenseTensor
*
kernel_grad
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
...
@@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
...
@@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
Conv3dCooGradGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
Conv3dCooGradGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
x
,
kernel
,
kernel
,
out
,
rulebook
,
rulebook
,
counter
,
out_grad
,
out_grad
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
x_grad
,
x_grad
,
kernel_grad
);
kernel_grad
);
}));
}));
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
9841b308
...
@@ -21,7 +21,9 @@ limitations under the License. */
...
@@ -21,7 +21,9 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
#include "glog/logging.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
// update padding and dilation
// update padding and dilation
// Currently, only support x.layout is NDHWC, groups = 1
// Currently, only support x.layout is NDHWC, groups = 1
// if x.layout != NDHWC then transpose(x), transpose(weight)
// if x.layout != NDHWC then transpose(x), transpose(weight)
...
@@ -61,85 +65,117 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -61,85 +65,117 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
x_dims
,
kernel_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
&
out_dims
);
x_dims
,
kernel_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
&
out_dims
);
const
int
in_channels
=
kernel_dims
[
3
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
out_channels
=
kernel_dims
[
4
];
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
h_counter
(
kernel_size
);
DenseTensor
h_counter
,
h_offsets
;
h_counter
.
Resize
({
kernel_size
});
h_offsets
.
Resize
({
kernel_size
+
1
});
int
*
h_counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_counter
);
int
*
h_offsets_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
&
h_offsets
);
// Second algorithm:
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
// 1. product rulebook
// 1. product rulebook
DenseTensorMeta
counter_meta
(
DenseTensor
counter_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
offsets_per_kernel
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
kernel_size
});
DenseTensorMeta
offsets_meta
(
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
offsets_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
offsets_meta
));
VLOG
(
6
)
<<
"call SubmConv3D or Conv3D "
<<
subm
<<
" and the key is "
<<
key
;
DenseTensorMeta
index_meta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
);
int
rulebook_len
=
0
;
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
const
IntT
*
rulebook_ptr
=
nullptr
;
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
bool
need_product_rulebook
=
true
;
if
(
subm
&&
!
key
.
empty
())
{
int
n
=
ProductRuleBook
<
T
,
GPUContext
,
IntT
>
(
dev_ctx
,
rulebook_ptr
=
phi
::
funcs
::
sparse
::
PrepareSubm
<
T
,
IntT
,
GPUContext
>
(
x
,
dev_ctx
,
kernel_sizes
,
x
,
subm_paddings
,
key
,
dilations
,
out_dims
,
subm_strides
,
out
,
out_dims
,
h_counter
.
data
<
int
>
(),
subm
,
h_offsets
.
data
<
int
>
(),
rulebook
,
&
rulebook_len
,
&
counter_per_kernel
,
&
need_product_rulebook
);
&
offsets_per_kernel
,
}
&
out_index
,
&
unique_value
,
if
(
need_product_rulebook
)
{
out
,
DenseTensor
tmp_rulebook
;
&
h_counter
,
rulebook_len
=
ProductRuleBook
<
T
,
GPUContext
,
IntT
>
(
dev_ctx
,
&
offsets
);
x
,
kernel_sizes
,
const
int
*
counter_ptr
=
counter_per_kernel
.
data
<
int
>
();
subm_paddings
,
const
int
*
offsets_ptr
=
counter_per_kernel
.
data
<
int
>
();
dilations
,
const
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
subm_strides
,
out_dims
,
subm
,
&
tmp_rulebook
,
&
counter_per_kernel
,
&
offsets_per_kernel
,
&
out_index
,
&
unique_value
,
out
,
h_counter_ptr
,
h_offsets_ptr
);
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
phi
::
funcs
::
sparse
::
SaveToTable
(
dev_ctx
,
x
,
key
,
tmp_rulebook
,
h_counter
,
out
,
rulebook
,
counter
);
}
// 2. gather
// 2. gather
DenseTensorMeta
in_features_meta
(
x
.
dtype
(),
{
n
,
in_channels
},
DataLayout
::
NCHW
);
DenseTensorMeta
out_features_meta
(
x
.
dtype
(),
{
n
,
out_channels
},
DataLayout
::
NCHW
);
phi
::
DenseTensor
in_features
=
phi
::
DenseTensor
in_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
in_features_meta
)
);
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
in_channels
}
);
phi
::
DenseTensor
out_features
=
phi
::
DenseTensor
out_features
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
out_features_meta
)
);
phi
::
Empty
<
T
>
(
dev_ctx
,
{
rulebook_len
,
out_channels
}
);
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
in_features_ptr
=
in_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
T
*
out_features_ptr
=
out_features
.
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
phi
::
funcs
::
SetConstant
<
GPUContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
&
out_features
,
static_cast
<
T
>
(
0.0
f
));
set_zero
(
dev_ctx
,
&
out_features
,
static_cast
<
T
>
(
0.0
f
));
auto
config
=
Gather
<
T
,
IntT
>
(
dev_ctx
,
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
*
in_channels
,
1
);
x
.
non_zero_elements
().
data
<
T
>
(),
GatherKernel
<
T
,
IntT
><<<
config
.
block_per_grid
.
x
,
rulebook_ptr
,
config
.
thread_per_block
.
x
,
rulebook_len
,
0
,
in_channels
,
dev_ctx
.
stream
()
>>>
(
x
.
non_zero_elements
().
data
<
T
>
(),
in_features_ptr
);
rulebook_ptr
+
n
,
in_features_ptr
,
n
,
in_channels
);
// 3. call gemm for every werght
// 3. call gemm for every werght
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
GPUContext
,
T
>
(
dev_ctx
);
auto
*
out_values
=
out
->
mutable_non_zero_elements
();
auto
*
out_values
=
out
->
mutable_non_zero_elements
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
T
*
out_values_ptr
=
out_values
->
data
<
T
>
();
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
if
(
subm
)
{
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
unique_value
.
ResizeAndAllocate
(
{
static_cast
<
int
>
(
out
->
nnz
()
*
kernel_size
)});
out_index
.
ResizeAndAllocate
({
static_cast
<
int
>
(
rulebook_len
)});
int
*
out_index_ptr
=
out_index
.
data
<
int
>
();
int
*
unique_value_ptr
=
unique_value
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
out_index_ptr
,
0
,
sizeof
(
int
)
*
rulebook_len
,
dev_ctx
.
stream
());
GroupIndexs
<<<
config
.
block_per_grid
,
config
.
thread_per_block
,
0
,
dev_ctx
.
stream
()
>>>
(
rulebook_len
,
kernel_size
,
rulebook_ptr
+
rulebook_len
,
out_index_ptr
,
unique_value_ptr
);
}
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
const
T
*
kernel_ptr
=
kernel
.
data
<
T
>
();
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
h_counter
[
i
]
<=
0
)
{
if
(
h_counter
_ptr
[
i
]
<=
0
)
{
continue
;
continue
;
}
}
// call gemm: (n, in_channels) * (in_channels, out_channels)
// call gemm: (n, in_channels) * (in_channels, out_channels)
const
int
M
=
h_counter
[
i
];
const
int
M
=
h_counter
_ptr
[
i
];
const
int
K
=
in_channels
;
const
int
K
=
in_channels
;
const
int
N
=
out_channels
;
const
int
N
=
out_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
offsets
[
i
]
*
in_channels
;
T
*
tmp_in_ptr
=
in_features_ptr
+
h_offsets_ptr
[
i
]
*
in_channels
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
const
T
*
tmp_kernel_ptr
=
kernel_ptr
+
i
*
K
*
N
;
T
*
tmp_out_ptr
=
out_features_ptr
+
offsets
[
i
]
*
out_channels
;
T
*
tmp_out_ptr
=
out_features_ptr
+
h_offsets_ptr
[
i
]
*
out_channels
;
blas
.
GEMM
(
CblasNoTrans
,
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
CblasNoTrans
,
...
@@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
}
}
// 4. scatter
// 4. scatter
if
(
subm
)
{
phi
::
funcs
::
sparse
::
ScatterV2
<
T
>
(
dev_ctx
,
set_zero
(
dev_ctx
,
out_values
,
static_cast
<
T
>
(
0.0
f
));
out_features_ptr
,
config
=
out_index
.
data
<
int
>
(),
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
n
*
out_channels
,
1
);
unique_value
.
data
<
int
>
(),
phi
::
funcs
::
ScatterCUDAKernel
<
T
,
IntT
>
out
->
nnz
(),
<<<
config
.
block_per_grid
,
kernel_size
,
config
.
thread_per_block
,
out_channels
,
0
,
1
,
dev_ctx
.
stream
()
>>>
(
out_features_ptr
,
out_values_ptr
);
rulebook_ptr
+
2
*
n
,
out_values_ptr
,
n
,
out_channels
,
false
);
}
else
{
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
out
->
nnz
()
*
out_channels
,
1
);
phi
::
funcs
::
sparse
::
ScatterKernel
<
T
>
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
out_features_ptr
,
unique_value
.
data
<
int
>
(),
out_index
.
data
<
int
>
(),
out
->
nnz
(),
n
,
out_channels
,
out_values_ptr
);
}
}
}
/**
/**
* x: (N, D, H, W, C)
* x: the input SparseCooTensor, shape is (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
* kernel: the weight data, shape is (D, H, W, C, OC)
* out: (N, D, H, W, OC)
* out: the output SparseCooTensor, shape is (N, D, H, W, OC)
* rulebook: return rulebook if key is not vailed else return nullptr
* counter: return counter if key is not vailed else return nullptr
**/
**/
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
Conv3dCooKernel
(
const
Context
&
dev_ctx
,
void
Conv3dCooKernel
(
const
Context
&
dev_ctx
,
...
@@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
...
@@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
const
int
groups
,
const
int
groups
,
const
bool
subm
,
const
bool
subm
,
const
std
::
string
&
key
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"Conv3dCooGPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"Conv3dCooGPUKernel"
,
([
&
]
{
Conv3dCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
Conv3dCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
...
@@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
...
@@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx,
strides
,
strides
,
groups
,
groups
,
subm
,
subm
,
key
,
out
,
out
,
rulebook
);
rulebook
,
counter
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/gpu/mask_kernel.cu
浏览文件 @
9841b308
...
@@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
...
@@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
x_indexs_ptr
,
x_indexs
.
numel
(),
table
.
data
<
int
>
());
x_indexs_ptr
,
x_indexs
.
numel
(),
table
.
data
<
int
>
());
config
=
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
mask_indexs
.
numel
(),
1
);
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
mask_indexs
.
numel
(),
1
);
const
int
VecBytes
=
16
;
const
int
VecBytes
=
16
;
const
int
VecSize
=
VecBytes
/
sizeof
(
T
);
const
int
VecSize
=
VecBytes
/
sizeof
(
T
);
if
(
stride
%
VecSize
==
0
)
{
if
(
stride
%
VecSize
==
0
)
{
...
...
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
浏览文件 @
9841b308
...
@@ -55,6 +55,7 @@ template <typename T, typename IntT = int>
...
@@ -55,6 +55,7 @@ template <typename T, typename IntT = int>
void
MaxPoolCooGradGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
MaxPoolCooGradGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
...
@@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
const
int
in_channels
=
x
.
dims
()[
4
];
const
int
in_channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
std
::
vector
<
IntT
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
h_counter
(
rulebook_len
,
0
);
const
int
*
counter_ptr
=
counter
.
data
<
int
>
();
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
phi
::
funcs
::
sparse
::
PrefixSum
(
counter_ptr
,
&
offsets
[
0
],
kernel_size
);
rulebook_ptr
,
rulebook_len
*
sizeof
(
IntT
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
h_counter
[
i
]]
+=
1
;
}
phi
::
funcs
::
sparse
::
PrefixSum
(
&
counter
[
0
],
&
offsets
[
0
],
kernel_size
);
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
...
@@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
&
x_grad_indices
);
&
x_grad_indices
);
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
if
(
counter
_ptr
[
i
]
<=
0
)
{
continue
;
continue
;
}
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
counter
[
i
]
*
in_channels
,
1
);
dev_ctx
,
counter
_ptr
[
i
]
*
in_channels
,
1
);
MaxPoolGradCudaKernel
<
T
,
IntT
>
MaxPoolGradCudaKernel
<
T
,
IntT
>
<<<
config
.
block_per_grid
.
x
,
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
config
.
thread_per_block
.
x
,
...
@@ -112,8 +99,8 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -112,8 +99,8 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
out_features_ptr
,
out_features_ptr
,
out_grad_ptr
,
out_grad_ptr
,
rulebook_ptr
+
offsets
[
i
]
+
rulebook_len
,
rulebook_ptr
+
offsets
[
i
],
counter
[
i
],
counter
_ptr
[
i
],
rulebook_len
,
rulebook_len
,
in_channels
,
in_channels
,
x_grad_ptr
);
x_grad_ptr
);
...
@@ -124,6 +111,7 @@ template <typename T, typename Context>
...
@@ -124,6 +111,7 @@ template <typename T, typename Context>
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
...
@@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
...
@@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx,
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGradGPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGradGPUKernel"
,
([
&
]
{
MaxPoolCooGradGPUKernel
<
T
,
data_t
>
(
MaxPoolCooGradGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
dev_ctx
,
x
,
rulebook
,
counter
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
浏览文件 @
9841b308
...
@@ -19,7 +19,7 @@ limitations under the License. */
...
@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/core/visit_type.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/gpu/conv
olution
.cu.h"
#include "paddle/phi/kernels/sparse/gpu/conv.cu.h"
namespace
phi
{
namespace
phi
{
namespace
sparse
{
namespace
sparse
{
...
@@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
const
std
::
vector
<
int
>&
real_kernel_sizes
=
...
@@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
real_kernel_sizes
[
3
];
const
int
in_channels
=
real_kernel_sizes
[
3
];
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
);
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
h_
counter
(
kernel_size
);
DenseTensorMeta
counter_meta
(
DenseTensorMeta
counter_meta
(
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
...
@@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
&
out_index
,
&
out_index
,
&
unique_value
,
&
unique_value
,
out
,
out
,
&
counter
,
h_counter
.
data
()
,
&
offsets
);
offsets
.
data
()
);
const
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
const
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
T
*
out_features_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
T
*
out_features_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
counter
->
Resize
({
kernel_size
});
int
*
counter_ptr
=
dev_ctx
.
template
HostAlloc
<
int
>(
counter
);
memcpy
(
counter_ptr
,
h_counter
.
data
(),
h_counter
.
size
()
*
sizeof
(
int
));
// 2. max pool
// 2. max pool
#ifdef PADDLE_WITH_HIP
#ifdef PADDLE_WITH_HIP
thrust
::
fill
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
thrust
::
fill
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
...
@@ -107,22 +111,21 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
...
@@ -107,22 +111,21 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx,
static_cast
<
T
>
(
0
));
static_cast
<
T
>
(
0
));
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
if
(
h_
counter
[
i
]
<=
0
)
{
continue
;
continue
;
}
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
counter
[
i
]
*
in_channels
,
1
);
dev_ctx
,
h_counter
[
i
]
*
in_channels
,
1
);
MaxPoolCudaKernel
<
T
,
IntT
>
MaxPoolCudaKernel
<
T
,
IntT
><<<
config
.
block_per_grid
.
x
,
<<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
config
.
thread_per_block
.
x
,
0
,
0
,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
rulebook_ptr
+
offsets
[
i
],
rulebook_ptr
+
offsets
[
i
]
+
rulebook_len
,
h_counter
[
i
],
counter
[
i
],
rulebook_len
,
rulebook_len
,
in_channels
,
in_channels
,
out_features_ptr
);
out_features_ptr
);
}
}
}
}
...
@@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
...
@@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCooGPUKernel"
,
([
&
]
{
MaxPoolCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
MaxPoolCooGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
...
@@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
...
@@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
dilations
,
dilations
,
strides
,
strides
,
out
,
out
,
rulebook
);
rulebook
,
counter
);
}));
}));
}
}
...
...
paddle/phi/kernels/sparse/pool_grad_kernel.h
浏览文件 @
9841b308
...
@@ -25,6 +25,7 @@ template <typename T, typename Context>
...
@@ -25,6 +25,7 @@ template <typename T, typename Context>
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPoolCooGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
...
@@ -34,12 +35,13 @@ template <typename T, typename Context>
...
@@ -34,12 +35,13 @@ template <typename T, typename Context>
SparseCooTensor
MaxPoolCooGrad
(
const
Context
&
dev_ctx
,
SparseCooTensor
MaxPoolCooGrad
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
counter
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
)
{
const
std
::
vector
<
int
>&
kernel_sizes
)
{
SparseCooTensor
x_grad
;
SparseCooTensor
x_grad
;
MaxPoolCooGradKernel
<
T
,
Context
>
(
MaxPoolCooGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
&
x_grad
);
dev_ctx
,
x
,
rulebook
,
counter
,
out
,
out_grad
,
kernel_sizes
,
&
x_grad
);
return
x_grad
;
return
x_grad
;
}
}
...
...
paddle/phi/kernels/sparse/pool_kernel.h
浏览文件 @
9841b308
...
@@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
...
@@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
);
DenseTensor
*
rulebook
,
DenseTensor
*
counter
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
SparseCooTensor
MaxPoolCoo
(
const
Context
&
dev_ctx
,
SparseCooTensor
MaxPoolCoo
(
const
Context
&
dev_ctx
,
...
@@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
...
@@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
,
DenseTensor
*
counter
)
{
SparseCooTensor
coo
;
SparseCooTensor
coo
;
MaxPoolCooKernel
<
T
,
Context
>
(
MaxPoolCooKernel
<
T
,
Context
>
(
dev_ctx
,
dev_ctx
,
x
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
coo
,
rulebook
);
x
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
coo
,
rulebook
,
counter
);
return
coo
;
return
coo
;
}
}
...
...
paddle/phi/tests/api/test_sparse_conv_api.cc
浏览文件 @
9841b308
...
@@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector<int>& indices,
...
@@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector<int>& indices,
kernel
.
size
()
*
sizeof
(
T
));
kernel
.
size
()
*
sizeof
(
T
));
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
auto
tensor_out
=
paddle
::
experimental
::
sparse
::
conv3d
(
auto
tensor_out
=
paddle
::
experimental
::
sparse
::
conv3d
_coo
(
x
,
weight
,
paddings
,
dilations
,
strides
,
1
,
false
);
x
,
weight
,
paddings
,
dilations
,
strides
,
1
,
false
,
"Conv3d"
);
auto
out
=
auto
out
=
std
::
dynamic_pointer_cast
<
phi
::
SparseCooTensor
>
(
tensor_out
.
impl
());
std
::
dynamic_pointer_cast
<
phi
::
SparseCooTensor
>
(
tensor_out
.
impl
());
...
...
paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc
浏览文件 @
9841b308
...
@@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
};
};
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
DenseTensor
rulebook
=
phi
::
Empty
(
DenseTensor
rulebook
,
counter
;
dev_ctx_cpu
,
DenseTensorMeta
(
indices_dtype
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
out
=
sparse
::
Conv3dCoo
<
T
>
(
dev_ctx_cpu
,
SparseCooTensor
out
=
sparse
::
Conv3dCoo
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
x_tensor
,
kernel_tensor
,
kernel_tensor
,
...
@@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides
,
strides
,
1
,
1
,
subm
,
subm
,
&
rulebook
);
"Conv3d"
,
&
rulebook
,
&
counter
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
...
@@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse
::
Conv3dCooGrad
<
T
>
(
dev_ctx_cpu
,
sparse
::
Conv3dCooGrad
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
x_tensor
,
kernel_tensor
,
kernel_tensor
,
out
,
rulebook
,
rulebook
,
counter
,
out
,
out
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
1
,
1
,
subm
);
subm
,
"Conv3d"
);
f_verify
(
std
::
get
<
0
>
(
grads
).
non_zero_elements
().
data
<
T
>
(),
features_grad
);
f_verify
(
std
::
get
<
0
>
(
grads
).
non_zero_elements
().
data
<
T
>
(),
features_grad
);
f_verify
(
std
::
get
<
1
>
(
grads
).
data
<
T
>
(),
kernel_grad
);
f_verify
(
std
::
get
<
1
>
(
grads
).
data
<
T
>
(),
kernel_grad
);
}
}
...
@@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector<IntT>& indices,
phi
::
Copy
(
phi
::
Copy
(
dev_ctx_gpu
,
kernel_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_kernel_tensor
);
dev_ctx_gpu
,
kernel_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_kernel_tensor
);
DenseTensor
d_rulebook
=
phi
::
Empty
(
DenseTensor
d_rulebook
,
d_counter
;
dev_ctx_gpu
,
DenseTensorMeta
(
indices_dtype
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
d_out
=
sparse
::
Conv3dCoo
<
T
>
(
dev_ctx_gpu
,
SparseCooTensor
d_out
=
sparse
::
Conv3dCoo
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_x_tensor
,
d_kernel_tensor
,
d_kernel_tensor
,
...
@@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector<IntT>& indices,
strides
,
strides
,
1
,
1
,
subm
,
subm
,
&
d_rulebook
);
"Conv3d"
,
&
d_rulebook
,
&
d_counter
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
...
@@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
...
@@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector<IntT>& indices,
sparse
::
Conv3dCooGrad
<
T
>
(
dev_ctx_gpu
,
sparse
::
Conv3dCooGrad
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_x_tensor
,
d_kernel_tensor
,
d_kernel_tensor
,
d_out
,
d_rulebook
,
d_rulebook
,
d_counter
,
d_out
,
d_out
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
1
,
1
,
subm
);
subm
,
"Conv3d"
);
DenseTensor
d_features_grad
=
std
::
get
<
0
>
(
grads
).
non_zero_elements
();
DenseTensor
d_features_grad
=
std
::
get
<
0
>
(
grads
).
non_zero_elements
();
DenseTensor
d_kernel_grad
=
std
::
get
<
1
>
(
grads
);
DenseTensor
d_kernel_grad
=
std
::
get
<
1
>
(
grads
);
DenseTensor
h_features_grad
=
DenseTensor
h_features_grad
=
...
...
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
浏览文件 @
9841b308
...
@@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
};
};
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
DenseTensor
rulebook
;
DenseTensor
rulebook
,
counter
;
SparseCooTensor
out
=
sparse
::
MaxPoolCoo
<
T
>
(
dev_ctx_cpu
,
SparseCooTensor
out
=
sparse
::
MaxPoolCoo
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
x_tensor
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
&
rulebook
);
&
rulebook
,
&
counter
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
...
@@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
if
(
backward
)
{
if
(
backward
)
{
SparseCooTensor
x_grad
=
sparse
::
MaxPoolCooGrad
<
T
>
(
SparseCooTensor
x_grad
=
sparse
::
MaxPoolCooGrad
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
rulebook
,
out
,
out
,
kernel_sizes
);
dev_ctx_cpu
,
x_tensor
,
rulebook
,
counter
,
out
,
out
,
kernel_sizes
);
f_verify
(
x_grad
.
non_zero_elements
().
data
<
T
>
(),
features_grad
);
f_verify
(
x_grad
.
non_zero_elements
().
data
<
T
>
(),
features_grad
);
}
}
}
}
...
@@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
SparseCooTensor
d_x_tensor
(
d_indices_tensor
,
d_features_tensor
,
x_dims
);
SparseCooTensor
d_x_tensor
(
d_indices_tensor
,
d_features_tensor
,
x_dims
);
DenseTensor
d_rulebook
;
DenseTensor
d_rulebook
,
d_counter
;
SparseCooTensor
d_out
=
sparse
::
MaxPoolCoo
<
T
>
(
dev_ctx_gpu
,
SparseCooTensor
d_out
=
sparse
::
MaxPoolCoo
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_x_tensor
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
&
d_rulebook
);
&
d_rulebook
,
&
d_counter
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
...
@@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
if
(
backward
)
{
SparseCooTensor
x_grad
=
sparse
::
MaxPoolCooGrad
<
T
>
(
SparseCooTensor
x_grad
=
sparse
::
MaxPoolCooGrad
<
T
>
(
dev_ctx_gpu
,
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_out
,
d_out
,
kernel_sizes
);
d_x_tensor
,
d_rulebook
,
d_counter
,
d_out
,
d_out
,
kernel_sizes
);
DenseTensor
h_features_grad
=
DenseTensor
h_features_grad
=
phi
::
EmptyLike
<
T
>
(
dev_ctx_cpu
,
x_grad
.
non_zero_elements
());
phi
::
EmptyLike
<
T
>
(
dev_ctx_cpu
,
x_grad
.
non_zero_elements
());
phi
::
Copy
(
dev_ctx_gpu
,
phi
::
Copy
(
dev_ctx_gpu
,
...
...
python/paddle/fluid/tests/unittests/test_sparse_conv_op.py
浏览文件 @
9841b308
...
@@ -67,7 +67,7 @@ class TestSparseConv(unittest.TestCase):
...
@@ -67,7 +67,7 @@ class TestSparseConv(unittest.TestCase):
indices
,
values
,
dense_shape
,
stop_gradient
=
True
)
indices
,
values
,
dense_shape
,
stop_gradient
=
True
)
weight
=
paddle
.
randn
((
1
,
3
,
3
,
1
,
1
),
dtype
=
'float32'
)
weight
=
paddle
.
randn
((
1
,
3
,
3
,
1
,
1
),
dtype
=
'float32'
)
y
=
paddle
.
incubate
.
sparse
.
nn
.
functional
.
subm_conv3d
(
y
=
paddle
.
incubate
.
sparse
.
nn
.
functional
.
subm_conv3d
(
sparse_x
,
weight
)
sparse_x
,
weight
,
key
=
'subm_conv'
)
assert
np
.
array_equal
(
sparse_x
.
indices
().
numpy
(),
assert
np
.
array_equal
(
sparse_x
.
indices
().
numpy
(),
y
.
indices
().
numpy
())
y
.
indices
().
numpy
())
...
@@ -91,7 +91,7 @@ class TestSparseConv(unittest.TestCase):
...
@@ -91,7 +91,7 @@ class TestSparseConv(unittest.TestCase):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
#Currently, only support data_format='NDHWC'
#Currently, only support data_format='NDHWC'
conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NCDHW'
)
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NCDHW'
,
key
=
'subm_conv'
)
def
test_SubmConv3D
(
self
):
def
test_SubmConv3D
(
self
):
with
_test_eager_guard
():
with
_test_eager_guard
():
...
@@ -105,7 +105,7 @@ class TestSparseConv(unittest.TestCase):
...
@@ -105,7 +105,7 @@ class TestSparseConv(unittest.TestCase):
indices
,
values
,
dense_shape
,
False
)
indices
,
values
,
dense_shape
,
False
)
subm_conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
subm_conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NDHWC'
)
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NDHWC'
,
key
=
'subm_conv'
)
# test extra_repr
# test extra_repr
print
(
subm_conv3d
.
extra_repr
())
print
(
subm_conv3d
.
extra_repr
())
...
@@ -117,7 +117,7 @@ class TestSparseConv(unittest.TestCase):
...
@@ -117,7 +117,7 @@ class TestSparseConv(unittest.TestCase):
with
self
.
assertRaises
(
ValueError
):
with
self
.
assertRaises
(
ValueError
):
#Currently, only support data_format='NDHWC'
#Currently, only support data_format='NDHWC'
conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
conv3d
=
paddle
.
incubate
.
sparse
.
nn
.
SubmConv3D
(
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NCDHW'
)
1
,
1
,
(
1
,
3
,
3
),
data_format
=
'NCDHW'
,
key
=
'subm_conv'
)
def
test_Conv3D_bias
(
self
):
def
test_Conv3D_bias
(
self
):
with
_test_eager_guard
():
with
_test_eager_guard
():
...
...
python/paddle/incubate/sparse/nn/functional/conv.py
浏览文件 @
9841b308
...
@@ -29,6 +29,7 @@ def _conv3d(x,
...
@@ -29,6 +29,7 @@ def _conv3d(x,
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
subm
=
False
,
subm
=
False
,
key
=
None
,
data_format
=
"NDHWC"
,
data_format
=
"NDHWC"
,
name
=
None
):
name
=
None
):
assert
in_dynamic_mode
(),
"Currently, only support dynamic mode"
assert
in_dynamic_mode
(),
"Currently, only support dynamic mode"
...
@@ -62,8 +63,9 @@ def _conv3d(x,
...
@@ -62,8 +63,9 @@ def _conv3d(x,
dilation
=
convert_to_list
(
dilation
,
dims
,
'dilation'
)
dilation
=
convert_to_list
(
dilation
,
dims
,
'dilation'
)
op_type
=
"conv3d"
op_type
=
"conv3d"
pre_bias
=
_C_ops
.
final_state_sparse_conv3d
(
x
,
weight
,
padding
,
dilation
,
pre_bias
=
_C_ops
.
final_state_sparse_conv3d_coo
(
stride
,
groups
,
subm
)
x
,
weight
,
padding
,
dilation
,
stride
,
groups
,
subm
,
key
if
key
is
not
None
else
""
)
if
bias
is
not
None
:
if
bias
is
not
None
:
values
=
pre_bias
.
values
()
values
=
pre_bias
.
values
()
add_bias
=
elementwise_add
(
values
,
bias
,
axis
=
1
)
add_bias
=
elementwise_add
(
values
,
bias
,
axis
=
1
)
...
@@ -186,7 +188,7 @@ def conv3d(x,
...
@@ -186,7 +188,7 @@ def conv3d(x,
# (1, 1, 1, 2, 1)
# (1, 1, 1, 2, 1)
"""
"""
return
_conv3d
(
x
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
,
False
,
return
_conv3d
(
x
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
,
False
,
data_format
,
name
)
None
,
data_format
,
name
)
def
subm_conv3d
(
x
,
def
subm_conv3d
(
x
,
...
@@ -197,6 +199,7 @@ def subm_conv3d(x,
...
@@ -197,6 +199,7 @@ def subm_conv3d(x,
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
data_format
=
"NDHWC"
,
data_format
=
"NDHWC"
,
key
=
None
,
name
=
None
):
name
=
None
):
r
"""
r
"""
...
@@ -274,6 +277,10 @@ def subm_conv3d(x,
...
@@ -274,6 +277,10 @@ def subm_conv3d(x,
will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`.
will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`.
The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of:
The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of:
`[batch_size, input_depth, input_height, input_width, input_channels]`.
`[batch_size, input_depth, input_height, input_width, input_channels]`.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
name(str|None): For detailed information, please refer
name(str|None): For detailed information, please refer
to :ref:`api_guide_Name`. Usually name is no need to set and
to :ref:`api_guide_Name`. Usually name is no need to set and
None by default.
None by default.
...
@@ -301,4 +308,4 @@ def subm_conv3d(x,
...
@@ -301,4 +308,4 @@ def subm_conv3d(x,
#(1, 1, 3, 4, 1)
#(1, 1, 3, 4, 1)
"""
"""
return
_conv3d
(
x
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
,
True
,
return
_conv3d
(
x
,
weight
,
bias
,
stride
,
padding
,
dilation
,
groups
,
True
,
data_format
,
name
)
key
,
data_format
,
name
)
python/paddle/incubate/sparse/nn/layer/conv.py
浏览文件 @
9841b308
...
@@ -33,6 +33,7 @@ class _Conv3D(Layer):
...
@@ -33,6 +33,7 @@ class _Conv3D(Layer):
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
subm
=
False
,
subm
=
False
,
key
=
None
,
padding_mode
=
'zeros'
,
padding_mode
=
'zeros'
,
weight_attr
=
None
,
weight_attr
=
None
,
bias_attr
=
None
,
bias_attr
=
None
,
...
@@ -46,6 +47,7 @@ class _Conv3D(Layer):
...
@@ -46,6 +47,7 @@ class _Conv3D(Layer):
self
.
_out_channels
=
out_channels
self
.
_out_channels
=
out_channels
self
.
_data_format
=
data_format
self
.
_data_format
=
data_format
self
.
_subm
=
subm
self
.
_subm
=
subm
self
.
_key
=
key
assert
padding_mode
==
'zeros'
,
"Currently, only support padding_mode='zeros'"
assert
padding_mode
==
'zeros'
,
"Currently, only support padding_mode='zeros'"
assert
groups
==
1
,
"Currently, only support groups=1"
assert
groups
==
1
,
"Currently, only support groups=1"
...
@@ -95,6 +97,7 @@ class _Conv3D(Layer):
...
@@ -95,6 +97,7 @@ class _Conv3D(Layer):
dilation
=
self
.
_dilation
,
dilation
=
self
.
_dilation
,
groups
=
self
.
_groups
,
groups
=
self
.
_groups
,
subm
=
self
.
_subm
,
subm
=
self
.
_subm
,
key
=
self
.
_key
,
data_format
=
self
.
_data_format
)
data_format
=
self
.
_data_format
)
return
out
return
out
...
@@ -240,6 +243,7 @@ class Conv3D(_Conv3D):
...
@@ -240,6 +243,7 @@ class Conv3D(_Conv3D):
dilation
=
dilation
,
dilation
=
dilation
,
groups
=
groups
,
groups
=
groups
,
subm
=
False
,
subm
=
False
,
key
=
None
,
padding_mode
=
padding_mode
,
padding_mode
=
padding_mode
,
weight_attr
=
weight_attr
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
bias_attr
=
bias_attr
,
...
@@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D):
...
@@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D):
of the input channels, while the second half of the filters is only
of the input channels, while the second half of the filters is only
connected to the second half of the input channels. The default value is 1.
connected to the second half of the input channels. The default value is 1.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``.
key(str, optional): the key is used to save or use the same rulebook,
the definition and role of rulebook refers to
https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The
default value is None.
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights
of conv3d. If it is set to None or one attribute of ParamAttr, conv3d
of conv3d. If it is set to None or one attribute of ParamAttr, conv3d
will create ParamAttr as param_attr. If it is set to None, the parameter
will create ParamAttr as param_attr. If it is set to None, the parameter
...
@@ -361,6 +369,7 @@ class SubmConv3D(_Conv3D):
...
@@ -361,6 +369,7 @@ class SubmConv3D(_Conv3D):
dilation
=
1
,
dilation
=
1
,
groups
=
1
,
groups
=
1
,
padding_mode
=
'zeros'
,
padding_mode
=
'zeros'
,
key
=
None
,
weight_attr
=
None
,
weight_attr
=
None
,
bias_attr
=
None
,
bias_attr
=
None
,
data_format
=
"NDHWC"
):
data_format
=
"NDHWC"
):
...
@@ -372,6 +381,7 @@ class SubmConv3D(_Conv3D):
...
@@ -372,6 +381,7 @@ class SubmConv3D(_Conv3D):
dilation
=
dilation
,
dilation
=
dilation
,
groups
=
groups
,
groups
=
groups
,
subm
=
True
,
subm
=
True
,
key
=
key
,
padding_mode
=
padding_mode
,
padding_mode
=
padding_mode
,
weight_attr
=
weight_attr
,
weight_attr
=
weight_attr
,
bias_attr
=
bias_attr
,
bias_attr
=
bias_attr
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录