Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
9e307229
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9e307229
编写于
7月 19, 2022
作者:
Z
zhangkaihuo
提交者:
GitHub
7月 19, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Standard name of sparse pool (#44344)
上级
f382eb06
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
152 addition
and
152 deletion
+152
-152
paddle/phi/api/yaml/sparse_api.yaml
paddle/phi/api/yaml/sparse_api.yaml
+2
-2
paddle/phi/api/yaml/sparse_bw_api.yaml
paddle/phi/api/yaml/sparse_bw_api.yaml
+8
-8
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
+19
-19
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
+28
-28
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
+19
-19
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
+28
-28
paddle/phi/kernels/sparse/pool_grad_kernel.h
paddle/phi/kernels/sparse/pool_grad_kernel.h
+14
-14
paddle/phi/kernels/sparse/pool_kernel.h
paddle/phi/kernels/sparse/pool_kernel.h
+16
-16
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
+18
-18
未找到文件。
paddle/phi/api/yaml/sparse_api.yaml
浏览文件 @
9e307229
...
@@ -316,10 +316,10 @@
...
@@ -316,10 +316,10 @@
args
:
(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
args
:
(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides)
output
:
Tensor(out), Tensor(rulebook)
output
:
Tensor(out), Tensor(rulebook)
kernel
:
kernel
:
func
:
sparse_maxpool
{sparse_coo -> sparse_coo, dense}
func
:
maxpool_coo
{sparse_coo -> sparse_coo, dense}
layout
:
x
layout
:
x
intermediate
:
rulebook
intermediate
:
rulebook
backward
:
sparse_
maxpool_grad
backward
:
maxpool_grad
-
api
:
mv
-
api
:
mv
args
:
(Tensor x, Tensor vec)
args
:
(Tensor x, Tensor vec)
...
...
paddle/phi/api/yaml/sparse_bw_api.yaml
浏览文件 @
9e307229
...
@@ -137,6 +137,13 @@
...
@@ -137,6 +137,13 @@
matmul_coo_dense_grad {sparse_coo, dense, dense -> sparse_coo, dense},
matmul_coo_dense_grad {sparse_coo, dense, dense -> sparse_coo, dense},
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo}
-
backward_api
:
maxpool_grad
forward
:
maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args
:
(Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
output
:
Tensor(x_grad)
kernel
:
func
:
maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
-
backward_api
:
multiply_grad
-
backward_api
:
multiply_grad
forward
:
multiply(Tensor x, Tensor y) -> Tensor(out)
forward
:
multiply(Tensor x, Tensor y) -> Tensor(out)
args
:
(Tensor x, Tensor y, Tensor out_grad)
args
:
(Tensor x, Tensor y, Tensor out_grad)
...
@@ -198,13 +205,6 @@
...
@@ -198,13 +205,6 @@
kernel
:
kernel
:
func
:
softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
func
:
softmax_csr_grad{sparse_csr, sparse_csr -> sparse_csr}
-
backward_api
:
sparse_maxpool_grad
forward
:
sparse_maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook)
args
:
(Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes)
output
:
Tensor(x_grad)
kernel
:
func
:
sparse_maxpool_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo}
-
backward_api
:
sqrt_grad
-
backward_api
:
sqrt_grad
forward
:
sqrt(Tensor x) -> Tensor(out)
forward
:
sqrt(Tensor x) -> Tensor(out)
args
:
(Tensor out, Tensor out_grad)
args
:
(Tensor out, Tensor out_grad)
...
@@ -255,7 +255,7 @@
...
@@ -255,7 +255,7 @@
-
backward_api
:
fused_attention_grad
-
backward_api
:
fused_attention_grad
forward
:
fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
forward
:
fused_attention_csr(Tensor query, Tensor key, Tensor value, Tensor sparse_mask, Tensor key_padding_mask, Tensor attn_mask) -> Tensor(out), Tensor(softmax)
args
:
(Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
args
:
(Tensor query, Tensor key, Tensor value, Tensor softmax, Tensor out_grad)
output
:
Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
output
:
Tensor(query_grad), Tensor(key_grad), Tensor(value_grad)
kernel
:
kernel
:
func
:
fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
func
:
fused_attention_csr_grad{dense, dense, dense, sparse_csr, dense -> dense, dense, dense}
layout
:
softmax
layout
:
softmax
...
...
paddle/phi/kernels/sparse/cpu/
sparse_
pool_grad_kernel.cc
→
paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc
浏览文件 @
9e307229
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/
sparse_
pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/pool_grad_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
...
@@ -25,13 +25,13 @@ namespace phi {
...
@@ -25,13 +25,13 @@ namespace phi {
namespace
sparse
{
namespace
sparse
{
template
<
typename
T
,
typename
IntT
=
int
>
template
<
typename
T
,
typename
IntT
=
int
>
void
MaxPoolGradCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
MaxPool
Coo
GradCPUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
SparseCooTensor
*
x_grad
)
{
SparseCooTensor
*
x_grad
)
{
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
int
channels
=
x
.
dims
()[
4
];
const
int
channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
...
@@ -75,16 +75,16 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx,
...
@@ -75,16 +75,16 @@ void MaxPoolGradCPUKernel(const CPUContext& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
GradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
SparseCooTensor
*
x_grad
)
{
SparseCooTensor
*
x_grad
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolGradCPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPool
Coo
GradCPUKernel"
,
([
&
]
{
MaxPoolGradCPUKernel
<
T
,
data_t
>
(
MaxPool
Coo
GradCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
}));
}));
}
}
...
@@ -92,10 +92,10 @@ void MaxPoolGradKernel(const Context& dev_ctx,
...
@@ -92,10 +92,10 @@ void MaxPoolGradKernel(const Context& dev_ctx,
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
_grad
,
PD_REGISTER_KERNEL
(
maxpool_coo
_grad
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolGradKernel
,
phi
::
sparse
::
MaxPool
Coo
GradKernel
,
float
,
float
,
double
)
{
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
...
...
paddle/phi/kernels/sparse/cpu/
sparse_
pool_kernel.cc
→
paddle/phi/kernels/sparse/cpu/pool_kernel.cc
浏览文件 @
9e307229
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/
sparse_
pool_kernel.h"
#include "paddle/phi/kernels/sparse/pool_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
...
@@ -30,14 +30,14 @@ namespace sparse {
...
@@ -30,14 +30,14 @@ namespace sparse {
* out: (N, D, H, W, OC)
* out: (N, D, H, W, OC)
**/
**/
template
<
typename
T
,
typename
IntT
=
int
>
template
<
typename
T
,
typename
IntT
=
int
>
void
MaxPoolCPUKernel
(
const
CPUContext
&
dev_ctx
,
void
MaxPoolC
ooC
PUKernel
(
const
CPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
const
std
::
vector
<
int
>&
real_kernel_sizes
=
...
@@ -98,34 +98,34 @@ void MaxPoolCPUKernel(const CPUContext& dev_ctx,
...
@@ -98,34 +98,34 @@ void MaxPoolCPUKernel(const CPUContext& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
Kernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolCPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPoolC
ooC
PUKernel"
,
([
&
]
{
MaxPoolCPUKernel
<
T
,
data_t
>
(
dev_ctx
,
MaxPoolC
ooC
PUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
x
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
out
,
out
,
rulebook
);
rulebook
);
}));
}));
}
}
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
,
PD_REGISTER_KERNEL
(
maxpool_coo
,
CPU
,
CPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolKernel
,
phi
::
sparse
::
MaxPool
Coo
Kernel
,
float
,
float
,
double
)
{
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
...
...
paddle/phi/kernels/sparse/gpu/
sparse_
pool_grad_kernel.cu
→
paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu
浏览文件 @
9e307229
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/
sparse_
pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/pool_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
...
@@ -52,13 +52,13 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr,
...
@@ -52,13 +52,13 @@ __global__ void MaxPoolGradCudaKernel(const T* in_features_ptr,
}
}
template
<
typename
T
,
typename
IntT
=
int
>
template
<
typename
T
,
typename
IntT
=
int
>
void
MaxPoolGradGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
MaxPool
Coo
GradGPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
SparseCooTensor
*
x_grad
)
{
SparseCooTensor
*
x_grad
)
{
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
int
in_channels
=
x
.
dims
()[
4
];
const
int
in_channels
=
x
.
dims
()[
4
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
int
rulebook_len
=
rulebook
.
dims
()[
1
];
...
@@ -121,16 +121,16 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx,
...
@@ -121,16 +121,16 @@ void MaxPoolGradGPUKernel(const GPUContext& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
GradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
SparseCooTensor
*
x_grad
)
{
SparseCooTensor
*
x_grad
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolGradGPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPool
Coo
GradGPUKernel"
,
([
&
]
{
MaxPoolGradGPUKernel
<
T
,
data_t
>
(
MaxPool
Coo
GradGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
x_grad
);
}));
}));
}
}
...
@@ -138,10 +138,10 @@ void MaxPoolGradKernel(const Context& dev_ctx,
...
@@ -138,10 +138,10 @@ void MaxPoolGradKernel(const Context& dev_ctx,
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
_grad
,
PD_REGISTER_KERNEL
(
maxpool_coo
_grad
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolGradKernel
,
phi
::
sparse
::
MaxPool
Coo
GradKernel
,
float
,
float
,
double
)
{
double
)
{
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
kernel
->
InputAt
(
0
).
SetDataLayout
(
phi
::
DataLayout
::
SPARSE_COO
);
...
...
paddle/phi/kernels/sparse/gpu/
sparse_
pool_kernel.cu
→
paddle/phi/kernels/sparse/gpu/pool_kernel.cu
浏览文件 @
9e307229
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/sparse/
sparse_
pool_kernel.h"
#include "paddle/phi/kernels/sparse/pool_kernel.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_meta.h"
#include "paddle/phi/core/tensor_meta.h"
...
@@ -48,14 +48,14 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr,
...
@@ -48,14 +48,14 @@ __global__ void MaxPoolCudaKernel(const T* in_features_ptr,
* out: (N, D, H, W, OC)
* out: (N, D, H, W, OC)
**/
**/
template
<
typename
T
,
typename
IntT
=
int
>
template
<
typename
T
,
typename
IntT
=
int
>
void
MaxPoolGPUKernel
(
const
GPUContext
&
dev_ctx
,
void
MaxPool
Coo
GPUKernel
(
const
GPUContext
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
const
auto
&
x_dims
=
x
.
dims
();
const
auto
&
x_dims
=
x
.
dims
();
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
int
kernel_size
=
kernel_sizes
[
0
]
*
kernel_sizes
[
1
]
*
kernel_sizes
[
2
];
const
std
::
vector
<
int
>&
real_kernel_sizes
=
const
std
::
vector
<
int
>&
real_kernel_sizes
=
...
@@ -127,34 +127,34 @@ void MaxPoolGPUKernel(const GPUContext& dev_ctx,
...
@@ -127,34 +127,34 @@ void MaxPoolGPUKernel(const GPUContext& dev_ctx,
}
}
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
Kernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
PD_VISIT_INTEGRAL_TYPES
(
PD_VISIT_INTEGRAL_TYPES
(
x
.
non_zero_indices
().
dtype
(),
"MaxPoolGPUKernel"
,
([
&
]
{
x
.
non_zero_indices
().
dtype
(),
"MaxPool
Coo
GPUKernel"
,
([
&
]
{
MaxPoolGPUKernel
<
T
,
data_t
>
(
dev_ctx
,
MaxPool
Coo
GPUKernel
<
T
,
data_t
>
(
dev_ctx
,
x
,
x
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
out
,
out
,
rulebook
);
rulebook
);
}));
}));
}
}
}
// namespace sparse
}
// namespace sparse
}
// namespace phi
}
// namespace phi
PD_REGISTER_KERNEL
(
sparse_maxpool
,
PD_REGISTER_KERNEL
(
maxpool_coo
,
GPU
,
GPU
,
ALL_LAYOUT
,
ALL_LAYOUT
,
phi
::
sparse
::
MaxPoolKernel
,
phi
::
sparse
::
MaxPool
Coo
Kernel
,
float
,
float
,
double
,
double
,
phi
::
dtype
::
float16
)
{
phi
::
dtype
::
float16
)
{
...
...
paddle/phi/kernels/sparse/
sparse_
pool_grad_kernel.h
→
paddle/phi/kernels/sparse/pool_grad_kernel.h
浏览文件 @
9e307229
...
@@ -22,23 +22,23 @@ namespace phi {
...
@@ -22,23 +22,23 @@ namespace phi {
namespace
sparse
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolGradKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
GradKernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
SparseCooTensor
*
x_grad
);
SparseCooTensor
*
x_grad
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
SparseCooTensor
MaxPoolGrad
(
const
Context
&
dev_ctx
,
SparseCooTensor
MaxPool
Coo
Grad
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
DenseTensor
&
rulebook
,
const
DenseTensor
&
rulebook
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out
,
const
SparseCooTensor
&
out_grad
,
const
SparseCooTensor
&
out_grad
,
const
std
::
vector
<
int
>&
kernel_sizes
)
{
const
std
::
vector
<
int
>&
kernel_sizes
)
{
SparseCooTensor
x_grad
;
SparseCooTensor
x_grad
;
MaxPoolGradKernel
<
T
,
Context
>
(
MaxPool
Coo
GradKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
&
x_grad
);
dev_ctx
,
x
,
rulebook
,
out
,
out_grad
,
kernel_sizes
,
&
x_grad
);
return
x_grad
;
return
x_grad
;
}
}
...
...
paddle/phi/kernels/sparse/
sparse_
pool_kernel.h
→
paddle/phi/kernels/sparse/pool_kernel.h
浏览文件 @
9e307229
...
@@ -22,25 +22,25 @@ namespace phi {
...
@@ -22,25 +22,25 @@ namespace phi {
namespace
sparse
{
namespace
sparse
{
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
void
MaxPoolKernel
(
const
Context
&
dev_ctx
,
void
MaxPool
Coo
Kernel
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
SparseCooTensor
*
out
,
SparseCooTensor
*
out
,
DenseTensor
*
rulebook
);
DenseTensor
*
rulebook
);
template
<
typename
T
,
typename
Context
>
template
<
typename
T
,
typename
Context
>
SparseCooTensor
MaxPool
(
const
Context
&
dev_ctx
,
SparseCooTensor
MaxPool
Coo
(
const
Context
&
dev_ctx
,
const
SparseCooTensor
&
x
,
const
SparseCooTensor
&
x
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
kernel_sizes
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
dilations
,
const
std
::
vector
<
int
>&
strides
,
const
std
::
vector
<
int
>&
strides
,
DenseTensor
*
rulebook
)
{
DenseTensor
*
rulebook
)
{
SparseCooTensor
coo
;
SparseCooTensor
coo
;
MaxPoolKernel
<
T
,
Context
>
(
MaxPool
Coo
Kernel
<
T
,
Context
>
(
dev_ctx
,
x
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
coo
,
rulebook
);
dev_ctx
,
x
,
kernel_sizes
,
paddings
,
dilations
,
strides
,
&
coo
,
rulebook
);
return
coo
;
return
coo
;
}
}
...
...
paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc
浏览文件 @
9e307229
...
@@ -23,8 +23,8 @@ limitations under the License. */
...
@@ -23,8 +23,8 @@ limitations under the License. */
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/coalesce_kernel.h"
#include "paddle/phi/kernels/sparse/
sparse_
pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/pool_grad_kernel.h"
#include "paddle/phi/kernels/sparse/
sparse_
pool_kernel.h"
#include "paddle/phi/kernels/sparse/pool_kernel.h"
namespace
phi
{
namespace
phi
{
namespace
tests
{
namespace
tests
{
...
@@ -91,13 +91,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -91,13 +91,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
if
(
!
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
DenseTensor
rulebook
;
DenseTensor
rulebook
;
SparseCooTensor
out
=
sparse
::
MaxPool
<
T
>
(
dev_ctx_cpu
,
SparseCooTensor
out
=
sparse
::
MaxPool
Coo
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
x_tensor
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
&
rulebook
);
&
rulebook
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
out
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
for
(
int
i
=
0
;
i
<
correct_out_dims
.
size
();
i
++
)
{
...
@@ -113,7 +113,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -113,7 +113,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
f_verify
(
out
.
non_zero_elements
().
data
<
T
>
(),
correct_out_features
);
f_verify
(
out
.
non_zero_elements
().
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
if
(
backward
)
{
SparseCooTensor
x_grad
=
sparse
::
MaxPoolGrad
<
T
>
(
SparseCooTensor
x_grad
=
sparse
::
MaxPool
Coo
Grad
<
T
>
(
dev_ctx_cpu
,
x_tensor
,
rulebook
,
out
,
out
,
kernel_sizes
);
dev_ctx_cpu
,
x_tensor
,
rulebook
,
out
,
out
,
kernel_sizes
);
f_verify
(
x_grad
.
non_zero_elements
().
data
<
T
>
(),
features_grad
);
f_verify
(
x_grad
.
non_zero_elements
().
data
<
T
>
(),
features_grad
);
}
}
...
@@ -151,13 +151,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -151,13 +151,13 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
SparseCooTensor
d_x_tensor
(
d_indices_tensor
,
d_features_tensor
,
x_dims
);
SparseCooTensor
d_x_tensor
(
d_indices_tensor
,
d_features_tensor
,
x_dims
);
DenseTensor
d_rulebook
;
DenseTensor
d_rulebook
;
SparseCooTensor
d_out
=
sparse
::
MaxPool
<
T
>
(
dev_ctx_gpu
,
SparseCooTensor
d_out
=
sparse
::
MaxPool
Coo
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_x_tensor
,
kernel_sizes
,
kernel_sizes
,
paddings
,
paddings
,
dilations
,
dilations
,
strides
,
strides
,
&
d_rulebook
);
&
d_rulebook
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
SparseCooTensor
tmp_d_out
=
sparse
::
Coalesce
<
T
>
(
dev_ctx_gpu
,
d_out
);
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
ASSERT_EQ
(
correct_out_dims
.
size
(),
d_out
.
dims
().
size
());
...
@@ -191,7 +191,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
...
@@ -191,7 +191,7 @@ void TestMaxPoolBase(const std::vector<IntT>& indices,
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
f_verify
(
h_features_tensor
.
data
<
T
>
(),
correct_out_features
);
if
(
backward
)
{
if
(
backward
)
{
SparseCooTensor
x_grad
=
sparse
::
MaxPoolGrad
<
T
>
(
SparseCooTensor
x_grad
=
sparse
::
MaxPool
Coo
Grad
<
T
>
(
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_out
,
d_out
,
kernel_sizes
);
dev_ctx_gpu
,
d_x_tensor
,
d_rulebook
,
d_out
,
d_out
,
kernel_sizes
);
DenseTensor
h_features_grad
=
DenseTensor
h_features_grad
=
phi
::
EmptyLike
<
T
>
(
dev_ctx_cpu
,
x_grad
.
non_zero_elements
());
phi
::
EmptyLike
<
T
>
(
dev_ctx_cpu
,
x_grad
.
non_zero_elements
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录