Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
e8e3b997
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e8e3b997
编写于
5月 05, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
5月 05, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix sparse mask (#42305)
上级
e51fad5f
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
35 addition
and
18 deletion
+35
-18
paddle/phi/core/sparse_coo_tensor.cc
paddle/phi/core/sparse_coo_tensor.cc
+8
-0
paddle/phi/core/sparse_coo_tensor.h
paddle/phi/core/sparse_coo_tensor.h
+6
-0
paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc
+2
-2
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
+1
-1
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
+1
-1
paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu
+5
-8
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
+1
-1
python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py
...on/paddle/fluid/tests/unittests/test_sparse_pooling_op.py
+11
-5
未找到文件。
paddle/phi/core/sparse_coo_tensor.cc
浏览文件 @
e8e3b997
...
...
@@ -115,4 +115,12 @@ void SparseCooTensor::SetMember(const DenseTensor& non_zero_indices,
this
->
coalesced_
=
coalesced
;
}
int32_t
SparseCooTensor
::
sparse_dim
()
const
{
return
non_zero_indices_
.
dims
()[
0
];
}
int32_t
SparseCooTensor
::
dense_dim
()
const
{
return
dims_
.
size
()
-
sparse_dim
();
}
}
// namespace phi
paddle/phi/core/sparse_coo_tensor.h
浏览文件 @
e8e3b997
...
...
@@ -150,6 +150,12 @@ class SparseCooTensor : public TensorBase,
/// \brief set the dims of original dense tensor
void
set_dims
(
const
DDim
&
dims
)
{
this
->
dims_
=
dims
;
}
/// \brief get the sparse dim
int32_t
sparse_dim
()
const
;
/// \brief get the dnese dim
int32_t
dense_dim
()
const
;
private:
// save the indices of non zero elements in original dense tensor
DenseTensor
non_zero_indices_
;
...
...
paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc
浏览文件 @
e8e3b997
...
...
@@ -39,7 +39,7 @@ void SparseMaskCPUKernel(const CPUContext& dev_ctx,
phi
::
errors
::
InvalidArgument
(
"the input x and mask must have the shape"
));
const
DenseTensor
&
indices
=
mask
.
non_zero_indices
();
const
DenseTensor
&
values
=
mask
.
non_zero_elements
();
int
sparse_dim
=
indices
.
dims
().
size
();
const
int
sparse_dim
=
mask
.
sparse_dim
();
DenseTensor
out_indices
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
indices
);
DenseTensor
out_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
values
);
...
...
@@ -95,7 +95,7 @@ void SparseMaskHelperCPUKernel(const CPUContext& dev_ctx,
2
,
phi
::
errors
::
InvalidArgument
(
"the mask_indices must be 2-D tensor"
));
const
int
64_t
sparse_dim
=
x
.
non_zero_indices
().
dims
()[
0
]
;
const
int
32_t
sparse_dim
=
x
.
sparse_dim
()
;
std
::
vector
<
IntT
>
sparse_offsets
(
sparse_dim
),
x_indexs
(
x
.
nnz
()),
mask_indexs
(
mask_indices
.
dims
()[
1
]);
...
...
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
浏览文件 @
e8e3b997
...
...
@@ -50,7 +50,7 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx,
DenseTensor
x_grad_values
=
phi
::
EmptyLike
<
T
>
(
dev_ctx
,
x
.
non_zero_elements
());
x_grad
->
SetMember
(
x_grad_indices
,
x_grad_values
,
x
.
dims
(),
true
);
T
*
x_grad_ptr
=
x_grad_values
.
data
<
T
>
();
memset
(
x_grad_ptr
,
0
,
sizeof
(
T
)
*
x_grad
->
numel
());
memset
(
x_grad_ptr
,
0
,
sizeof
(
T
)
*
x_grad
_values
.
numel
());
phi
::
Copy
<
CPUContext
>
(
dev_ctx
,
x
.
non_zero_indices
(),
dev_ctx
.
GetPlace
(),
...
...
paddle/phi/kernels/sparse/cpu/sparse_utils_kernel.cc
浏览文件 @
e8e3b997
...
...
@@ -254,7 +254,7 @@ void SparseCooToDenseKernel(const Context& dev_ctx,
if
(
indices_dims
.
size
()
==
1
)
{
sparse_dim
=
1
;
}
const
int64_t
dense_dim
=
values
.
dims
().
size
()
-
1
;
const
int64_t
dense_dim
=
x
.
dense_dim
()
;
const
T
*
x_data
=
values
.
data
<
T
>
();
*
out
=
phi
::
Empty
(
...
...
paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu
浏览文件 @
e8e3b997
...
...
@@ -42,7 +42,7 @@ __global__ void MaskKernel(const T* x_ptr,
int64_t
col_i
=
i
-
out_i
*
cols
;
int64_t
index
=
0
;
for
(
int
j
=
0
;
j
<
sparse_dim
;
j
++
)
{
index
+=
indices_ptr
[
j
*
non_zero_num
+
i
]
*
sparse_offsets
[
j
];
index
+=
indices_ptr
[
j
*
non_zero_num
+
out_
i
]
*
sparse_offsets
[
j
];
}
out_values_ptr
[
out_i
*
cols
+
col_i
]
=
x_ptr
[
index
*
cols
+
col_i
];
}
...
...
@@ -60,16 +60,13 @@ void SparseMaskGPUKernel(const GPUContext& dev_ctx,
phi
::
errors
::
InvalidArgument
(
"the input x and mask must have the shape"
));
const
DenseTensor
&
indices
=
mask
.
non_zero_indices
();
const
DenseTensor
&
values
=
mask
.
non_zero_elements
();
int
sparse_dim
=
indices
.
dims
().
size
();
const
int
sparse_dim
=
mask
.
sparse_dim
();
DenseTensor
sparse_offsets
=
phi
::
Empty
<
GPUContext
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT64
,
{
sparse_dim
},
DataLayout
::
NCHW
));
std
::
vector
<
int64_t
>
h_sparse_offsets
(
sparse_dim
);
int64_t
offset
=
1
;
for
(
int
i
=
sparse_dim
-
1
;
i
>=
0
;
i
--
)
{
h_sparse_offsets
[
i
]
=
offset
;
offset
*=
dims
[
i
];
}
phi
::
funcs
::
sparse
::
CalcOffsetsPerDim
(
dims
,
sparse_dim
,
h_sparse_offsets
.
data
());
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
sparse_offsets
.
data
<
int64_t
>
(),
&
h_sparse_offsets
[
0
],
...
...
@@ -151,7 +148,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx,
2
,
phi
::
errors
::
InvalidArgument
(
"the mask_indices must be 2-D tensor"
));
const
int
64_t
sparse_dim
=
x
.
non_zero_indices
().
dims
()[
0
]
;
const
int
32_t
sparse_dim
=
x
.
sparse_dim
()
;
auto
indices_dtype
=
paddle
::
experimental
::
CppTypeToDataType
<
IntT
>::
Type
();
std
::
vector
<
IntT
>
sparse_offsets
(
sparse_dim
);
...
...
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
浏览文件 @
e8e3b997
...
...
@@ -64,7 +64,7 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx,
int
rulebook_len
=
rulebook
.
dims
()[
1
];
const
IntT
*
rulebook_ptr
=
rulebook
.
data
<
IntT
>
();
std
::
vector
<
IntT
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
h_counter
(
kernel_size
);
h_counter
(
rulebook_len
,
0
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
rulebook_ptr
,
rulebook_len
*
sizeof
(
IntT
),
...
...
python/paddle/fluid/tests/unittests/test_sparse_pooling_op.py
浏览文件 @
e8e3b997
...
...
@@ -19,6 +19,7 @@ import paddle
import
paddle.fluid.core
as
core
from
paddle
import
_C_ops
from
paddle.fluid.framework
import
_test_eager_guard
import
copy
class
TestMaxPool3DFunc
(
unittest
.
TestCase
):
...
...
@@ -44,23 +45,28 @@ class TestMaxPool3DFunc(unittest.TestCase):
def
test
(
self
):
with
_test_eager_guard
():
self
.
setUp
()
self
.
dense_x
.
stop_gradient
=
False
sparse_x
=
self
.
dense_x
.
to_sparse_coo
(
4
)
out
=
paddle
.
sparse
.
functional
.
max_pool3d
(
sparse_
out
=
paddle
.
sparse
.
functional
.
max_pool3d
(
sparse_x
,
self
.
kernel_sizes
,
stride
=
self
.
strides
,
padding
=
self
.
paddings
)
out
=
out
.
to_dense
()
out
=
sparse_out
.
to_dense
()
out
.
backward
(
out
)
dense_x
=
copy
.
deepcopy
(
self
.
dense_x
)
dense_out
=
paddle
.
nn
.
functional
.
max_pool3d
(
self
.
dense_x
,
dense_x
,
self
.
kernel_sizes
,
stride
=
self
.
strides
,
padding
=
self
.
paddings
,
data_format
=
'NDHWC'
)
dense_out
.
backward
(
dense_out
)
#compare with dense
assert
np
.
allclose
(
dense_out
.
flatten
().
numpy
(),
out
.
flatten
()
.
numpy
())
assert
np
.
allclose
(
dense_out
.
numpy
(),
out
.
numpy
())
assert
np
.
allclose
(
dense_x
.
grad
.
numpy
(),
self
.
dense_x
.
grad
.
numpy
())
class
TestStride
(
TestMaxPool3DFunc
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录