Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
e52ffb70
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e52ffb70
编写于
3月 18, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
3月 18, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add Sparse OP Maxpool (#40569)
sparse maxpool; kernel_registry support sparse tensor
上级
aed6faf2
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
988 addition
and
1 deletion
+988
-1
paddle/phi/core/kernel_registry.h
paddle/phi/core/kernel_registry.h
+32
-0
paddle/phi/kernels/funcs/pooling.h
paddle/phi/kernels/funcs/pooling.h
+1
-1
paddle/phi/kernels/funcs/sparse/convolution.h
paddle/phi/kernels/funcs/sparse/convolution.h
+20
-0
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
+73
-0
paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc
paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc
+108
-0
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
+120
-0
paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu
paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu
+140
-0
paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h
paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h
+49
-0
paddle/phi/kernels/sparse/sparse_pool_kernel.h
paddle/phi/kernels/sparse/sparse_pool_kernel.h
+53
-0
paddle/phi/tests/kernels/CMakeLists.txt
paddle/phi/tests/kernels/CMakeLists.txt
+1
-0
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
+391
-0
未找到文件。
paddle/phi/core/kernel_registry.h
浏览文件 @
e52ffb70
...
...
@@ -98,6 +98,28 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
const
SparseCooTensor
&
)))
{
args_def
->
AppendInput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
paddle
::
optional
<
const
SparseCooTensor
&>
)))
{
args_def
->
AppendInput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
const
SparseCsrTensor
&
)))
{
args_def
->
AppendInput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
paddle
::
optional
<
const
SparseCsrTensor
&>
)))
{
args_def
->
AppendInput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
DenseTensor
*
)))
{
args_def
->
AppendOutput
(
default_key
.
backend
(),
default_tensor_layout
,
...
...
@@ -114,6 +136,16 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
SparseCooTensor
*
)))
{
args_def
->
AppendOutput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
if
(
arg_type
==
std
::
type_index
(
typeid
(
SparseCsrTensor
*
)))
{
args_def
->
AppendOutput
(
default_key
.
backend
(),
default_tensor_layout
,
default_key
.
dtype
(),
arg_type
);
}
else
{
// Attribute deal with
// TODO(chenweihang): now here allow any types of attribute, maybe
...
...
paddle/phi/kernels/funcs/pooling.h
浏览文件 @
e52ffb70
...
...
@@ -43,7 +43,7 @@ template <class T>
class
MaxPool
{
public:
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
HOST
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
DEVICE
inline
void
finalize
(
const
T
&
pool_field
,
T
*
y
)
{}
};
...
...
paddle/phi/kernels/funcs/sparse/convolution.h
浏览文件 @
e52ffb70
...
...
@@ -165,6 +165,26 @@ inline void SubmPreProcess(const Context& dev_ctx,
x_grad_ptr
);
}
inline
const
std
::
vector
<
int
>
PoolResetKernel
(
const
std
::
vector
<
int
>&
kernel_sizes
,
const
int
in_channels
,
const
int
out_channels
)
{
std
::
vector
<
int
>
res
(
kernel_sizes
);
res
.
resize
(
5
);
res
[
3
]
=
in_channels
;
res
[
4
]
=
out_channels
;
return
res
;
}
inline
void
PrefixSum
(
const
int
*
counter
,
int
*
offsets
,
const
int
n
)
{
int
offset
=
0
;
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
offsets
[
i
]
=
offset
;
offset
+=
counter
[
i
];
}
offsets
[
n
]
=
offset
;
}
}
// namespace sparse
}
// namespace funcs
}
// namespace phi
paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
DenseTensor
*
x_grad
)
{
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
int
channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
const
int
*
rulebook_ptr
=
rulebook
.
data
<
int
>
();
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
);
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
rulebook_ptr
[
i
]]
+=
1
;
}
phi
::
funcs
::
sparse
::
PrefixSum
(
&
counter
[
0
],
&
offsets
[
0
],
kernel_size
);
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_grad_ptr
=
out_grad
.
data
<
T
>
();
T
*
x_grad_ptr
=
x_grad
->
data
<
T
>
();
memset
(
x_grad_ptr
,
0
,
sizeof
(
T
)
*
x_grad
->
numel
());
phi
::
funcs
::
MaxPoolGrad
<
T
>
grad_functor
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
j
=
0
;
j
<
counter
[
i
];
j
++
)
{
int
in_i
=
rulebook_ptr
[
rulebook_len
+
offsets
[
i
]
+
j
];
int
out_i
=
rulebook_ptr
[
rulebook_len
*
2
+
offsets
[
i
]
+
j
];
for
(
int
c
=
0
;
c
<
channels
;
c
++
)
{
grad_functor
.
compute
(
in_features_ptr
[
in_i
*
channels
+
c
],
out_features_ptr
[
out_i
*
channels
+
c
],
out_grad_ptr
[
out_i
*
channels
+
c
],
1
,
&
x_grad_ptr
[
in_i
*
channels
+
c
]);
}
}
}
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/cpu/convolution.h"
namespace
phi
{
namespace
sparse
{
/**
* x: (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
* out: (N, D, H, W, OC)
**/
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
phi
::
funcs
::
sparse
::
PoolResetKernel
(
kernel_sizes
,
x_dims
[
4
],
x_dims
[
4
]);
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
real_kernel_sizes
[
3
];
DenseTensorMeta
counter_meta
(
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
// 1. product rule book
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
out_dims
,
false
,
rulebook
,
&
counter_per_kernel
);
UpdateRulebookAndOutIndex
<
T
>
(
dev_ctx
,
x
,
kernel_size
,
in_channels
,
out_dims
,
rulebook
,
out
);
int
rulebook_len
=
rulebook
->
dims
()[
1
];
const
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
const
int
*
counter_ptr
=
counter_per_kernel
.
data
<
int
>
();
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
);
phi
::
funcs
::
sparse
::
PrefixSum
(
counter_ptr
,
&
offsets
[
0
],
kernel_size
);
std
::
vector
<
bool
>
out_flags
(
out
->
nnz
(),
false
);
// 2. max pool
T
*
out_features_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
phi
::
funcs
::
MaxPool
<
T
>
max_pool_functor
;
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
for
(
int
j
=
0
;
j
<
counter_ptr
[
i
];
j
++
)
{
int
in_i
=
rulebook_ptr
[
rulebook_len
+
offsets
[
i
]
+
j
];
int
out_i
=
rulebook_ptr
[
rulebook_len
*
2
+
offsets
[
i
]
+
j
];
if
(
!
out_flags
[
out_i
])
{
out_flags
[
out_i
]
=
true
;
memcpy
(
&
out_features_ptr
[
out_i
*
in_channels
],
&
in_features_ptr
[
in_i
*
in_channels
],
in_channels
*
sizeof
(
T
));
}
else
{
for
(
int
c
=
0
;
c
<
in_channels
;
c
++
)
{
max_pool_functor
.
compute
(
in_features_ptr
[
in_i
*
in_channels
+
c
],
&
out_features_ptr
[
out_i
*
in_channels
+
c
]);
}
}
}
}
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
,
CPU
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
>
__global__
void
MaxPoolGradCudaKernel
(
const
T
*
in_features_ptr
,
const
T
*
out_features_ptr
,
const
T
*
out_grad_ptr
,
const
int
*
rulebook_ptr
,
const
int
n
,
const
int
rulebook_len
,
const
int
channels
,
T
*
x_grad_ptr
)
{
phi
::
funcs
::
MaxPoolGrad
<
T
>
grad_functor
;
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
*
channels
,
int64_t
)
{
int
real_i
=
i
/
channels
;
int
c
=
i
-
real_i
*
channels
;
int
in_i
=
rulebook_ptr
[
real_i
];
int
out_i
=
rulebook_ptr
[
real_i
+
rulebook_len
];
grad_functor
.
compute
(
in_features_ptr
[
in_i
*
channels
+
c
],
out_features_ptr
[
out_i
*
channels
+
c
],
out_grad_ptr
[
out_i
*
channels
+
c
],
1
,
&
x_grad_ptr
[
in_i
*
channels
+
c
]);
}
}
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
DenseTensor
*
x_grad
)
{
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
int
in_channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
const
int
*
rulebook_ptr
=
rulebook
.
data
<
int
>
();
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
,
0
),
h_counter
(
kernel_size
);
phi
::
backends
::
gpu
::
GpuMemcpyAsync
(
&
h_counter
[
0
],
rulebook_ptr
,
rulebook_len
*
sizeof
(
int
),
#ifdef PADDLE_WITH_HIP
hipMemcpyDeviceToHost
,
#else
cudaMemcpyDeviceToHost
,
#endif
dev_ctx
.
stream
());
dev_ctx
.
Wait
();
for
(
int
i
=
0
;
i
<
rulebook_len
;
i
++
)
{
counter
[
h_counter
[
i
]]
+=
1
;
}
phi
::
funcs
::
sparse
::
PrefixSum
(
&
counter
[
0
],
&
offsets
[
0
],
kernel_size
);
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_features_ptr
=
out
.
non_zero_elements
().
data
<
T
>
();
const
T
*
out_grad_ptr
=
out_grad
.
data
<
T
>
();
T
*
x_grad_ptr
=
x_grad
->
data
<
T
>
();
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0.0
f
));
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
continue
;
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
counter
[
i
]
*
in_channels
,
1
);
MaxPoolGradCudaKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
out_features_ptr
,
out_grad_ptr
,
rulebook_ptr
+
offsets
[
i
]
+
rulebook_len
,
counter
[
i
],
rulebook_len
,
in_channels
,
x_grad_ptr
);
}
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolGradKernel
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/sparse/convolution.h"
#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h"
#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
>
__global__
void
MaxPoolCudaKernel
(
const
T
*
in_features_ptr
,
const
int
*
rulebook_ptr
,
const
int
n
,
const
int
rulebook_len
,
const
int
channels
,
T
*
out_features_ptr
)
{
phi
::
funcs
::
MaxPool
<
T
>
max_pool_functor
;
CUDA_KERNEL_LOOP_TYPE
(
i
,
n
*
channels
,
int64_t
)
{
int
real_i
=
i
/
channels
;
int
channel_i
=
i
-
real_i
*
channels
;
int
in_i
=
rulebook_ptr
[
real_i
];
int
out_i
=
rulebook_ptr
[
real_i
+
rulebook_len
];
max_pool_functor
.
compute
(
in_features_ptr
[
in_i
*
channels
+
channel_i
],
&
out_features_ptr
[
out_i
*
channels
+
channel_i
]);
}
}
/**
* x: (N, D, H, W, C)
* kernel: (D, H, W, C, OC)
* out: (N, D, H, W, OC)
**/
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
phi
::
funcs
::
sparse
::
PoolResetKernel
(
kernel_sizes
,
x_dims
[
4
],
x_dims
[
4
]);
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
1
};
phi
::
funcs
::
sparse
::
GetOutShape
(
x_dims
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
&
out_dims
);
const
int
in_channels
=
real_kernel_sizes
[
3
];
std
::
vector
<
int
>
offsets
(
kernel_size
+
1
),
counter
(
kernel_size
);
DenseTensorMeta
counter_meta
(
DataType
::
INT32
,
{
kernel_size
},
DataLayout
::
NCHW
);
DenseTensor
counter_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensor
offsets_per_kernel
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
counter_meta
));
DenseTensorMeta
index_meta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
);
DenseTensor
out_index
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
unique_key
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
DenseTensor
unique_value
=
phi
::
Empty
(
dev_ctx
,
std
::
move
(
index_meta
));
// 1. product rulebook
int
rulebook_len
=
ProductRuleBook
<
T
,
Context
>
(
dev_ctx
,
x
,
real_kernel_sizes
,
paddings
,
dilations
,
strides
,
out_dims
,
false
,
rulebook
,
&
counter_per_kernel
,
&
offsets_per_kernel
,
&
out_index
,
&
unique_key
,
&
unique_value
,
out
,
&
counter
,
&
offsets
);
const
int
*
rulebook_ptr
=
rulebook
->
data
<
int
>
();
T
*
out_features_ptr
=
out
->
mutable_non_zero_elements
()
->
data
<
T
>
();
const
T
*
in_features_ptr
=
x
.
non_zero_elements
().
data
<
T
>
();
// 2. max pool
#ifdef PADDLE_WITH_HIP
thrust
::
fill
(
thrust
::
hip
::
par
.
on
(
dev_ctx
.
stream
()),
#else
thrust
::
fill
(
thrust
::
cuda
::
par
.
on
(
dev_ctx
.
stream
()),
#endif
out_features_ptr
,
out_features_ptr
+
out
->
non_zero_elements
().
numel
(),
static_cast
<
T
>
(
-
FLT_MAX
));
// TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster
for
(
int
i
=
0
;
i
<
kernel_size
;
i
++
)
{
if
(
counter
[
i
]
<=
0
)
{
continue
;
}
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
counter
[
i
]
*
in_channels
,
1
);
MaxPoolCudaKernel
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
dev_ctx
.
stream
()
>>>
(
in_features_ptr
,
rulebook_ptr
+
offsets
[
i
]
+
rulebook_len
,
counter
[
i
],
rulebook_len
,
in_channels
,
out_features_ptr
);
}
}
}
// namespace sparse
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
,
GPU
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolKernel
,
float
,
double
,
phi
::
dtype
::
float16
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
}
paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
DenseTensor
*
x_grad
);
template
<
typename
T
,
typename
Context
>
DenseTensor
MaxPoolGrad
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
DenseTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
)
{
DenseTensor
x_grad
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
x
.
non_zero_elements
().
dims
(),
x
.
layout
()));
MaxPoolGradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
&
x_grad
);
return
x_grad
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/kernels/sparse/sparse_pool_kernel.h
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/sparse_coo_tensor.h"
#include "paddle/phi/kernels/empty_kernel.h"
namespace
phi
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
);
template
<
typename
T
,
typename
Context
>
SparseCooTensor
MaxPool
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
DenseTensor
*
rulebook
)
{
DenseTensor
indices
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
DenseTensor
values
=
phi
::
Empty
<
Context
>
(
dev_ctx
,
DenseTensorMeta
(
x
.
dtype
(),
{
1
},
x
.
layout
()));
SparseCooTensor
coo
(
indices
,
values
,
x
.
dims
());
MaxPoolKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
coo
,
rulebook
);
return
coo
;
}
}
// namespace sparse
}
// namespace phi
paddle/phi/tests/kernels/CMakeLists.txt
浏览文件 @
e52ffb70
...
...
@@ -14,6 +14,7 @@ cc_test(test_concat_dev_api SRCS test_concat_dev_api.cc DEPS phi phi_api_utils)
cc_test
(
test_split_dev_api SRCS test_split_dev_api.cc DEPS phi phi_api_utils
)
cc_test
(
test_sparse_utils_dev_api SRCS test_sparse_utils_dev_api.cc DEPS phi phi_api_utils
)
cc_test
(
test_sparse_conv3d_dev_api SRCS test_sparse_conv3d_dev_api.cc DEPS phi phi_api_utils
)
cc_test
(
test_sparse_pool_dev_api SRCS test_sparse_pool_dev_api.cc DEPS phi phi_api_utils
)
cc_test
(
test_math_function SRCS test_math_function.cc DEPS math_function
)
if
(
WITH_GPU
)
...
...
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
0 → 100644
浏览文件 @
e52ffb70
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <memory>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/copy_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/sparse_pool_kernel.h"
#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/phi/api/lib/utils/allocator.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
tests
{
template
<
typename
T1
,
typename
T2
>
std
::
vector
<
T2
>
cast
(
const
std
::
vector
<
T1
>&
in
)
{
std
::
vector
<
T2
>
out
(
in
.
size
());
for
(
uint64_t
i
=
0
;
i
<
in
.
size
();
i
++
)
{
out
[
i
]
=
static_cast
<
T2
>
(
in
[
i
]);
}
return
out
;
}
template
<
typename
T
>
void
TestMaxPoolBase
(
const
std
::
vector
<
int
>&
indices
,
const
std
::
vector
<
T
>&
features
,
const
DDim
&
x_dims
,
const
std
::
vector
<
int
>&
correct_out_indices
,
const
std
::
vector
<
T
>&
correct_out_features
,
const
DDim
&
correct_out_dims
,
const
int
non_zero_num
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
float
diff
=
1e-3
,
const
bool
backward
=
false
,
const
std
::
vector
<
T
>
features_grad
=
{})
{
phi
::
CPUContext
dev_ctx_cpu
;
dev_ctx_cpu
.
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
paddle
::
platform
::
CPUPlace
())
.
get
());
dev_ctx_cpu
.
Init
();
const
int
in_channels
=
x_dims
[
4
];
const
int
out_channels
=
in_channels
;
DenseTensor
indices_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
non_zero_num
},
DataLayout
::
NCHW
));
memcpy
(
indices_tensor
.
data
<
int
>
(),
indices
.
data
(),
indices
.
size
()
*
sizeof
(
int
));
DenseTensor
features_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
non_zero_num
,
in_channels
},
DataLayout
::
NHWC
));
memcpy
(
features_tensor
.
data
<
T
>
(),
features
.
data
(),
features
.
size
()
*
sizeof
(
T
));
SparseCooTensor
x_tensor
(
indices_tensor
,
features_tensor
,
x_dims
);
auto
f_verify
=
[
&
](
const
T
*
real_data
,
const
std
::
vector
<
T
>&
correct_data
)
{
for
(
uint64_t
i
=
0
;
i
<
correct_data
.
size
();
i
++
)
{
float
tmp
=
std
::
fabs
(
static_cast
<
float
>
(
correct_data
[
i
]
-
real_data
[
i
]));
ASSERT_LT
(
tmp
,
diff
);
}
};
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
DenseTensor
rulebook
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
out
=
sparse
::
MaxPool
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
rulebook
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
ASSERT_EQ
(
correct_out_dims
[
i
],
out
.
dims
()[
i
]);
}
ASSERT_EQ
((
int64_t
)
correct_out_features
.
size
()
/
out_channels
,
out
.
nnz
());
int
cmp_indices
=
memcmp
(
correct_out_indices
.
data
(),
out
.
non_zero_indices
().
data
<
int
>
(),
correct_out_indices
.
size
()
*
sizeof
(
int
));
ASSERT_EQ
(
cmp_indices
,
0
);
f_verify
(
out
.
non_zero_elements
().
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
DenseTensor
x_grad
=
sparse
::
MaxPoolGrad
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
rulebook
,
out
,
out
.
non_zero_elements
(),
kernel_sizes
);
f_verify
(
x_grad
.
data
<
T
>
(),
features_grad
);
}
}
// test gpu
#if defined(PADDLE_WITH_CUDA)
phi
::
GPUContext
dev_ctx_gpu
;
dev_ctx_gpu
.
PartialInitWithoutAllocator
();
dev_ctx_gpu
.
SetAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
dev_ctx_gpu
.
GetPlace
(),
dev_ctx_gpu
.
stream
())
.
get
());
dev_ctx_gpu
.
SetHostAllocator
(
paddle
::
memory
::
allocation
::
AllocatorFacade
::
Instance
()
.
GetAllocator
(
phi
::
CPUPlace
())
.
get
());
dev_ctx_gpu
.
PartialInitWithAllocator
();
DenseTensor
d_indices_tensor
=
phi
::
Empty
(
dev_ctx_gpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
non_zero_num
},
DataLayout
::
NCHW
));
phi
::
Copy
(
dev_ctx_gpu
,
indices_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_indices_tensor
);
DenseTensor
d_features_tensor
=
phi
::
Empty
(
dev_ctx_gpu
,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
non_zero_num
,
in_channels
},
DataLayout
::
NHWC
));
phi
::
Copy
(
dev_ctx_gpu
,
features_tensor
,
phi
::
GPUPlace
(),
true
,
&
d_features_tensor
);
SparseCooTensor
d_x_tensor
(
d_indices_tensor
,
d_features_tensor
,
x_dims
);
DenseTensor
d_rulebook
=
phi
::
Empty
(
dev_ctx_gpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
1
},
DataLayout
::
NCHW
));
SparseCooTensor
d_out
=
sparse
::
MaxPool
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
d_rulebook
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
ASSERT_EQ
((
int64_t
)
correct_out_features
.
size
()
/
out_channels
,
d_out
.
nnz
());
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
ASSERT_EQ
(
correct_out_dims
[
i
],
d_out
.
dims
()[
i
]);
}
DenseTensor
h_indices_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
DataType
::
INT32
,
{
4
,
d_out
.
nnz
()},
DataLayout
::
NCHW
));
phi
::
Copy
(
dev_ctx_gpu
,
d_out
.
non_zero_indices
(),
phi
::
CPUPlace
(),
true
,
&
h_indices_tensor
);
int
cmp_indices2
=
memcmp
(
correct_out_indices
.
data
(),
h_indices_tensor
.
data
<
int
>
(),
correct_out_indices
.
size
()
*
sizeof
(
int
));
ASSERT_EQ
(
cmp_indices2
,
0
);
DenseTensor
h_features_tensor
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
paddle
::
experimental
::
CppTypeToDataType
<
T
>::
Type
(),
{
d_out
.
nnz
()},
d_out
.
layout
()));
phi
::
Copy
(
dev_ctx_gpu
,
d_out
.
non_zero_elements
(),
phi
::
CPUPlace
(),
true
,
&
h_features_tensor
);
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
DenseTensor
x_grad
=
sparse
::
MaxPoolGrad
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_out
,
d_out
.
non_zero_elements
(),
kernel_sizes
);
DenseTensor
h_features_grad
=
phi
::
Empty
(
dev_ctx_cpu
,
DenseTensorMeta
(
x_grad
.
dtype
(),
x_grad
.
dims
(),
x_grad
.
layout
()));
phi
::
Copy
(
dev_ctx_gpu
,
x_grad
,
phi
::
CPUPlace
(),
true
,
&
h_features_grad
);
f_verify
(
h_features_grad
.
data
<
T
>
(),
features_grad
);
}
#endif
}
void
TestMaxPool
(
const
std
::
vector
<
int
>&
indices
,
const
std
::
vector
<
float
>&
features
,
const
DDim
&
x_dims
,
const
std
::
vector
<
int
>&
correct_out_indices
,
const
std
::
vector
<
float
>&
correct_out_features
,
const
DDim
&
correct_out_dims
,
const
int
non_zero_num
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
dilations
,
const
float
diff
=
1e-3
,
const
bool
backward
=
false
,
const
std
::
vector
<
float
>
features_grad
=
{})
{
// test float
TestMaxPoolBase
<
float
>
(
indices
,
features
,
x_dims
,
correct_out_indices
,
correct_out_features
,
correct_out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
diff
,
backward
,
features_grad
);
// test double
TestMaxPoolBase
<
double
>
(
indices
,
cast
<
float
,
double
>
(
features
),
x_dims
,
correct_out_indices
,
cast
<
float
,
double
>
(
correct_out_features
),
correct_out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
diff
,
backward
,
cast
<
float
,
double
>
(
features_grad
));
}
TEST
(
DEV_API
,
sparse_maxpool
)
{
const
int
channels
=
1
;
DDim
x_dims
=
{
1
,
1
,
4
,
4
,
channels
};
DDim
out_dims
=
{
1
,
1
,
2
,
2
,
channels
};
std
::
vector
<
int
>
kernel_sizes
=
{
1
,
3
,
3
};
std
::
vector
<
int
>
paddings
=
{
0
,
0
,
0
};
std
::
vector
<
int
>
strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
3
;
std
::
vector
<
int
>
indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
3
,
0
,
1
,
2
};
std
::
vector
<
float
>
features
=
{
1
,
2
,
3
};
std
::
vector
<
int
>
out_indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
};
std
::
vector
<
float
>
out_features
=
{
2
,
2
,
3
,
3
};
std
::
vector
<
float
>
x_grad
=
{
0
,
4
,
6
};
TestMaxPool
(
indices
,
features
,
x_dims
,
out_indices
,
out_features
,
out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
1e-6
,
true
,
x_grad
);
}
TEST
(
DEV_API
,
sparse_maxpool_stride
)
{
const
int
channels
=
1
;
DDim
x_dims
=
{
1
,
1
,
4
,
4
,
channels
};
DDim
out_dims
=
{
1
,
1
,
1
,
1
,
channels
};
std
::
vector
<
int
>
kernel_sizes
=
{
1
,
3
,
3
};
std
::
vector
<
int
>
paddings
=
{
0
,
0
,
0
};
std
::
vector
<
int
>
strides
=
{
2
,
2
,
2
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
3
;
std
::
vector
<
int
>
indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
3
,
0
,
1
,
2
};
std
::
vector
<
float
>
features
=
{
1
,
2
,
3
};
std
::
vector
<
int
>
out_indices
=
{
0
,
0
,
0
,
0
};
std
::
vector
<
float
>
out_features
=
{
2
};
std
::
vector
<
float
>
x_grad
=
{
0
,
2
,
0
};
TestMaxPool
(
indices
,
features
,
x_dims
,
out_indices
,
out_features
,
out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
1e-6
,
true
,
x_grad
);
}
TEST
(
DEV_API
,
sparse_maxpool_channel
)
{
const
int
channels
=
2
;
DDim
x_dims
=
{
1
,
1
,
4
,
4
,
channels
};
DDim
out_dims
=
{
1
,
1
,
2
,
2
,
channels
};
std
::
vector
<
int
>
kernel_sizes
=
{
1
,
3
,
3
};
std
::
vector
<
int
>
paddings
=
{
0
,
0
,
0
};
std
::
vector
<
int
>
strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
3
;
std
::
vector
<
int
>
indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
3
,
0
,
1
,
2
};
std
::
vector
<
float
>
features
=
{
1
,
1
,
2
,
2
,
3
,
3
};
std
::
vector
<
int
>
out_indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
};
std
::
vector
<
float
>
out_features
=
{
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
};
std
::
vector
<
float
>
x_grad
=
{
0
,
0
,
4
,
4
,
6
,
6
};
TestMaxPool
(
indices
,
features
,
x_dims
,
out_indices
,
out_features
,
out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
1e-6
,
true
,
x_grad
);
}
TEST
(
DEV_API
,
sparse_maxpool3d
)
{
const
int
channels
=
2
;
DDim
x_dims
=
{
1
,
5
,
4
,
4
,
channels
};
DDim
out_dims
=
{
1
,
3
,
2
,
2
,
channels
};
std
::
vector
<
int
>
kernel_sizes
=
{
3
,
3
,
3
};
std
::
vector
<
int
>
paddings
=
{
0
,
0
,
0
};
std
::
vector
<
int
>
strides
=
{
1
,
1
,
1
};
std
::
vector
<
int
>
dilations
=
{
1
,
1
,
1
};
const
int
non_zero_num
=
3
;
std
::
vector
<
int
>
indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
3
,
0
,
1
,
2
};
std
::
vector
<
float
>
features
=
{
1
,
1
,
2
,
2
,
3
,
3
};
std
::
vector
<
int
>
out_indices
=
{
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
0
,
1
,
1
,
0
,
1
,
0
,
1
,
};
std
::
vector
<
float
>
out_features
=
{
2
,
2
,
2
,
2
,
3
,
3
,
3
,
3
};
std
::
vector
<
float
>
x_grad
=
{
0
,
0
,
4
,
4
,
6
,
6
};
TestMaxPool
(
indices
,
features
,
x_dims
,
out_indices
,
out_features
,
out_dims
,
non_zero_num
,
kernel_sizes
,
paddings
,
strides
,
dilations
,
1e-6
,
true
,
x_grad
);
}
}
// namespace tests
}
// namespace phi
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录