Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
3e3f5d90
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3e3f5d90
编写于
7月 05, 2023
作者:
L
LUZY0726
提交者:
GitHub
7月 05, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[sparse] Add backend conv2d support (#54707)
上级
e05d5c0a
变更
11
显示空白变更内容
内联
并排
Showing
11 changed file
with
443 addition
and
135 deletion
+443
-135
paddle/phi/infermeta/sparse/binary.cc
paddle/phi/infermeta/sparse/binary.cc
+44
-18
paddle/phi/kernels/funcs/sparse/convolution.h
paddle/phi/kernels/funcs/sparse/convolution.h
+38
-16
paddle/phi/kernels/sparse/conv_kernel.h
paddle/phi/kernels/sparse/conv_kernel.h
+0
-1
paddle/phi/kernels/sparse/cpu/conv.h
paddle/phi/kernels/sparse/cpu/conv.h
+98
-31
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
+6
-4
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
+10
-6
paddle/phi/kernels/sparse/gpu/conv.cu.h
paddle/phi/kernels/sparse/gpu/conv.cu.h
+90
-25
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
+6
-4
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
+18
-6
python/paddle/sparse/nn/functional/conv.py
python/paddle/sparse/nn/functional/conv.py
+0
-24
test/legacy_test/test_sparse_conv_op.py
test/legacy_test/test_sparse_conv_op.py
+133
-0
未找到文件。
paddle/phi/infermeta/sparse/binary.cc
浏览文件 @
3e3f5d90
...
...
@@ -23,10 +23,31 @@ inline void GetOutShape(const DDim& x_dims,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
DDim
*
out_dims
)
{
const
bool
is2D
=
out_dims
->
size
()
==
4
?
true
:
false
;
if
(
is2D
)
{
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, H, W, C)"
));
PADDLE_ENFORCE_EQ
(
kernel_sizes
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"the shape of kernel should be (H, W, C, OC)"
));
// infer out shape
(
*
out_dims
)[
0
]
=
x_dims
[
0
];
(
*
out_dims
)[
3
]
=
kernel_sizes
[
3
];
for
(
int
i
=
1
;
i
<
3
;
i
++
)
{
(
*
out_dims
)[
i
]
=
(
x_dims
[
i
]
+
2
*
paddings
[
i
-
1
]
-
dilations
[
i
-
1
]
*
(
kernel_sizes
[
i
-
1
]
-
1
)
-
1
)
/
strides
[
i
-
1
]
+
1
;
}
}
else
{
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, D, H, W, C)"
));
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, D, H, W, C)"
));
PADDLE_ENFORCE_EQ
(
kernel_sizes
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
...
...
@@ -41,6 +62,7 @@ inline void GetOutShape(const DDim& x_dims,
strides
[
i
-
1
]
+
1
;
}
}
}
inline
void
ResetSubmKernelSizeAndStrides
(
const
DDim
&
kernel_dims
,
...
...
@@ -64,8 +86,12 @@ void Conv3dInferMeta(const MetaTensor& x,
MetaTensor
*
rulebook
,
MetaTensor
*
counter
)
{
const
auto
&
x_dims
=
x
.
dims
();
const
bool
is2D
=
x_dims
.
size
()
==
4
?
true
:
false
;
const
auto
&
kernel_dims
=
kernel
.
dims
();
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
int
rank
=
is2D
?
4
:
5
;
std
::
vector
<
int
>
out_dims_vec
(
rank
,
1
);
DDim
out_dims
=
make_ddim
(
out_dims_vec
);
std
::
vector
<
int
>
kernel_sizes
(
kernel_dims
.
size
());
for
(
int
i
=
0
;
i
<
kernel_dims
.
size
();
i
++
)
{
...
...
paddle/phi/kernels/funcs/sparse/convolution.h
浏览文件 @
3e3f5d90
...
...
@@ -101,10 +101,31 @@ inline void GetOutShape(const DDim& x_dims,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
DDim
*
out_dims
)
{
const
bool
is2D
=
out_dims
->
size
()
==
4
?
true
:
false
;
if
(
is2D
)
{
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, H, W, C)"
));
PADDLE_ENFORCE_EQ
(
kernel_sizes
.
size
(),
4
,
phi
::
errors
::
InvalidArgument
(
"the shape of kernel should be (H, W, C, OC)"
));
// infer out shape
(
*
out_dims
)[
0
]
=
x_dims
[
0
];
(
*
out_dims
)[
3
]
=
kernel_sizes
[
3
];
for
(
int
i
=
1
;
i
<
3
;
i
++
)
{
(
*
out_dims
)[
i
]
=
(
x_dims
[
i
]
+
2
*
paddings
[
i
-
1
]
-
dilations
[
i
-
1
]
*
(
kernel_sizes
[
i
-
1
]
-
1
)
-
1
)
/
strides
[
i
-
1
]
+
1
;
}
}
else
{
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, D, H, W, C)"
));
phi
::
errors
::
InvalidArgument
(
"the shape of x should be (N, D, H, W, C)"
));
PADDLE_ENFORCE_EQ
(
kernel_sizes
.
size
(),
5
,
phi
::
errors
::
InvalidArgument
(
...
...
@@ -119,6 +140,7 @@ inline void GetOutShape(const DDim& x_dims,
strides
[
i
-
1
]
+
1
;
}
}
}
inline
void
ResetSubmKernelSizeAndStrides
(
const
DDim
&
kernel_dims
,
...
...
paddle/phi/kernels/sparse/conv_kernel.h
浏览文件 @
3e3f5d90
...
...
@@ -63,6 +63,5 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx,
counter
);
return
coo
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/cpu/conv.h
浏览文件 @
3e3f5d90
...
...
@@ -42,50 +42,101 @@ void ProductRuleBook(const Context& dev_ctx,
const
bool
subm
,
DenseTensor
*
rulebook
,
int
*
counter_per_kernel
)
{
const
bool
is2D
=
out_dims
.
size
()
==
4
?
true
:
false
;
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
indices
=
x
.
indices
();
const
IntT
*
indices_ptr
=
indices
.
data
<
IntT
>
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
is2D
?
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
:
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
memset
(
counter_per_kernel
,
0
,
kernel_size
*
sizeof
(
int
));
int
rulebook_len
=
0
;
// calc the rulebook_len
const
auto
&
x_dims
=
x
.
dims
();
const
Dims4D
c_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
const
Dims4D
c_kernel_dims
(
1
,
kernel_sizes
[
2
],
kernel_sizes
[
1
],
kernel_sizes
[
0
]);
const
Dims4D
c_out_dims
(
out_dims
[
0
],
out_dims
[
3
],
out_dims
[
2
],
out_dims
[
1
]);
const
Dims4D
c_paddings
(
1
,
paddings
[
2
],
paddings
[
1
],
paddings
[
0
]);
const
Dims4D
c_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
const
Dims4D
c_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
int
xdim0
,
xdim1
,
xdim2
,
xdim3
;
int
kdim0
,
kdim1
,
kdim2
,
kdim3
;
int
odim0
,
odim1
,
odim2
,
odim3
;
int
pdim0
,
pdim1
,
pdim2
,
pdim3
;
int
sdim0
,
sdim1
,
sdim2
,
sdim3
;
int
ddim0
,
ddim1
,
ddim2
,
ddim3
;
xdim0
=
x_dims
[
0
];
xdim1
=
is2D
?
x_dims
[
2
]
:
x_dims
[
3
];
xdim2
=
is2D
?
x_dims
[
1
]
:
x_dims
[
2
];
xdim3
=
is2D
?
1
:
x_dims
[
1
];
kdim0
=
1
;
kdim1
=
is2D
?
kernel_sizes
[
1
]
:
kernel_sizes
[
2
];
kdim2
=
is2D
?
kernel_sizes
[
0
]
:
kernel_sizes
[
1
];
kdim3
=
is2D
?
1
:
kernel_sizes
[
0
];
odim0
=
out_dims
[
0
];
odim1
=
is2D
?
out_dims
[
2
]
:
out_dims
[
3
];
odim2
=
is2D
?
out_dims
[
1
]
:
out_dims
[
2
];
odim3
=
is2D
?
1
:
out_dims
[
1
];
pdim0
=
1
;
pdim1
=
is2D
?
paddings
[
1
]
:
paddings
[
2
];
pdim2
=
is2D
?
paddings
[
0
]
:
paddings
[
1
];
pdim3
=
is2D
?
1
:
paddings
[
0
];
sdim0
=
1
;
sdim1
=
is2D
?
strides
[
1
]
:
strides
[
2
];
sdim2
=
is2D
?
strides
[
0
]
:
strides
[
1
];
sdim3
=
is2D
?
1
:
strides
[
0
];
ddim0
=
1
;
ddim1
=
is2D
?
dilations
[
1
]
:
dilations
[
2
];
ddim2
=
is2D
?
dilations
[
0
]
:
dilations
[
1
];
ddim3
=
is2D
?
1
:
dilations
[
0
];
const
Dims4D
c_x_dims
(
xdim0
,
xdim1
,
xdim2
,
xdim3
);
const
Dims4D
c_kernel_dims
(
kdim0
,
kdim1
,
kdim2
,
kdim3
);
const
Dims4D
c_out_dims
(
odim0
,
odim1
,
odim2
,
odim3
);
const
Dims4D
c_paddings
(
pdim0
,
pdim1
,
pdim2
,
pdim3
);
const
Dims4D
c_strides
(
sdim0
,
sdim1
,
sdim2
,
sdim3
);
const
Dims4D
c_dilations
(
ddim0
,
ddim1
,
ddim2
,
ddim3
);
std
::
set
<
IntT
>
hash_in
;
if
(
subm
)
{
for
(
int
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
IntT
batch
=
indices_ptr
[
i
];
IntT
in_z
=
indices_ptr
[
i
+
non_zero_num
];
IntT
in_y
=
indices_ptr
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices_ptr
[
i
+
3
*
non_zero_num
];
IntT
index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
DDim
>
(
batch
,
in_x
,
in_y
,
in_z
,
x_dims
);
IntT
in_z
=
is2D
?
0
:
indices_ptr
[
i
+
non_zero_num
];
IntT
in_y
=
is2D
?
indices_ptr
[
i
+
non_zero_num
]
:
indices_ptr
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
is2D
?
indices_ptr
[
i
+
2
*
non_zero_num
]
:
indices_ptr
[
i
+
3
*
non_zero_num
];
IntT
index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
batch
,
in_x
,
in_y
,
in_z
,
c_x_dims
);
hash_in
.
insert
(
index
);
}
}
auto
f_calc_rulebook
=
[
&
](
IntT
*
rulebook_ptr
)
{
int
kernel_index
=
0
,
rulebook_index
=
0
;
for
(
int
kz
=
0
;
kz
<
kernel_sizes
[
0
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_sizes
[
1
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_sizes
[
2
];
kx
++
)
{
int
zceil
=
is2D
?
1
:
kernel_sizes
[
0
];
int
yceil
=
is2D
?
kernel_sizes
[
0
]
:
kernel_sizes
[
1
];
int
xceil
=
is2D
?
kernel_sizes
[
1
]
:
kernel_sizes
[
2
];
for
(
int
kz
=
0
;
kz
<
zceil
;
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
yceil
;
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
xceil
;
kx
++
)
{
++
kernel_index
;
for
(
int64_t
i
=
0
;
i
<
non_zero_num
;
i
++
)
{
IntT
batch
=
indices_ptr
[
i
];
IntT
in_z
=
indices_ptr
[
i
+
non_zero_num
];
IntT
in_y
=
indices_ptr
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices_ptr
[
i
+
3
*
non_zero_num
];
IntT
out_z
=
(
in_z
+
paddings
[
0
]
-
kz
*
dilations
[
0
])
/
strides
[
0
];
IntT
out_y
=
(
in_y
+
paddings
[
1
]
-
ky
*
dilations
[
1
])
/
strides
[
1
];
IntT
out_x
=
(
in_x
+
paddings
[
2
]
-
kx
*
dilations
[
2
])
/
strides
[
2
];
IntT
in_z
=
is2D
?
0
:
indices_ptr
[
i
+
non_zero_num
];
IntT
in_y
=
is2D
?
indices_ptr
[
i
+
non_zero_num
]
:
indices_ptr
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
is2D
?
indices_ptr
[
i
+
2
*
non_zero_num
]
:
indices_ptr
[
i
+
3
*
non_zero_num
];
IntT
out_z
=
is2D
?
0
:
(
in_z
+
paddings
[
0
]
-
kz
*
dilations
[
0
])
/
strides
[
0
];
IntT
out_y
=
(
in_y
+
c_paddings
[
2
]
-
ky
*
c_dilations
[
2
])
/
c_strides
[
2
];
IntT
out_x
=
(
in_x
+
c_paddings
[
3
]
-
kx
*
c_dilations
[
3
])
/
c_strides
[
3
];
if
(
phi
::
funcs
::
sparse
::
Check
(
c_x_dims
,
c_kernel_dims
,
c_paddings
,
...
...
@@ -98,8 +149,8 @@ void ProductRuleBook(const Context& dev_ctx,
ky
,
kz
))
{
if
(
subm
)
{
IntT
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
D
Dim
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
IntT
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
D
ims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
c_
out_dims
);
if
(
hash_in
.
find
(
out_index
)
==
hash_in
.
end
())
{
continue
;
}
...
...
@@ -112,8 +163,8 @@ void ProductRuleBook(const Context& dev_ctx,
rulebook_ptr
[
rulebook_index
]
=
kernel_index
-
1
;
rulebook_ptr
[
rulebook_index
+
rulebook_len
]
=
i
;
// in_i
rulebook_ptr
[
rulebook_index
+
rulebook_len
*
2
]
=
phi
::
funcs
::
sparse
::
PointToIndex
<
D
Dim
>
(
batch
,
out_x
,
out_y
,
out_z
,
out_dims
);
// out_index
phi
::
funcs
::
sparse
::
PointToIndex
<
D
ims4D
>
(
batch
,
out_x
,
out_y
,
out_z
,
c_
out_dims
);
// out_index
++
rulebook_index
;
}
}
...
...
@@ -141,6 +192,8 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
const
DDim
&
out_dims
,
DenseTensor
*
rulebook
,
SparseCooTensor
*
out
)
{
const
bool
is2D
=
out_dims
.
size
()
==
4
?
true
:
false
;
std
::
set
<
IntT
>
out_indexs
;
int
n
=
rulebook
->
dims
()[
1
];
IntT
*
rulebook_ptr
=
rulebook
->
data
<
IntT
>
();
...
...
@@ -149,7 +202,7 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
}
int
out_non_zero_num
=
out_indexs
.
size
();
const
int64_t
sparse_dim
=
4
;
const
int64_t
sparse_dim
=
is2D
?
3
:
4
;
DenseTensorMeta
indices_meta
(
phi
::
CppTypeToDataType
<
IntT
>::
Type
(),
{
sparse_dim
,
out_non_zero_num
},
DataLayout
::
NCHW
);
...
...
@@ -159,15 +212,29 @@ void UpdateRulebookAndOutIndex(const Context& dev_ctx,
phi
::
DenseTensor
out_values
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
values_meta
));
IntT
*
out_indices_ptr
=
out_indices
.
data
<
IntT
>
();
int
i
=
0
;
int
odim0
,
odim1
,
odim2
,
odim3
;
odim0
=
out_dims
[
0
];
odim1
=
is2D
?
out_dims
[
2
]
:
out_dims
[
3
];
odim2
=
is2D
?
out_dims
[
1
]
:
out_dims
[
2
];
odim3
=
is2D
?
1
:
out_dims
[
1
];
const
Dims4D
c_out_dims
(
odim0
,
odim1
,
odim2
,
odim3
);
for
(
auto
it
=
out_indexs
.
begin
();
it
!=
out_indexs
.
end
();
it
++
,
i
++
)
{
const
IntT
index
=
*
it
;
IntT
batch
,
x
,
y
,
z
;
phi
::
funcs
::
sparse
::
IndexToPoint
<
DDim
>
(
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
phi
::
funcs
::
sparse
::
IndexToPoint
<
Dims4D
>
(
index
,
c_out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
out_indices_ptr
[
i
]
=
batch
;
if
(
is2D
)
{
out_indices_ptr
[
i
+
out_non_zero_num
]
=
y
;
out_indices_ptr
[
i
+
out_non_zero_num
*
2
]
=
x
;
}
else
{
out_indices_ptr
[
i
+
out_non_zero_num
]
=
z
;
out_indices_ptr
[
i
+
out_non_zero_num
*
2
]
=
y
;
out_indices_ptr
[
i
+
out_non_zero_num
*
3
]
=
x
;
}
}
for
(
i
=
0
;
i
<
n
;
i
++
)
{
IntT
out_index
=
rulebook_ptr
[
i
+
n
*
2
];
rulebook_ptr
[
i
+
n
*
2
]
=
...
...
paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc
浏览文件 @
3e3f5d90
...
...
@@ -47,9 +47,12 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
bool
is2D
=
kernel_dims
.
size
()
==
4
?
true
:
false
;
const
int
kernel_size
=
is2D
?
kernel_dims
[
0
]
*
kernel_dims
[
1
]
:
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
is2D
?
kernel_dims
[
2
]
:
kernel_dims
[
3
];
const
int
out_channels
=
is2D
?
kernel_dims
[
3
]
:
kernel_dims
[
4
];
int
rulebook_len
=
0
;
const
IntT
*
rulebook_ptr
=
phi
::
funcs
::
sparse
::
GetRulebookPtr
<
IntT
>
(
...
...
@@ -210,7 +213,6 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
kernel_grad
);
}));
}
}
// namespace sparse
}
// namespace phi
...
...
paddle/phi/kernels/sparse/cpu/conv_kernel.cc
浏览文件 @
3e3f5d90
...
...
@@ -45,9 +45,15 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
// if x.layout != NDHWC then transpose(x), transpose(weight)
const
auto
&
x_dims
=
x
.
dims
();
const
bool
is2D
=
x_dims
.
size
()
==
4
?
true
:
false
;
const
auto
&
kernel_dims
=
kernel
.
dims
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
int
kernel_size
=
is2D
?
kernel_dims
[
0
]
*
kernel_dims
[
1
]
:
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
int
count_tmp
=
is2D
?
4
:
5
;
std
::
vector
<
int
>
out_dims_vec
(
count_tmp
,
1
);
DDim
out_dims
=
make_ddim
(
out_dims_vec
);
std
::
vector
<
int
>
kernel_sizes
(
kernel_dims
.
size
());
for
(
int
i
=
0
;
i
<
kernel_dims
.
size
();
i
++
)
{
kernel_sizes
[
i
]
=
kernel_dims
[
i
];
...
...
@@ -63,8 +69,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
kernel_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
&
out_dims
);
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
in_channels
=
is2D
?
kernel_dims
[
2
]
:
kernel_dims
[
3
];
const
int
out_channels
=
is2D
?
kernel_dims
[
3
]
:
kernel_dims
[
4
];
// Second algorithm:
// https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf
...
...
@@ -112,7 +118,6 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx,
phi
::
funcs
::
sparse
::
SaveToTable
(
dev_ctx
,
x
,
key
,
tmp_rulebook
,
h_counter
,
out
,
rulebook
,
counter
);
}
// int n = rulebook->dims()[1];
// 2. gather
DenseTensorMeta
in_features_meta
(
...
...
@@ -198,7 +203,6 @@ void Conv3dCooKernel(const Context& dev_ctx,
counter
);
}));
}
}
// namespace sparse
}
// namespace phi
...
...
paddle/phi/kernels/sparse/gpu/conv.cu.h
浏览文件 @
3e3f5d90
...
...
@@ -331,6 +331,7 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
bool
is2D
,
T
*
rulebook
,
int
*
counter
)
{
int
tid
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
...
...
@@ -345,9 +346,11 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
kernel_index
=
0
;
T
batch
=
x_indices
[
i
];
T
in_z
=
x_indices
[
i
+
non_zero_num
];
T
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
T
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
T
in_z
=
is2D
?
0
:
x_indices
[
i
+
non_zero_num
];
T
in_y
=
is2D
?
x_indices
[
i
+
non_zero_num
]
:
x_indices
[
i
+
2
*
non_zero_num
];
T
in_x
=
is2D
?
x_indices
[
i
+
2
*
non_zero_num
]
:
x_indices
[
i
+
3
*
non_zero_num
];
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
...
...
@@ -363,7 +366,9 @@ __global__ void ProductRuleBookKernel(const T* x_indices,
kx
,
ky
,
kz
))
{
T
out_z
=
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
T
out_z
=
is2D
?
0
:
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
T
out_y
=
(
in_y
+
paddings
[
2
]
-
ky
*
dilations
[
2
])
/
strides
[
2
];
T
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
in_i
=
i
;
...
...
@@ -390,12 +395,15 @@ __global__ void GetOutIndexTable1(const IntT* indices,
const
IntT
non_zero_num
,
const
Dims4D
dims
,
int
*
index_flags
,
const
bool
is2D
,
int
*
out_index_table
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
IntT
batch
=
indices
[
i
];
IntT
in_z
=
indices
[
i
+
non_zero_num
];
IntT
in_y
=
indices
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
indices
[
i
+
3
*
non_zero_num
];
IntT
in_z
=
is2D
?
0
:
indices
[
i
+
non_zero_num
];
IntT
in_y
=
is2D
?
indices
[
i
+
non_zero_num
]
:
indices
[
i
+
2
*
non_zero_num
];
IntT
in_x
=
is2D
?
indices
[
i
+
2
*
non_zero_num
]
:
indices
[
i
+
3
*
non_zero_num
];
IntT
index
=
PointToIndex
(
batch
,
in_x
,
in_y
,
in_z
,
dims
);
phi
::
funcs
::
sparse
::
SetBits
(
index
,
index_flags
);
out_index_table
[
index
]
=
i
;
...
...
@@ -406,6 +414,7 @@ template <typename IntT>
__global__
void
GetOutIndexTable
(
int
*
indexs
,
const
int
non_zero_num
,
const
Dims4D
out_dims
,
const
bool
is2D
,
int
*
out_index_table
,
IntT
*
out_indices
)
{
CUDA_KERNEL_LOOP_TYPE
(
i
,
non_zero_num
,
int64_t
)
{
...
...
@@ -416,9 +425,14 @@ __global__ void GetOutIndexTable(int* indexs,
index
,
out_dims
,
&
batch
,
&
x
,
&
y
,
&
z
);
// get out indices
out_indices
[
i
]
=
batch
;
if
(
is2D
)
{
out_indices
[
i
+
non_zero_num
]
=
y
;
out_indices
[
i
+
non_zero_num
*
2
]
=
x
;
}
else
{
out_indices
[
i
+
non_zero_num
]
=
z
;
out_indices
[
i
+
non_zero_num
*
2
]
=
y
;
out_indices
[
i
+
non_zero_num
*
3
]
=
x
;
}
indexs
[
i
]
=
0
;
}
}
...
...
@@ -464,6 +478,7 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const
Dims4D
paddings
,
const
Dims4D
dilations
,
const
Dims4D
strides
,
const
bool
is2D
,
const
int
*
index_flags
,
const
int
*
out_index_table
,
T
*
rulebook
,
...
...
@@ -472,7 +487,6 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
const
int
kernel_size
=
kernel_dims
[
3
]
*
kernel_dims
[
2
]
*
kernel_dims
[
1
];
extern
__shared__
int
counter_buf
[];
// kernel_size
int
*
counter_buf2
=
counter_buf
+
kernel_size
;
// length = kernel_size * blockDim.x * 2;
int
*
rulebook_buf
=
counter_buf
+
kernel_size
*
2
;
const
int
offset
=
kernel_size
*
non_zero_num
;
...
...
@@ -484,9 +498,11 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
for
(
int
i
=
tid
;
i
<
non_zero_num
;
i
+=
gridDim
.
x
*
blockDim
.
x
)
{
int
kernel_index
=
0
;
T
batch
=
x_indices
[
i
];
T
in_z
=
x_indices
[
i
+
non_zero_num
];
T
in_y
=
x_indices
[
i
+
2
*
non_zero_num
];
T
in_x
=
x_indices
[
i
+
3
*
non_zero_num
];
T
in_z
=
is2D
?
0
:
x_indices
[
i
+
non_zero_num
];
T
in_y
=
is2D
?
x_indices
[
i
+
non_zero_num
]
:
x_indices
[
i
+
2
*
non_zero_num
];
T
in_x
=
is2D
?
x_indices
[
i
+
2
*
non_zero_num
]
:
x_indices
[
i
+
3
*
non_zero_num
];
for
(
int
kz
=
0
;
kz
<
kernel_dims
[
1
];
kz
++
)
{
for
(
int
ky
=
0
;
ky
<
kernel_dims
[
2
];
ky
++
)
{
for
(
int
kx
=
0
;
kx
<
kernel_dims
[
3
];
kx
++
)
{
...
...
@@ -502,7 +518,9 @@ __global__ void ProductSubmRuleBookKernel(const T* x_indices,
kx
,
ky
,
kz
))
{
T
out_z
=
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
T
out_z
=
is2D
?
0
:
(
in_z
+
paddings
[
1
]
-
kz
*
dilations
[
1
])
/
strides
[
1
];
T
out_y
=
(
in_y
+
paddings
[
2
]
-
ky
*
dilations
[
2
])
/
strides
[
2
];
T
out_x
=
(
in_x
+
paddings
[
3
]
-
kx
*
dilations
[
3
])
/
strides
[
3
];
out_index
=
phi
::
funcs
::
sparse
::
PointToIndex
<
Dims4D
>
(
...
...
@@ -637,21 +655,62 @@ int ProductRuleBook(const Context& dev_ctx,
SparseCooTensor
*
out
,
int
*
h_counter
,
int
*
h_offsets
)
{
const
bool
is2D
=
out_dims
.
size
()
==
4
?
true
:
false
;
auto
indices_dtype
=
phi
::
CppTypeToDataType
<
IntT
>::
Type
();
const
int64_t
non_zero_num
=
x
.
nnz
();
const
auto
&
indices
=
x
.
indices
();
const
IntT
*
indices_ptr
=
indices
.
data
<
IntT
>
();
int
*
counter_ptr
=
counter_per_kernel
->
data
<
int
>
();
int
*
offsets_ptr
=
offsets_per_kernel
->
data
<
int
>
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
is2D
?
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
:
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
auto
x_dims
=
x
.
dims
();
Dims4D
d_x_dims
(
x_dims
[
0
],
x_dims
[
3
],
x_dims
[
2
],
x_dims
[
1
]);
Dims4D
d_kernel_dims
(
1
,
kernel_sizes
[
2
],
kernel_sizes
[
1
],
kernel_sizes
[
0
]);
Dims4D
d_out_dims
(
out_dims
[
0
],
out_dims
[
3
],
out_dims
[
2
],
out_dims
[
1
]);
Dims4D
d_paddings
(
1
,
paddings
[
2
],
paddings
[
1
],
paddings
[
0
]);
Dims4D
d_strides
(
1
,
strides
[
2
],
strides
[
1
],
strides
[
0
]);
Dims4D
d_dilations
(
1
,
dilations
[
2
],
dilations
[
1
],
dilations
[
0
]);
int
xdim0
,
xdim1
,
xdim2
,
xdim3
;
int
kdim0
,
kdim1
,
kdim2
,
kdim3
;
int
odim0
,
odim1
,
odim2
,
odim3
;
int
pdim0
,
pdim1
,
pdim2
,
pdim3
;
int
sdim0
,
sdim1
,
sdim2
,
sdim3
;
int
ddim0
,
ddim1
,
ddim2
,
ddim3
;
xdim0
=
x_dims
[
0
];
xdim1
=
is2D
?
x_dims
[
2
]
:
x_dims
[
3
];
xdim2
=
is2D
?
x_dims
[
1
]
:
x_dims
[
2
];
xdim3
=
is2D
?
1
:
x_dims
[
1
];
kdim0
=
1
;
kdim1
=
is2D
?
kernel_sizes
[
1
]
:
kernel_sizes
[
2
];
kdim2
=
is2D
?
kernel_sizes
[
0
]
:
kernel_sizes
[
1
];
kdim3
=
is2D
?
1
:
kernel_sizes
[
0
];
odim0
=
out_dims
[
0
];
odim1
=
is2D
?
out_dims
[
2
]
:
out_dims
[
3
];
odim2
=
is2D
?
out_dims
[
1
]
:
out_dims
[
2
];
odim3
=
is2D
?
1
:
out_dims
[
1
];
pdim0
=
1
;
pdim1
=
is2D
?
paddings
[
1
]
:
paddings
[
2
];
pdim2
=
is2D
?
paddings
[
0
]
:
paddings
[
1
];
pdim3
=
is2D
?
1
:
paddings
[
0
];
sdim0
=
1
;
sdim1
=
is2D
?
strides
[
1
]
:
strides
[
2
];
sdim2
=
is2D
?
strides
[
0
]
:
strides
[
1
];
sdim3
=
is2D
?
1
:
strides
[
0
];
ddim0
=
1
;
ddim1
=
is2D
?
dilations
[
1
]
:
dilations
[
2
];
ddim2
=
is2D
?
dilations
[
0
]
:
dilations
[
1
];
ddim3
=
is2D
?
1
:
dilations
[
0
];
const
Dims4D
d_x_dims
(
xdim0
,
xdim1
,
xdim2
,
xdim3
);
const
Dims4D
d_kernel_dims
(
kdim0
,
kdim1
,
kdim2
,
kdim3
);
const
Dims4D
d_out_dims
(
odim0
,
odim1
,
odim2
,
odim3
);
const
Dims4D
d_paddings
(
pdim0
,
pdim1
,
pdim2
,
pdim3
);
const
Dims4D
d_strides
(
sdim0
,
sdim1
,
sdim2
,
sdim3
);
const
Dims4D
d_dilations
(
ddim0
,
ddim1
,
ddim2
,
ddim3
);
// 1. product rule book
phi
::
backends
::
gpu
::
GpuMemsetAsync
(
counter_ptr
,
0
,
...
...
@@ -682,7 +741,9 @@ int ProductRuleBook(const Context& dev_ctx,
DenseTensor
tmp_rulebook
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
rulebook_meta
));
IntT
*
rulebook_ptr
=
tmp_rulebook
.
data
<
IntT
>
();
DenseTensor
out_indices
=
phi
::
EmptyLike
<
IntT
>
(
dev_ctx
,
x
.
indices
());
DenseTensor
out_values
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
x
.
nnz
(),
kernel_sizes
[
4
]});
int
tmpidx
=
is2D
?
3
:
4
;
DenseTensor
out_values
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
x
.
nnz
(),
kernel_sizes
[
tmpidx
]});
phi
::
Copy
(
dev_ctx
,
x
.
indices
(),
dev_ctx
.
GetPlace
(),
false
,
&
out_indices
);
...
...
@@ -695,6 +756,7 @@ int ProductRuleBook(const Context& dev_ctx,
non_zero_num
,
d_x_dims
,
index_flags_ptr
,
is2D
,
out_index_table_ptr
);
size_t
cache_size
=
...
...
@@ -721,6 +783,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings
,
d_dilations
,
d_strides
,
is2D
,
index_flags_ptr
,
out_index_table_ptr
,
rulebook_ptr
,
...
...
@@ -766,6 +829,7 @@ int ProductRuleBook(const Context& dev_ctx,
d_paddings
,
d_dilations
,
d_strides
,
is2D
,
rulebook_ptr
,
counter_ptr
);
...
...
@@ -833,11 +897,11 @@ int ProductRuleBook(const Context& dev_ctx,
out_nnz
,
out_index_ptr
);
const
int64_t
sparse_dim
=
4
;
const
int64_t
sparse_dim
=
is2D
?
3
:
4
;
phi
::
DenseTensor
out_indices
=
phi
::
Empty
<
IntT
>
(
dev_ctx
,
{
sparse_dim
,
out_nnz
});
phi
::
DenseTensor
out_values
=
phi
::
Empty
<
T
>
(
dev_ctx
,
{
out_nnz
,
kernel_sizes
[
4
]});
phi
::
Empty
<
T
>
(
dev_ctx
,
{
out_nnz
,
kernel_sizes
[
sparse_dim
]});
out
->
SetMember
(
out_indices
,
out_values
,
out_dims
,
false
);
IntT
*
out_indices_ptr
=
out_indices
.
data
<
IntT
>
();
...
...
@@ -849,6 +913,7 @@ int ProductRuleBook(const Context& dev_ctx,
dev_ctx
.
stream
()
>>>
(
out_index_ptr
,
out_nnz
,
d_out_dims
,
is2D
,
out_index_table_ptr
,
out_indices_ptr
);
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
rulebook_len
,
1
);
...
...
paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu
浏览文件 @
3e3f5d90
...
...
@@ -57,9 +57,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx,
SparseCooTensor
*
x_grad
,
DenseTensor
*
kernel_grad
)
{
const
auto
&
kernel_dims
=
kernel
.
dims
();
const
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
bool
is2D
=
kernel_dims
.
size
()
==
4
?
true
:
false
;
const
int
kernel_size
=
is2D
?
kernel_dims
[
0
]
*
kernel_dims
[
1
]
:
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
const
int
in_channels
=
is2D
?
kernel_dims
[
2
]
:
kernel_dims
[
3
];
const
int
out_channels
=
is2D
?
kernel_dims
[
3
]
:
kernel_dims
[
4
];
int
rulebook_len
=
0
;
const
IntT
*
rulebook_ptr
=
phi
::
funcs
::
sparse
::
GetRulebookPtr
<
IntT
>
(
...
...
@@ -324,7 +327,6 @@ void Conv3dCooGradKernel(const Context& dev_ctx,
kernel_grad
);
}));
}
}
// namespace sparse
}
// namespace phi
...
...
paddle/phi/kernels/sparse/gpu/conv_kernel.cu
浏览文件 @
3e3f5d90
...
...
@@ -85,8 +85,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
// if x.layout != NDHWC then transpose(x), transpose(weight)
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
kernel_dims
=
kernel
.
dims
();
int
kernel_size
=
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
const
bool
is2D
=
x_dims
.
size
()
==
4
?
true
:
false
;
int
kernel_size
=
is2D
?
kernel_dims
[
0
]
*
kernel_dims
[
1
]
:
kernel_dims
[
0
]
*
kernel_dims
[
1
]
*
kernel_dims
[
2
];
int
rank
=
is2D
?
4
:
5
;
std
::
vector
<
int
>
out_dims_vec
(
rank
,
1
);
DDim
out_dims
=
make_ddim
(
out_dims_vec
);
std
::
vector
<
int
>
kernel_sizes
(
kernel_dims
.
size
());
for
(
int
i
=
0
;
i
<
kernel_dims
.
size
();
i
++
)
{
kernel_sizes
[
i
]
=
kernel_dims
[
i
];
...
...
@@ -102,8 +108,8 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
kernel_sizes
,
subm_paddings
,
dilations
,
subm_strides
,
&
out_dims
);
const
int
in_channels
=
kernel_dims
[
3
];
const
int
out_channels
=
kernel_dims
[
4
];
const
int
in_channels
=
is2D
?
kernel_dims
[
2
]
:
kernel_dims
[
3
];
const
int
out_channels
=
is2D
?
kernel_dims
[
3
]
:
kernel_dims
[
4
];
DenseTensor
h_counter
,
h_offsets
;
h_counter
.
Resize
({
kernel_size
});
h_offsets
.
Resize
({
kernel_size
+
1
});
...
...
@@ -118,7 +124,14 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx,
DenseTensor
out_index
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
DenseTensor
unique_value
=
phi
::
Empty
<
int
>
(
dev_ctx
,
{
1
});
VLOG
(
6
)
<<
"call SubmConv3D or Conv3D "
<<
subm
<<
" and the key is "
<<
key
;
if
(
is2D
)
{
VLOG
(
6
)
<<
"call SubmConv2D or Conv2D "
<<
subm
<<
" and the key is "
<<
key
;
}
else
{
VLOG
(
6
)
<<
"call SubmConv3D or Conv3D "
<<
subm
<<
" and the key is "
<<
key
;
}
int
rulebook_len
=
0
;
const
IntT
*
rulebook_ptr
=
nullptr
;
bool
need_product_rulebook
=
true
;
...
...
@@ -313,7 +326,6 @@ void Conv3dCooKernel(const Context& dev_ctx,
counter
);
}));
}
}
// namespace sparse
}
// namespace phi
...
...
python/paddle/sparse/nn/functional/conv.py
浏览文件 @
3e3f5d90
...
...
@@ -14,14 +14,12 @@
__all__
=
[]
import
paddle
from
paddle
import
_C_ops
,
in_dynamic_mode
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.nn.functional.conv
import
_update_padding_nd
from
paddle.utils
import
convert_to_list
from
...binary
import
add
from
...unary
import
reshape
def
_conv3d
(
...
...
@@ -148,20 +146,13 @@ def _conv2d(
)
channel_last
=
data_format
==
"NHWC"
n_dim
=
0
channel_dim
=
-
1
if
channel_last
else
1
h_dim
=
1
if
channel_last
else
2
w_dim
=
2
if
channel_last
else
-
1
if
len
(
x
.
shape
)
!=
4
:
raise
ValueError
(
"Input x should be 4D tensor, but received x with the shape of {}"
.
format
(
x
.
shape
)
)
n
=
x
.
shape
[
n_dim
]
d
=
1
h
=
x
.
shape
[
h_dim
]
w
=
x
.
shape
[
w_dim
]
num_channels
=
x
.
shape
[
channel_dim
]
if
num_channels
<
0
:
raise
ValueError
(
...
...
@@ -173,16 +164,6 @@ def _conv2d(
stride
=
convert_to_list
(
stride
,
dims
,
'stride'
)
dilation
=
convert_to_list
(
dilation
,
dims
,
'dilation'
)
padding
.
insert
(
0
,
0
)
stride
.
insert
(
0
,
1
)
dilation
.
insert
(
0
,
1
)
x
=
reshape
(
x
,
[
n
,
d
,
h
,
w
,
num_channels
])
h_filter
=
weight
.
shape
[
0
]
w_filter
=
weight
.
shape
[
1
]
c_filter
=
weight
.
shape
[
2
]
m_filter
=
weight
.
shape
[
3
]
weight
=
paddle
.
reshape
(
weight
,
[
d
,
h_filter
,
w_filter
,
c_filter
,
m_filter
])
if
in_dynamic_mode
():
pre_bias
=
_C_ops
.
sparse_conv3d
(
x
,
...
...
@@ -217,11 +198,6 @@ def _conv2d(
helper
.
append_op
(
type
=
op_type
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
)
n_out
=
pre_bias
.
shape
[
0
]
h_out
=
pre_bias
.
shape
[
2
]
w_out
=
pre_bias
.
shape
[
3
]
channels_out
=
pre_bias
.
shape
[
4
]
pre_bias
=
reshape
(
pre_bias
,
[
n_out
,
h_out
,
w_out
,
channels_out
])
if
bias
is
not
None
:
return
add
(
pre_bias
,
bias
)
else
:
...
...
test/legacy_test/test_sparse_conv_op.py
浏览文件 @
3e3f5d90
...
...
@@ -378,6 +378,75 @@ class TestStatic(unittest.TestCase):
self
.
assertTrue
(
out_indices
.
dtype
==
paddle
.
int32
)
paddle
.
disable_static
()
def
test_cpu
(
self
):
paddle
.
enable_static
()
main
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
):
indices
=
paddle
.
static
.
data
(
name
=
'indices'
,
shape
=
[
4
,
4
],
dtype
=
'int32'
)
values
=
paddle
.
static
.
data
(
name
=
'values'
,
shape
=
[
4
,
1
],
dtype
=
'float32'
)
dense_shape
=
[
1
,
1
,
3
,
4
,
1
]
sp_x
=
sparse
.
sparse_coo_tensor
(
indices
,
values
,
dense_shape
)
weight_shape
=
[
1
,
3
,
3
,
1
,
1
]
weight
=
paddle
.
static
.
data
(
name
=
'weight'
,
shape
=
weight_shape
,
dtype
=
'float32'
)
bias_shape
=
[
1
]
bias
=
paddle
.
static
.
data
(
name
=
'bias'
,
shape
=
bias_shape
,
dtype
=
'float32'
)
out
=
sparse
.
nn
.
functional
.
conv3d
(
sp_x
,
weight
,
bias
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
data_format
=
"NDHWC"
,
)
sp_out
=
sparse
.
nn
.
functional
.
relu
(
out
)
out_indices
=
sp_out
.
indices
()
out_values
=
sp_out
.
values
()
out
=
sp_out
.
to_dense
()
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
()
indices_data
=
[
[
0
,
0
,
0
,
0
],
[
0
,
0
,
0
,
0
],
[
0
,
0
,
1
,
2
],
[
1
,
3
,
2
,
3
],
]
values_data
=
[[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]]
weight_data
=
np
.
array
(
[[[[[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
]]]]]
).
astype
(
'float32'
)
weight_data
=
weight_data
.
reshape
(
weight_shape
)
bias_data
=
np
.
array
([
1
]).
astype
(
'float32'
)
fetch
=
exe
.
run
(
feed
=
{
'indices'
:
indices_data
,
'values'
:
values_data
,
'weight'
:
weight_data
,
'bias'
:
bias_data
,
},
fetch_list
=
[
out
,
out_indices
,
out_values
],
return_numpy
=
True
,
)
correct_out
=
np
.
array
([[[[[
5.0
],
[
11.0
]]]]]).
astype
(
'float64'
)
correct_out_values
=
[[
5.0
],
[
11.0
]]
np
.
testing
.
assert_array_equal
(
correct_out
,
fetch
[
0
])
np
.
testing
.
assert_array_equal
(
correct_out_values
,
fetch
[
2
])
self
.
assertTrue
(
out_indices
.
dtype
==
paddle
.
int32
)
paddle
.
disable_static
()
def
test2D
(
self
):
paddle
.
enable_static
()
main
=
paddle
.
static
.
Program
()
...
...
@@ -441,6 +510,70 @@ class TestStatic(unittest.TestCase):
self
.
assertTrue
(
out_indices
.
dtype
==
paddle
.
int32
)
paddle
.
disable_static
()
def
test2D_cpu
(
self
):
paddle
.
enable_static
()
main
=
paddle
.
static
.
Program
()
with
paddle
.
static
.
program_guard
(
main
):
indices
=
paddle
.
static
.
data
(
name
=
'indices'
,
shape
=
[
3
,
4
],
dtype
=
'int32'
)
values
=
paddle
.
static
.
data
(
name
=
'values'
,
shape
=
[
4
,
1
],
dtype
=
'float32'
)
dense_shape
=
[
1
,
3
,
4
,
1
]
sp_x
=
sparse
.
sparse_coo_tensor
(
indices
,
values
,
dense_shape
)
weight_shape
=
[
3
,
3
,
1
,
1
]
weight
=
paddle
.
static
.
data
(
name
=
'weight'
,
shape
=
weight_shape
,
dtype
=
'float32'
)
bias_shape
=
[
1
]
bias
=
paddle
.
static
.
data
(
name
=
'bias'
,
shape
=
bias_shape
,
dtype
=
'float32'
)
out
=
sparse
.
nn
.
functional
.
conv2d
(
sp_x
,
weight
,
bias
,
stride
=
1
,
padding
=
0
,
dilation
=
1
,
groups
=
1
,
data_format
=
"NHWC"
,
)
sp_out
=
sparse
.
nn
.
functional
.
relu
(
out
)
out_indices
=
sp_out
.
indices
()
out_values
=
sp_out
.
values
()
out
=
sp_out
.
to_dense
()
place
=
paddle
.
CPUPlace
()
exe
=
paddle
.
static
.
Executor
()
indices_data
=
[[
0
,
0
,
0
,
0
],
[
0
,
0
,
1
,
2
],
[
1
,
3
,
2
,
3
]]
values_data
=
[[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]]
weight_data
=
np
.
array
(
[[[[[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
]],
[[
1
],
[
1
],
[
1
]]]]]
).
astype
(
'float32'
)
weight_data
=
weight_data
.
reshape
(
weight_shape
)
bias_data
=
np
.
array
([
1
]).
astype
(
'float32'
)
fetch
=
exe
.
run
(
feed
=
{
'indices'
:
indices_data
,
'values'
:
values_data
,
'weight'
:
weight_data
,
'bias'
:
bias_data
,
},
fetch_list
=
[
out
,
out_indices
,
out_values
],
return_numpy
=
True
,
)
correct_out
=
np
.
array
([[[[
5.0
],
[
11.0
]]]]).
astype
(
'float64'
)
correct_out_values
=
[[
5.0
],
[
11.0
]]
np
.
testing
.
assert_array_equal
(
correct_out
,
fetch
[
0
])
np
.
testing
.
assert_array_equal
(
correct_out_values
,
fetch
[
2
])
self
.
assertTrue
(
out_indices
.
dtype
==
paddle
.
int32
)
paddle
.
disable_static
()
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录