Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
a07f19ee
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
a07f19ee
编写于
3月 10, 2022
作者:
Z
Zhong Hui
提交者:
GitHub
3月 10, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[PHI] Move segment_pool to phi. (#40099)
* move segment_pool to phi. * mark summed ids as optional tensor. * fix as reviews.
上级
548f2be4
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
666 addition
and
405 deletion
+666
-405
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+0
-1
paddle/fluid/operators/segment_pool_op.cc
paddle/fluid/operators/segment_pool_op.cc
+9
-28
paddle/fluid/operators/segment_pool_op.cu
paddle/fluid/operators/segment_pool_op.cu
+0
-27
paddle/fluid/operators/segment_pool_op.h
paddle/fluid/operators/segment_pool_op.h
+0
-176
paddle/fluid/operators/unity_build_rule.cmake
paddle/fluid/operators/unity_build_rule.cmake
+1
-3
paddle/phi/infermeta/binary.cc
paddle/phi/infermeta/binary.cc
+19
-0
paddle/phi/infermeta/binary.h
paddle/phi/infermeta/binary.h
+8
-0
paddle/phi/kernels/CMakeLists.txt
paddle/phi/kernels/CMakeLists.txt
+3
-1
paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc
paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc
+26
-0
paddle/phi/kernels/cpu/segment_pool_kernel.cc
paddle/phi/kernels/cpu/segment_pool_kernel.cc
+22
-0
paddle/phi/kernels/funcs/CMakeLists.txt
paddle/phi/kernels/funcs/CMakeLists.txt
+1
-0
paddle/phi/kernels/funcs/segment_pooling.cc
paddle/phi/kernels/funcs/segment_pooling.cc
+48
-36
paddle/phi/kernels/funcs/segment_pooling.cu
paddle/phi/kernels/funcs/segment_pooling.cu
+173
-116
paddle/phi/kernels/funcs/segment_pooling.h
paddle/phi/kernels/funcs/segment_pooling.h
+19
-16
paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu
paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu
+27
-0
paddle/phi/kernels/gpu/segment_pool_kernel.cu
paddle/phi/kernels/gpu/segment_pool_kernel.cu
+23
-0
paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h
paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h
+51
-0
paddle/phi/kernels/impl/segment_pool_kernel_impl.h
paddle/phi/kernels/impl/segment_pool_kernel_impl.h
+142
-0
paddle/phi/kernels/segment_pool_grad_kernel.h
paddle/phi/kernels/segment_pool_grad_kernel.h
+31
-0
paddle/phi/kernels/segment_pool_kernel.h
paddle/phi/kernels/segment_pool_kernel.h
+29
-0
paddle/phi/ops/compat/segment_pool_sig.cc
paddle/phi/ops/compat/segment_pool_sig.cc
+33
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
a07f19ee
...
@@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP
...
@@ -161,7 +161,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows_utils lapack_function
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows_utils lapack_function
lod_tensor maxouting unpooling pooling lod_rank_table context_project
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling
segment_pooling
executor device_memory_aligment generator
)
sequence_pooling executor device_memory_aligment generator
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse matrix_solve
)
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
a07f19ee
...
@@ -46,7 +46,6 @@ math_library(vol2col)
...
@@ -46,7 +46,6 @@ math_library(vol2col)
math_library
(
prelu
)
math_library
(
prelu
)
math_library
(
bert_encoder_functor
)
math_library
(
bert_encoder_functor
)
math_library
(
tree2col DEPS math_function
)
math_library
(
tree2col DEPS math_function
)
math_library
(
segment_pooling
)
math_library
(
matrix_solve
)
math_library
(
matrix_solve
)
cc_test
(
selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor
)
cc_test
(
selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor
)
...
...
paddle/fluid/operators/segment_pool_op.cc
浏览文件 @
a07f19ee
...
@@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,9 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include <memory>
#include <memory>
#include <string>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/binary.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
...
@@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel {
...
@@ -23,22 +26,6 @@ class SegmentPoolOp : public framework::OperatorWithKernel {
public:
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"SegmentPool"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SegmentIds"
),
"Input"
,
"SegmentIds"
,
"SegmentPool"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"SegmentPool"
);
auto
dims
=
ctx
->
GetInputDim
(
"X"
);
dims
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
dims
);
if
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"pooltype"
)
==
"MEAN"
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SummedIds"
),
"Output"
,
"SummedIds"
,
"SegmentPool"
);
ctx
->
SetOutputDim
(
"SummedIds"
,
{
-
1
,
1
});
}
}
protected:
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
...
@@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
...
@@ -150,17 +137,11 @@ class SegmentPoolGradOpMaker : public framework::SingleGradOpMaker<T> {
}
// namespace paddle
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
ops
=
paddle
::
operators
;
DECLARE_INFER_SHAPE_FUNCTOR
(
segment_pool
,
SegmentPoolInferShapeFunctor
,
PD_INFER_META
(
phi
::
SegmentPoolInferMeta
));
REGISTER_OPERATOR
(
segment_pool
,
ops
::
SegmentPoolOp
,
ops
::
SegmentPoolOpMaker
,
REGISTER_OPERATOR
(
segment_pool
,
ops
::
SegmentPoolOp
,
ops
::
SegmentPoolOpMaker
,
ops
::
SegmentPoolGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SegmentPoolGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SegmentPoolGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
ops
::
SegmentPoolGradOpMaker
<
paddle
::
imperative
::
OpBase
>
,
SegmentPoolInferShapeFunctor
);
REGISTER_OPERATOR
(
segment_pool_grad
,
ops
::
SegmentPoolGradOp
);
REGISTER_OPERATOR
(
segment_pool_grad
,
ops
::
SegmentPoolGradOp
);
REGISTER_OP_CPU_KERNEL
(
segment_pool
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
segment_pool_grad
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/segment_pool_op.cu
已删除
100644 → 0
浏览文件 @
548f2be4
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
segment_pool
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
segment_pool_grad
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
);
paddle/fluid/operators/segment_pool_op.h
已删除
100644 → 0
浏览文件 @
548f2be4
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/macros.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
>
void
SegmentKernelLaunchHelper
(
const
framework
::
ExecutionContext
&
context
)
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
Tensor
*
summed_ids
=
nullptr
;
int64_t
num_indices
=
segment
->
numel
();
PADDLE_ENFORCE_EQ
(
num_indices
,
input
->
dims
()[
0
],
platform
::
errors
::
InvalidArgument
(
"Segment_ids should be the same size as dimension 0 of input X."
));
PADDLE_ENFORCE_EQ
(
num_indices
,
segment
->
dims
()[
0
],
platform
::
errors
::
InvalidArgument
(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s]."
,
segment
->
dims
()));
if
(
input
->
numel
()
==
0
||
segment
->
numel
()
==
0
)
{
return
;
}
bool
cpu_place
=
context
.
GetPlace
().
GetType
()
==
phi
::
AllocationType
::
CPU
;
if
(
cpu_place
)
{
auto
dims
=
input
->
dims
();
auto
*
segment_ids
=
segment
->
data
<
IndexT
>
();
dims
[
0
]
=
static_cast
<
int64_t
>
(
segment_ids
[
segment
->
numel
()
-
1
]
+
1
);
PADDLE_ENFORCE_GT
(
dims
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"Segment ids must be >= 0, but got last id %d"
,
dims
[
0
]));
output
->
Resize
({
dims
});
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
output
,
static_cast
<
T
>
(
0
));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
!
cpu_place
)
{
Tensor
length
;
length
.
mutable_data
<
IndexT
>
(
phi
::
make_ddim
({
1
}),
platform
::
CPUPlace
());
IndexT
*
length_data
=
length
.
data
<
IndexT
>
();
const
IndexT
*
segment_ids
=
segment
->
data
<
IndexT
>
();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipMemcpy
(
length_data
,
segment_ids
+
num_indices
-
1
,
sizeof
(
IndexT
),
hipMemcpyDeviceToHost
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
length_data
,
segment_ids
+
num_indices
-
1
,
sizeof
(
IndexT
),
cudaMemcpyDeviceToHost
));
#endif
IndexT
length_host
=
length_data
[
0
];
length_host
++
;
PADDLE_ENFORCE_GT
(
length_host
,
0
,
platform
::
errors
::
InvalidArgument
(
"Segment ids must be >= 0, but got last id %d"
,
length_data
[
0
]));
auto
dims
=
input
->
dims
();
dims
[
0
]
=
static_cast
<
int64_t
>
(
length_host
);
output
->
Resize
({
dims
});
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
T
init_value
=
0
;
if
(
pooltype
==
"MAX"
)
{
init_value
=
static_cast
<
T
>
(
-
FLT_MAX
);
}
else
if
(
pooltype
==
"MIN"
)
{
init_value
=
static_cast
<
T
>
(
FLT_MAX
);
}
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
setconst
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
setconst
(
dev_ctx
,
output
,
static_cast
<
T
>
(
init_value
));
// the gpu kernel of mean pool record the counts of segment_ids
if
(
pooltype
==
"MEAN"
)
{
summed_ids
=
context
.
Output
<
Tensor
>
(
"SummedIds"
);
summed_ids
->
Resize
({
dims
[
0
],
1
});
summed_ids
->
mutable_data
<
T
>
(
context
.
GetPlace
());
setconst
(
dev_ctx
,
summed_ids
,
static_cast
<
T
>
(
1e-12
));
}
}
#endif
SegmentPoolFunctor
<
DeviceContext
,
T
,
IndexT
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
segment
,
output
,
summed_ids
,
pooltype
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
SegmentPoolKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
index_type
=
framework
::
TransToProtoVarType
(
segment
->
dtype
());
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
SegmentKernelLaunchHelper
<
DeviceContext
,
T
,
int
>
(
context
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
SegmentKernelLaunchHelper
<
DeviceContext
,
T
,
int64_t
>
(
context
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SegmentPoolGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
*
out_g
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
in_g
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
const
Tensor
*
summed_ids
=
nullptr
;
if
(
pooltype
==
"MEAN"
)
{
summed_ids
=
context
.
Input
<
Tensor
>
(
"SummedIds"
);
}
in_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
phi
::
funcs
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
in_g
,
static_cast
<
T
>
(
0
));
auto
index_type
=
framework
::
TransToProtoVarType
(
segment
->
dtype
());
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
SegmentPoolGradFunctor
<
DeviceContext
,
T
,
int
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
out_g
,
*
segment
,
in_g
,
summed_ids
,
pooltype
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
SegmentPoolGradFunctor
<
DeviceContext
,
T
,
int64_t
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
out_g
,
*
segment
,
in_g
,
summed_ids
,
pooltype
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/unity_build_rule.cmake
浏览文件 @
a07f19ee
...
@@ -236,7 +236,6 @@ register_unity_group(cc
...
@@ -236,7 +236,6 @@ register_unity_group(cc
scatter_nd_add_op.cc
scatter_nd_add_op.cc
scatter_op.cc
scatter_op.cc
seed_op.cc
seed_op.cc
segment_pool_op.cc
select_input_op.cc
select_input_op.cc
select_output_op.cc
)
select_output_op.cc
)
register_unity_group
(
cc
register_unity_group
(
cc
...
@@ -496,8 +495,7 @@ register_unity_group(cu
...
@@ -496,8 +495,7 @@ register_unity_group(cu
scale_op.cu
scale_op.cu
scatter_nd_add_op.cu
scatter_nd_add_op.cu
scatter_op.cu
scatter_op.cu
seed_op.cu
seed_op.cu
)
segment_pool_op.cu
)
register_unity_group
(
cu
register_unity_group
(
cu
roi_pool_op.cu
roi_pool_op.cu
selu_op.cu
selu_op.cu
...
...
paddle/phi/infermeta/binary.cc
浏览文件 @
a07f19ee
...
@@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
...
@@ -417,6 +417,25 @@ void Atan2InferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) {
out
->
share_meta
(
x
);
out
->
share_meta
(
x
);
}
}
void
SegmentPoolInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
segment_ids
,
const
std
::
string
&
pooltype
,
MetaTensor
*
out
,
MetaTensor
*
summed_ids
,
MetaConfig
config
)
{
auto
dims
=
x
.
dims
();
dims
[
0
]
=
-
1
;
out
->
set_dims
(
dims
);
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
if
(
pooltype
==
"MEAN"
)
{
summed_ids
->
set_dims
({
-
1
,
1
});
summed_ids
->
set_dtype
(
x
.
dtype
());
summed_ids
->
set_layout
(
x
.
layout
());
}
}
void
BCELossInferMeta
(
const
MetaTensor
&
input
,
void
BCELossInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
label
,
const
MetaTensor
&
label
,
MetaTensor
*
out
,
MetaTensor
*
out
,
...
...
paddle/phi/infermeta/binary.h
浏览文件 @
a07f19ee
...
@@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x,
...
@@ -80,6 +80,14 @@ void CrossInferMeta(const MetaTensor& x,
MetaTensor
*
out
);
MetaTensor
*
out
);
void
Atan2InferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
void
Atan2InferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
void
SegmentPoolInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
segment_ids
,
const
std
::
string
&
pooltype
,
MetaTensor
*
out
,
MetaTensor
*
summed_ids
,
MetaConfig
config
=
MetaConfig
());
void
BCELossInferMeta
(
const
MetaTensor
&
input
,
void
BCELossInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
label
,
const
MetaTensor
&
label
,
MetaTensor
*
out
,
MetaTensor
*
out
,
...
...
paddle/phi/kernels/CMakeLists.txt
浏览文件 @
a07f19ee
...
@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
...
@@ -27,7 +27,7 @@ kernel_library(full_kernel DEPS ${COMMON_KERNEL_DEPS} empty_kernel)
# Some kernels depend on some targets that are not commonly used.
# Some kernels depend on some targets that are not commonly used.
# These targets are not suitable for common dependencies.
# These targets are not suitable for common dependencies.
# In this case, you need to manually generate them here.
# In this case, you need to manually generate them here.
set
(
MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel
)
set
(
MANUAL_BUILD_KERNELS math_kernel softmax_kernel softmax_grad_kernel triangular_solve_grad_kernel maxout_kernel maxout_grad_kernel put_along_axis_kernel put_along_axis_grad_kernel take_along_axis_kernel take_along_axis_grad_kernel eigh_kernel
segment_pool_kernel segment_pool_grad_kernel
)
kernel_library
(
math_kernel DEPS
${
COMMON_KERNEL_DEPS
}
cast_kernel copy_kernel
)
kernel_library
(
math_kernel DEPS
${
COMMON_KERNEL_DEPS
}
cast_kernel copy_kernel
)
kernel_library
(
softmax_kernel DEPS
${
COMMON_KERNEL_DEPS
}
softmax
)
kernel_library
(
softmax_kernel DEPS
${
COMMON_KERNEL_DEPS
}
softmax
)
kernel_library
(
softmax_grad_kernel DEPS
${
COMMON_KERNEL_DEPS
}
softmax
)
kernel_library
(
softmax_grad_kernel DEPS
${
COMMON_KERNEL_DEPS
}
softmax
)
...
@@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat
...
@@ -39,6 +39,8 @@ kernel_library(put_along_axis_grad_kernel DEPS ${COMMON_KERNEL_DEPS} gather_scat
kernel_library
(
take_along_axis_kernel DEPS
${
COMMON_KERNEL_DEPS
}
gather_scatter_kernel
)
kernel_library
(
take_along_axis_kernel DEPS
${
COMMON_KERNEL_DEPS
}
gather_scatter_kernel
)
kernel_library
(
take_along_axis_grad_kernel DEPS
${
COMMON_KERNEL_DEPS
}
gather_scatter_kernel
)
kernel_library
(
take_along_axis_grad_kernel DEPS
${
COMMON_KERNEL_DEPS
}
gather_scatter_kernel
)
kernel_library
(
eigh_kernel DEPS
${
COMMON_KERNEL_DEPS
}
lapack_function
)
kernel_library
(
eigh_kernel DEPS
${
COMMON_KERNEL_DEPS
}
lapack_function
)
kernel_library
(
segment_pool_kernel DEPS
${
COMMON_KERNEL_DEPS
}
segment_pooling
)
kernel_library
(
segment_pool_grad_kernel DEPS
${
COMMON_KERNEL_DEPS
}
segment_pooling
)
# 4. auto parse and build kernel targets by cmake
# 4. auto parse and build kernel targets by cmake
register_kernels
(
EXCLUDES
${
COMMON_BAISC_KERNELS
}
${
MANUAL_BUILD_KERNELS
}
DEPS
${
COMMON_KERNEL_DEPS
}
${
COMMON_BAISC_KERNELS
}
)
register_kernels
(
EXCLUDES
${
COMMON_BAISC_KERNELS
}
${
MANUAL_BUILD_KERNELS
}
DEPS
${
COMMON_KERNEL_DEPS
}
${
COMMON_BAISC_KERNELS
}
)
...
...
paddle/phi/kernels/cpu/segment_pool_grad_kernel.cc
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
segment_pool_grad
,
CPU
,
ALL_LAYOUT
,
phi
::
SegmentPoolGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/cpu/segment_pool_kernel.cc
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/segment_pool_kernel.h"
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
segment_pool
,
CPU
,
ALL_LAYOUT
,
phi
::
SegmentPoolKernel
,
float
,
double
)
{}
paddle/phi/kernels/funcs/CMakeLists.txt
浏览文件 @
a07f19ee
...
@@ -4,6 +4,7 @@ add_subdirectory(lapack)
...
@@ -4,6 +4,7 @@ add_subdirectory(lapack)
add_subdirectory
(
detail
)
add_subdirectory
(
detail
)
math_library
(
math_function DEPS blas dense_tensor tensor
)
math_library
(
math_function DEPS blas dense_tensor tensor
)
math_library
(
segment_pooling
)
math_library
(
sequence2batch
)
math_library
(
sequence2batch
)
math_library
(
gru_compute DEPS activation_functions math_function
)
math_library
(
gru_compute DEPS activation_functions math_function
)
math_library
(
lstm_compute DEPS activation_functions
)
math_library
(
lstm_compute DEPS activation_functions
)
...
...
paddle/
fluid/operators/math
/segment_pooling.cc
→
paddle/
phi/kernels/funcs
/segment_pooling.cc
浏览文件 @
a07f19ee
...
@@ -12,45 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,45 +12,52 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/
fluid/operators/math
/segment_pooling.h"
#include "paddle/
phi/kernels/funcs
/segment_pooling.h"
#include <string>
#include <string>
#include "paddle/fluid/framework/eigen.h"
namespace
paddle
{
#include "paddle/phi/backends/cpu/cpu_context.h"
namespace
operators
{
#include "paddle/phi/kernels/funcs/eigen/common.h"
using
Tensor
=
framework
::
Tensor
;
namespace
phi
{
namespace
funcs
{
using
Tensor
=
DenseTensor
;
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolFunctor
<
p
latform
::
CPUDevice
Context
,
T
,
IndexT
>
{
class
SegmentPoolFunctor
<
p
hi
::
CPU
Context
,
T
,
IndexT
>
{
public:
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
phi
::
CPUContext
&
dev_ctx
,
const
framework
::
Tensor
&
input
,
const
DenseTensor
&
input
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
output
,
const
DenseTensor
&
segments
,
framework
::
Tensor
*
index
,
DenseTensor
*
output
,
DenseTensor
*
index
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
std
::
string
pooltype
=
"SUM"
)
{
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
auto
curent_id
=
segment_ids
[
0
];
auto
curent_id
=
segment_ids
[
0
];
int64_t
last_idx
=
0
;
int64_t
last_idx
=
0
;
int64_t
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
int64_t
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
auto
&
place
=
*
context
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
if
(
idx
<
segments
.
numel
())
{
if
(
idx
<
segments
.
numel
())
{
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
curent_id
,
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
platform
::
errors
::
InvalidArgument
(
curent_id
,
phi
::
errors
::
InvalidArgument
(
"The segment ids should be sorted, but got "
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
}
}
Tensor
out_t
=
output
->
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
out_t
=
output
->
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
int64_t
h
=
idx
-
last_idx
;
int64_t
h
=
idx
-
last_idx
;
auto
in_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_t
,
phi
::
make_ddim
({
h
,
w
}));
auto
in_e
=
EigenMatrix
<
T
>::
From
(
in_t
,
phi
::
make_ddim
({
h
,
w
}));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
out_t
);
auto
out_e
=
EigenVector
<
T
>::
Flatten
(
out_t
);
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
0
}});
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
0
}});
if
(
pooltype
==
"MEAN"
)
{
if
(
pooltype
==
"MEAN"
)
{
...
@@ -62,7 +69,7 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
...
@@ -62,7 +69,7 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
}
else
if
(
pooltype
==
"MIN"
)
{
}
else
if
(
pooltype
==
"MIN"
)
{
out_e
.
device
(
place
)
=
in_e
.
minimum
(
reduce_dim
);
out_e
.
device
(
place
)
=
in_e
.
minimum
(
reduce_dim
);
}
else
{
}
else
{
PADDLE_THROW
(
p
latform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
"available, but got %s."
,
pooltype
));
pooltype
));
...
@@ -75,36 +82,41 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
...
@@ -75,36 +82,41 @@ class SegmentPoolFunctor<platform::CPUDeviceContext, T, IndexT> {
};
};
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolGradFunctor
<
p
latform
::
CPUDevice
Context
,
T
,
IndexT
>
{
class
SegmentPoolGradFunctor
<
p
hi
::
CPU
Context
,
T
,
IndexT
>
{
public:
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
void
operator
()(
const
phi
::
CPUContext
&
dev_ctx
,
const
framework
::
Tensor
&
input
,
const
DenseTensor
&
input
,
const
framework
::
Tensor
&
output
,
const
DenseTensor
&
output
,
const
framework
::
Tensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
in_grad
,
const
DenseTensor
&
segments
,
const
framework
::
Tensor
*
index
=
nullptr
,
DenseTensor
*
in_grad
,
paddle
::
optional
<
const
DenseTensor
&>
index
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
std
::
string
pooltype
=
"SUM"
)
{
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
auto
&
place
=
*
context
.
eigen_device
();
auto
&
place
=
*
dev_ctx
.
eigen_device
();
auto
curent_id
=
segment_ids
[
0
];
auto
curent_id
=
segment_ids
[
0
];
int64_t
last_idx
=
0
;
int64_t
last_idx
=
0
;
int64_t
w
=
in_grad
->
numel
()
/
in_grad
->
dims
()[
0
];
int64_t
w
=
in_grad
->
numel
()
/
in_grad
->
dims
()[
0
];
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
if
(
idx
<
segments
.
numel
())
{
if
(
idx
<
segments
.
numel
())
{
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
curent_id
,
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
platform
::
errors
::
InvalidArgument
(
curent_id
,
phi
::
errors
::
InvalidArgument
(
"The segment ids should be sorted, but got "
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
}
}
Tensor
out_g_t
=
out_grad
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
out_g_t
=
out_grad
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_g_t
=
in_grad
->
Slice
(
last_idx
,
idx
);
Tensor
in_g_t
=
in_grad
->
Slice
(
last_idx
,
idx
);
int64_t
h
=
idx
-
last_idx
;
int64_t
h
=
idx
-
last_idx
;
auto
in_g_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_g_t
,
{
h
,
w
});
auto
in_g_e
=
EigenMatrix
<
T
>::
From
(
in_g_t
,
{
h
,
w
});
auto
out_g_e
=
framework
::
EigenMatrix
<
T
>::
From
(
out_g_t
,
{
1
,
w
});
auto
out_g_e
=
EigenMatrix
<
T
>::
From
(
out_g_t
,
{
1
,
w
});
Eigen
::
DSizes
<
int
,
2
>
bcast
(
h
,
1
);
Eigen
::
DSizes
<
int
,
2
>
bcast
(
h
,
1
);
if
(
pooltype
==
"MEAN"
)
{
if
(
pooltype
==
"MEAN"
)
{
...
@@ -114,13 +126,13 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
...
@@ -114,13 +126,13 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
}
else
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
}
else
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
Tensor
out_t
=
output
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
out_t
=
output
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
auto
in_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_t
,
{
h
,
w
});
auto
in_e
=
EigenMatrix
<
T
>::
From
(
in_t
,
{
h
,
w
});
auto
out_e
=
framework
::
EigenMatrix
<
T
>::
From
(
out_t
,
{
1
,
w
});
auto
out_e
=
EigenMatrix
<
T
>::
From
(
out_t
,
{
1
,
w
});
in_g_e
.
device
(
place
)
=
in_g_e
.
device
(
place
)
=
(
in_e
==
out_e
.
broadcast
(
bcast
)).
template
cast
<
T
>()
*
(
in_e
==
out_e
.
broadcast
(
bcast
)).
template
cast
<
T
>()
*
out_g_e
.
broadcast
(
bcast
);
out_g_e
.
broadcast
(
bcast
);
}
else
{
}
else
{
PADDLE_THROW
(
p
latform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
"available, but got %s."
,
pooltype
));
pooltype
));
...
@@ -132,7 +144,7 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
...
@@ -132,7 +144,7 @@ class SegmentPoolGradFunctor<platform::CPUDeviceContext, T, IndexT> {
}
}
};
};
using
CPU
=
p
latform
::
CPUDevice
Context
;
using
CPU
=
p
hi
::
CPU
Context
;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int
>;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int
>;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int64_t
>;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int64_t
>;
template
class
SegmentPoolFunctor
<
CPU
,
double
,
int
>;
template
class
SegmentPoolFunctor
<
CPU
,
double
,
int
>;
...
@@ -142,5 +154,5 @@ template class SegmentPoolGradFunctor<CPU, float, int64_t>;
...
@@ -142,5 +154,5 @@ template class SegmentPoolGradFunctor<CPU, float, int64_t>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int64_t
>;
}
// namespace
operator
s
}
// namespace
func
s
}
// namespace p
addle
}
// namespace p
hi
paddle/
fluid/operators/math
/segment_pooling.cu
→
paddle/
phi/kernels/funcs
/segment_pooling.cu
浏览文件 @
a07f19ee
...
@@ -12,20 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,20 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/phi/kernels/funcs/segment_pooling.h"
#include <algorithm>
#include <algorithm>
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace
p
addle
{
namespace
p
hi
{
namespace
operator
s
{
namespace
func
s
{
using
Tensor
=
framework
::
Tensor
;
using
Tensor
=
Dense
Tensor
;
template
<
typename
T
,
typename
Index
,
int
DimTileSize
>
template
<
typename
T
,
typename
Index
,
int
DimTileSize
>
__global__
void
SegmentSumIdsKernel
(
const
Index
*
segment_ids
,
T
*
summed_ids
,
__global__
void
SegmentSumIdsKernel
(
const
Index
*
segment_ids
,
T
*
summed_ids
,
const
Index
input_length_size
,
const
Index
input_length_size
,
const
Index
total_stripe_count
)
{
const
Index
total_stripe_count
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
total_stripe_count
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
total_stripe_count
)
{
...
@@ -45,16 +49,19 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
...
@@ -45,16 +49,19 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
PADDLE_ENFORCE
(
current_segment_id
>=
last_segment_id
,
PADDLE_ENFORCE
(
current_segment_id
>=
last_segment_id
,
"the segment ids should be sorted, but got "
"the segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
dim_index_base
+
j
-
1
,
dim_index_base
+
j
,
dim_index_base
+
j
-
1
,
last_segment_id
,
current_segment_id
);
dim_index_base
+
j
,
last_segment_id
,
current_segment_id
);
if
(
current_segment_id
>
last_segment_id
)
{
if
(
current_segment_id
>
last_segment_id
)
{
for
(
Index
interval_id
=
last_segment_id
+
1
;
for
(
Index
interval_id
=
last_segment_id
+
1
;
interval_id
<
current_segment_id
;
++
interval_id
)
{
interval_id
<
current_segment_id
;
++
interval_id
)
{
*
(
summed_ids
+
interval_id
)
=
0
;
*
(
summed_ids
+
interval_id
)
=
0
;
}
}
if
(
j
>
0
)
{
if
(
j
>
0
)
{
if
(
last_segment_id
==
first_segment_id
)
{
if
(
last_segment_id
==
first_segment_id
)
{
platform
::
CudaAtomicAdd
(
summed_ids
+
last_segment_id
,
sum
);
p
addle
::
p
latform
::
CudaAtomicAdd
(
summed_ids
+
last_segment_id
,
sum
);
}
else
{
}
else
{
*
(
summed_ids
+
last_segment_id
)
=
sum
;
*
(
summed_ids
+
last_segment_id
)
=
sum
;
}
}
...
@@ -64,13 +71,15 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
...
@@ -64,13 +71,15 @@ __global__ void SegmentSumIdsKernel(const Index* segment_ids, T* summed_ids,
sum
+=
T
(
1
);
sum
+=
T
(
1
);
last_segment_id
=
current_segment_id
;
last_segment_id
=
current_segment_id
;
}
}
platform
::
CudaAtomicAdd
(
summed_ids
+
last_segment_id
,
sum
);
p
addle
::
p
latform
::
CudaAtomicAdd
(
summed_ids
+
last_segment_id
,
sum
);
}
}
}
}
template
<
typename
T
,
typename
Index
,
int
DimTileSize
>
template
<
typename
T
,
typename
Index
,
int
DimTileSize
>
__global__
void
SegmentMeanKernel
(
const
Index
*
segment_ids
,
const
T
*
input
,
__global__
void
SegmentMeanKernel
(
const
Index
*
segment_ids
,
T
*
output
,
T
*
summed_ids
,
const
T
*
input
,
T
*
output
,
T
*
summed_ids
,
const
Index
input_length_size
,
const
Index
input_length_size
,
const
Index
inner_dim_size
,
const
Index
inner_dim_size
,
const
Index
output_length_size
,
const
Index
output_length_size
,
...
@@ -93,7 +102,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
...
@@ -93,7 +102,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
if
(
current_segment_id
>
last_segment_id
)
{
if
(
current_segment_id
>
last_segment_id
)
{
// reset the interval value which do not have corresponding ids.
// reset the interval value which do not have corresponding ids.
for
(
Index
interval_id
=
last_segment_id
+
1
;
for
(
Index
interval_id
=
last_segment_id
+
1
;
interval_id
<
current_segment_id
;
++
interval_id
)
{
interval_id
<
current_segment_id
;
++
interval_id
)
{
*
(
output
+
interval_id
*
inner_dim_size
+
segment_offset
)
=
T
(
0
);
*
(
output
+
interval_id
*
inner_dim_size
+
segment_offset
)
=
T
(
0
);
}
}
...
@@ -102,8 +112,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
...
@@ -102,8 +112,8 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
last_segment_id
*
inner_dim_size
+
segment_offset
;
last_segment_id
*
inner_dim_size
+
segment_offset
;
if
(
last_segment_id
==
first_segment_id
)
{
if
(
last_segment_id
==
first_segment_id
)
{
p
latform
::
CudaAtomicAdd
(
output
+
output_index
,
p
addle
::
platform
::
CudaAtomicAdd
(
sum
/
*
(
summed_ids
+
last_segment_id
));
output
+
output_index
,
sum
/
*
(
summed_ids
+
last_segment_id
));
}
else
{
}
else
{
*
(
output
+
output_index
)
=
sum
/
*
(
summed_ids
+
last_segment_id
);
*
(
output
+
output_index
)
=
sum
/
*
(
summed_ids
+
last_segment_id
);
}
}
...
@@ -114,15 +124,14 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
...
@@ -114,15 +124,14 @@ __global__ void SegmentMeanKernel(const Index* segment_ids, const T* input,
last_segment_id
=
current_segment_id
;
last_segment_id
=
current_segment_id
;
}
}
Index
output_index
=
last_segment_id
*
inner_dim_size
+
segment_offset
;
Index
output_index
=
last_segment_id
*
inner_dim_size
+
segment_offset
;
platform
::
CudaAtomicAdd
(
output
+
output_index
,
p
addle
::
p
latform
::
CudaAtomicAdd
(
output
+
output_index
,
sum
/
*
(
summed_ids
+
last_segment_id
));
sum
/
*
(
summed_ids
+
last_segment_id
));
}
}
}
}
template
<
typename
T
,
typename
Index
,
typename
Helper
,
typename
Pool
>
template
<
typename
T
,
typename
Index
,
typename
Helper
,
typename
Pool
>
__global__
void
__launch_bounds__
(
1024
,
1
)
__global__
void
__launch_bounds__
(
1024
,
1
)
SegmentOpsKernel
(
SegmentOpsKernel
(
const
Index
*
segment_ids
,
const
T
*
input
,
T
*
output
,
const
Index
*
segment_ids
,
const
T
*
input
,
T
*
output
,
Helper
h
,
Pool
pool
)
{
Helper
h
,
Pool
pool
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
h
.
total_stripe_count
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
h
.
total_stripe_count
)
{
Index
segment_offset
,
dim_index_base
,
actual_height
;
Index
segment_offset
,
dim_index_base
,
actual_height
;
Index
inner_dim_size
=
h
.
inner_dim_size
;
Index
inner_dim_size
=
h
.
inner_dim_size
;
...
@@ -142,13 +151,16 @@ __global__ void __launch_bounds__(1024, 1)
...
@@ -142,13 +151,16 @@ __global__ void __launch_bounds__(1024, 1)
PADDLE_ENFORCE
(
current_segment_id
>=
last_segment_id
,
PADDLE_ENFORCE
(
current_segment_id
>=
last_segment_id
,
"The segment ids should be sorted, but got "
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
dim_index_base
+
j
-
1
,
dim_index_base
+
j
,
dim_index_base
+
j
-
1
,
last_segment_id
,
current_segment_id
);
dim_index_base
+
j
,
last_segment_id
,
current_segment_id
);
if
(
current_segment_id
>
last_segment_id
)
{
if
(
current_segment_id
>
last_segment_id
)
{
// reset the interval value which do not have corresponding ids.
// reset the interval value which do not have corresponding ids.
for
(
Index
interval_id
=
last_segment_id
+
1
;
for
(
Index
interval_id
=
last_segment_id
+
1
;
interval_id
<
current_segment_id
;
++
interval_id
)
{
interval_id
<
current_segment_id
;
++
interval_id
)
{
*
(
output
+
interval_id
*
inner_dim_size
+
segment_offset
)
=
T
(
0
);
*
(
output
+
interval_id
*
inner_dim_size
+
segment_offset
)
=
T
(
0
);
}
}
// don't update result when j=0
// don't update result when j=0
...
@@ -175,9 +187,12 @@ __global__ void __launch_bounds__(1024, 1)
...
@@ -175,9 +187,12 @@ __global__ void __launch_bounds__(1024, 1)
}
}
template
<
typename
T
,
typename
Index
,
typename
Helper
>
template
<
typename
T
,
typename
Index
,
typename
Helper
>
__global__
void
SegmentIndexGradKernel
(
const
Index
*
segment_ids
,
const
T
*
input
,
__global__
void
SegmentIndexGradKernel
(
const
Index
*
segment_ids
,
const
T
*
output
,
const
T
*
out_grad
,
const
T
*
input
,
T
*
in_grad
,
Helper
h
)
{
const
T
*
output
,
const
T
*
out_grad
,
T
*
in_grad
,
Helper
h
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
h
.
total_stripe_count
)
{
CUDA_KERNEL_LOOP
(
stripe_index
,
h
.
total_stripe_count
)
{
Index
segment_offset
,
dim_index_base
,
actual_height
;
Index
segment_offset
,
dim_index_base
,
actual_height
;
h
.
calculate
(
stripe_index
,
&
segment_offset
,
&
dim_index_base
,
&
actual_height
);
h
.
calculate
(
stripe_index
,
&
segment_offset
,
&
dim_index_base
,
&
actual_height
);
...
@@ -201,7 +216,7 @@ class MaxPool {
...
@@ -201,7 +216,7 @@ class MaxPool {
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
-
FLT_MAX
);
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
>
x
?
*
y
:
x
;
}
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
return
platform
::
CudaAtomicMax
(
address
,
val
);
return
p
addle
::
p
latform
::
CudaAtomicMax
(
address
,
val
);
}
}
};
};
...
@@ -211,7 +226,7 @@ class MinPool {
...
@@ -211,7 +226,7 @@ class MinPool {
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
FLT_MAX
);
}
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
FLT_MAX
);
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
<
x
?
*
y
:
x
;
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
<
x
?
*
y
:
x
;
}
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
return
platform
::
CudaAtomicMin
(
address
,
val
);
return
p
addle
::
p
latform
::
CudaAtomicMin
(
address
,
val
);
}
}
};
};
...
@@ -221,7 +236,7 @@ class SumPool {
...
@@ -221,7 +236,7 @@ class SumPool {
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
T
initial
()
{
return
static_cast
<
T
>
(
0
);
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
+
x
;
}
DEVICE
inline
void
compute
(
const
T
&
x
,
T
*
y
)
{
*
y
=
*
y
+
x
;
}
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
DEVICE
inline
T
atomic
(
T
*
address
,
const
T
val
)
{
return
platform
::
CudaAtomicAdd
(
address
,
val
);
return
p
addle
::
p
latform
::
CudaAtomicAdd
(
address
,
val
);
}
}
};
};
...
@@ -243,8 +258,10 @@ class ArrangeHelper {
...
@@ -243,8 +258,10 @@ class ArrangeHelper {
total_stripe_count
=
inner_dim_size
*
input_outer_dim_num_stripe
;
total_stripe_count
=
inner_dim_size
*
input_outer_dim_num_stripe
;
}
}
DEVICE
inline
void
calculate
(
T
stripe_index
,
T
*
segment_offset
,
DEVICE
inline
void
calculate
(
T
stripe_index
,
T
*
dim_index_base
,
T
*
actual_height
)
{
T
*
segment_offset
,
T
*
dim_index_base
,
T
*
actual_height
)
{
*
segment_offset
=
stripe_index
%
inner_dim_size
;
*
segment_offset
=
stripe_index
%
inner_dim_size
;
*
dim_index_base
=
stripe_index
/
inner_dim_size
*
DimTileSize
;
*
dim_index_base
=
stripe_index
/
inner_dim_size
*
DimTileSize
;
*
actual_height
=
min
(
DimTileSize
,
input_length_size
-
*
dim_index_base
);
*
actual_height
=
min
(
DimTileSize
,
input_length_size
-
*
dim_index_base
);
...
@@ -252,23 +269,32 @@ class ArrangeHelper {
...
@@ -252,23 +269,32 @@ class ArrangeHelper {
};
};
template
<
typename
T
,
typename
Index
>
template
<
typename
T
,
typename
Index
>
void
SegmentPoolCUDAGradFunctor
(
const
p
latform
::
CUDADevice
Context
&
ctx
,
void
SegmentPoolCUDAGradFunctor
(
const
p
hi
::
GPU
Context
&
ctx
,
const
framework
::
Tensor
&
input
,
const
Dense
Tensor
&
input
,
const
framework
::
Tensor
&
segment_ids
,
const
Dense
Tensor
&
segment_ids
,
const
framework
::
Tensor
&
output
,
const
Dense
Tensor
&
output
,
const
framework
::
Tensor
&
out_grad
,
const
Dense
Tensor
&
out_grad
,
framework
::
Tensor
*
in_grad
,
Dense
Tensor
*
in_grad
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
std
::
string
pooltype
=
"SUM"
)
{
auto
h
=
ArrangeHelper
<
Index
>
(
input
.
numel
(),
segment_ids
.
dims
()[
0
],
auto
h
=
ArrangeHelper
<
Index
>
(
output
.
dims
()[
0
]);
input
.
numel
(),
segment_ids
.
dims
()[
0
],
output
.
dims
()[
0
]);
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
,
h
.
total_stripe_count
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
h
.
total_stripe_count
);
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
SegmentIndexGradKernel
<
T
,
Index
,
ArrangeHelper
<
Index
>><<<
SegmentIndexGradKernel
<
T
,
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
ctx
.
stream
()
>>>
(
Index
,
segment_ids
.
data
<
Index
>
(),
input
.
data
<
T
>
(),
output
.
data
<
T
>
(),
ArrangeHelper
<
Index
>><<<
config
.
block_per_grid
.
x
,
out_grad
.
data
<
T
>
(),
in_grad
->
data
<
T
>
(),
h
);
config
.
thread_per_block
.
x
,
0
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
Index
>
(),
input
.
data
<
T
>
(),
output
.
data
<
T
>
(),
out_grad
.
data
<
T
>
(),
in_grad
->
data
<
T
>
(),
h
);
}
else
{
}
else
{
PADDLE_THROW
(
p
latform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"Unsupported segment pooling grad operation, Only MAX, MIN "
"Unsupported segment pooling grad operation, Only MAX, MIN "
"available, but got %s."
,
"available, but got %s."
,
pooltype
));
pooltype
));
...
@@ -291,13 +317,13 @@ __global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) {
...
@@ -291,13 +317,13 @@ __global__ void SimpleDiv(T* x, const T* y, const int len, const int dim) {
}
}
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolFunctor
<
p
latform
::
CUDADevice
Context
,
T
,
IndexT
>
{
class
SegmentPoolFunctor
<
p
hi
::
GPU
Context
,
T
,
IndexT
>
{
public:
public:
void
operator
()(
const
p
latform
::
CUDADevice
Context
&
ctx
,
void
operator
()(
const
p
hi
::
GPU
Context
&
ctx
,
const
framework
::
Tensor
&
input
,
const
Dense
Tensor
&
input
,
const
framework
::
Tensor
&
segment_ids
,
const
Dense
Tensor
&
segment_ids
,
framework
::
Tensor
*
output
,
Dense
Tensor
*
output
,
framework
::
Tensor
*
summed_ids
=
nullptr
,
Dense
Tensor
*
summed_ids
=
nullptr
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
std
::
string
pooltype
=
"SUM"
)
{
if
(
pooltype
==
"MEAN"
)
{
if
(
pooltype
==
"MEAN"
)
{
// Sum the segment id num first
// Sum the segment id num first
...
@@ -305,50 +331,76 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
...
@@ -305,50 +331,76 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
auto
input_length_size
=
segment_ids
.
numel
();
auto
input_length_size
=
segment_ids
.
numel
();
auto
total_stripe_count
=
auto
total_stripe_count
=
(
input_length_size
+
DimTileSize
-
1
)
/
DimTileSize
;
(
input_length_size
+
DimTileSize
-
1
)
/
DimTileSize
;
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
,
total_stripe_count
);
auto
config
=
SegmentSumIdsKernel
<
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
total_stripe_count
);
T
,
IndexT
,
IndexT
(
8
)
><<<
config
.
block_per_grid
.
x
,
SegmentSumIdsKernel
<
T
,
IndexT
,
IndexT
(
8
)
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
ctx
.
stream
()
>>>
(
config
.
thread_per_block
.
x
,
segment_ids
.
data
<
IndexT
>
(),
summed_ids
->
data
<
T
>
(),
input_length_size
,
0
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
summed_ids
->
data
<
T
>
(),
input_length_size
,
total_stripe_count
);
total_stripe_count
);
}
}
auto
h
=
ArrangeHelper
<
IndexT
>
(
input
.
numel
(),
segment_ids
.
dims
()[
0
],
auto
h
=
ArrangeHelper
<
IndexT
>
(
output
->
dims
()[
0
]);
input
.
numel
(),
segment_ids
.
dims
()[
0
],
output
->
dims
()[
0
]);
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
ctx
,
h
.
total_stripe_count
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
ctx
,
h
.
total_stripe_count
);
if
(
pooltype
==
"MEAN"
)
{
if
(
pooltype
==
"MEAN"
)
{
SegmentMeanKernel
<
SegmentMeanKernel
<
T
,
IndexT
,
IndexT
(
8
)
><<<
config
.
block_per_grid
.
x
,
T
,
IndexT
,
IndexT
(
8
)
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
config
.
thread_per_block
.
x
,
0
,
ctx
.
stream
()
>>>
(
0
,
segment_ids
.
data
<
IndexT
>
(),
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
ctx
.
stream
()
>>>
(
summed_ids
->
data
<
T
>
(),
h
.
input_length_size
,
h
.
inner_dim_size
,
segment_ids
.
data
<
IndexT
>
(),
h
.
output_length_size
,
h
.
total_stripe_count
);
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
summed_ids
->
data
<
T
>
(),
h
.
input_length_size
,
h
.
inner_dim_size
,
h
.
output_length_size
,
h
.
total_stripe_count
);
}
else
if
(
pooltype
==
"SUM"
)
{
}
else
if
(
pooltype
==
"SUM"
)
{
SumPool
<
T
>
pool
;
SumPool
<
T
>
pool
;
SegmentOpsKernel
<
SegmentOpsKernel
<
T
,
T
,
IndexT
,
ArrangeHelper
<
IndexT
>
,
IndexT
,
SumPool
<
T
>><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
ArrangeHelper
<
IndexT
>
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
SumPool
<
T
>><<<
config
.
block_per_grid
.
x
,
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
config
.
thread_per_block
.
x
,
pool
);
0
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
pool
);
}
else
if
(
pooltype
==
"MAX"
)
{
}
else
if
(
pooltype
==
"MAX"
)
{
MaxPool
<
T
>
pool
;
MaxPool
<
T
>
pool
;
SegmentOpsKernel
<
SegmentOpsKernel
<
T
,
T
,
IndexT
,
ArrangeHelper
<
IndexT
>
,
IndexT
,
MaxPool
<
T
>><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
ArrangeHelper
<
IndexT
>
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
MaxPool
<
T
>><<<
config
.
block_per_grid
.
x
,
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
config
.
thread_per_block
.
x
,
pool
);
0
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
pool
);
}
else
if
(
pooltype
==
"MIN"
)
{
}
else
if
(
pooltype
==
"MIN"
)
{
MinPool
<
T
>
pool
;
MinPool
<
T
>
pool
;
SegmentOpsKernel
<
SegmentOpsKernel
<
T
,
T
,
IndexT
,
ArrangeHelper
<
IndexT
>
,
IndexT
,
MinPool
<
T
>><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
ArrangeHelper
<
IndexT
>
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
MinPool
<
T
>><<<
config
.
block_per_grid
.
x
,
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
config
.
thread_per_block
.
x
,
pool
);
0
,
ctx
.
stream
()
>>>
(
segment_ids
.
data
<
IndexT
>
(),
input
.
data
<
T
>
(),
output
->
data
<
T
>
(),
h
,
pool
);
}
else
{
}
else
{
PADDLE_THROW
(
p
latform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
"Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
"available, but got %s."
,
pooltype
));
pooltype
));
...
@@ -357,33 +409,38 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
...
@@ -357,33 +409,38 @@ class SegmentPoolFunctor<platform::CUDADeviceContext, T, IndexT> {
};
};
template
<
typename
T
,
typename
IndexT
>
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolGradFunctor
<
p
latform
::
CUDADevice
Context
,
T
,
IndexT
>
{
class
SegmentPoolGradFunctor
<
p
hi
::
GPU
Context
,
T
,
IndexT
>
{
public:
public:
void
operator
()(
const
platform
::
CUDADeviceContext
&
context
,
void
operator
()(
const
phi
::
GPUContext
&
dev_ctx
,
const
framework
::
Tensor
&
input
,
const
DenseTensor
&
input
,
const
framework
::
Tensor
&
output
,
const
DenseTensor
&
output
,
const
framework
::
Tensor
&
out_grad
,
const
DenseTensor
&
out_grad
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
in_grad
,
const
DenseTensor
&
segments
,
const
framework
::
Tensor
*
summed_ids
=
nullptr
,
DenseTensor
*
in_grad
,
paddle
::
optional
<
const
DenseTensor
&>
summed_ids
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
std
::
string
pooltype
=
"SUM"
)
{
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
SegmentPoolCUDAGradFunctor
<
T
,
IndexT
>
(
context
,
input
,
segments
,
output
,
SegmentPoolCUDAGradFunctor
<
T
,
IndexT
>
(
out_grad
,
in_grad
,
pooltype
);
dev_ctx
,
input
,
segments
,
output
,
out_grad
,
in_grad
,
pooltype
);
}
else
if
(
pooltype
==
"MEAN"
)
{
}
else
if
(
pooltype
==
"MEAN"
)
{
framework
::
Tensor
mean_grad
;
DenseTensor
mean_grad
;
mean_grad
.
mutable_data
<
T
>
(
input
.
dims
(),
context
.
GetPlace
());
mean_grad
.
Resize
(
input
.
dims
());
framework
::
TensorCopy
(
out_grad
,
context
.
GetPlace
(),
context
,
&
mean_grad
);
dev_ctx
.
template
Alloc
<
T
>(
&
mean_grad
);
paddle
::
framework
::
TensorCopy
(
out_grad
,
dev_ctx
.
GetPlace
(),
dev_ctx
,
&
mean_grad
);
int
len
=
output
.
dims
()[
0
];
int
len
=
output
.
dims
()[
0
];
int
dim
=
output
.
numel
()
/
len
;
int
dim
=
output
.
numel
()
/
len
;
auto
config
=
platform
::
GetGpuLaunchConfig1D
(
context
,
len
);
auto
config
=
phi
::
backends
::
gpu
::
GetGpuLaunchConfig1D
(
dev_ctx
,
len
);
SimpleDiv
<
T
><<<
config
.
block_per_grid
.
x
,
config
.
thread_per_block
.
x
,
0
,
SimpleDiv
<
T
><<<
config
.
block_per_grid
.
x
,
context
.
stream
()
>>>
(
mean_grad
.
data
<
T
>
(),
config
.
thread_per_block
.
x
,
summed_ids
->
data
<
T
>
(),
len
,
dim
);
0
,
phi
::
funcs
::
GPUGather
<
T
,
IndexT
>
(
context
,
mean_grad
,
segments
,
in_grad
);
dev_ctx
.
stream
()
>>>
(
mean_grad
.
data
<
T
>
(),
summed_ids
->
data
<
T
>
(),
len
,
dim
);
phi
::
funcs
::
GPUGather
<
T
,
IndexT
>
(
dev_ctx
,
mean_grad
,
segments
,
in_grad
);
}
else
if
(
pooltype
==
"SUM"
)
{
}
else
if
(
pooltype
==
"SUM"
)
{
phi
::
funcs
::
GPUGather
<
T
,
IndexT
>
(
context
,
out_grad
,
segments
,
in_grad
);
phi
::
funcs
::
GPUGather
<
T
,
IndexT
>
(
dev_ctx
,
out_grad
,
segments
,
in_grad
);
}
else
{
}
else
{
PADDLE_THROW
(
p
latform
::
errors
::
InvalidArgument
(
PADDLE_THROW
(
p
hi
::
errors
::
InvalidArgument
(
"Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
"Unsupported segment pooling operation, Only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
"available, but got %s."
,
pooltype
));
pooltype
));
...
@@ -391,15 +448,15 @@ class SegmentPoolGradFunctor<platform::CUDADeviceContext, T, IndexT> {
...
@@ -391,15 +448,15 @@ class SegmentPoolGradFunctor<platform::CUDADeviceContext, T, IndexT> {
}
}
};
};
using
CUDA
=
paddle
::
platform
::
CUDADevice
Context
;
using
GPU
=
phi
::
GPU
Context
;
template
class
SegmentPoolFunctor
<
CUDA
,
float
,
int
>;
template
class
SegmentPoolFunctor
<
GPU
,
float
,
int
>;
template
class
SegmentPoolFunctor
<
CUDA
,
float
,
int64_t
>;
template
class
SegmentPoolFunctor
<
GPU
,
float
,
int64_t
>;
template
class
SegmentPoolFunctor
<
CUDA
,
double
,
int
>;
template
class
SegmentPoolFunctor
<
GPU
,
double
,
int
>;
template
class
SegmentPoolFunctor
<
CUDA
,
double
,
int64_t
>;
template
class
SegmentPoolFunctor
<
GPU
,
double
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
CUDA
,
float
,
int
>;
template
class
SegmentPoolGradFunctor
<
GPU
,
float
,
int
>;
template
class
SegmentPoolGradFunctor
<
CUDA
,
float
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
GPU
,
float
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
CUDA
,
double
,
int
>;
template
class
SegmentPoolGradFunctor
<
GPU
,
double
,
int
>;
template
class
SegmentPoolGradFunctor
<
CUDA
,
double
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
GPU
,
double
,
int64_t
>;
}
// namespace
operator
s
}
// namespace
func
s
}
// namespace p
addle
}
// namespace p
hi
paddle/
fluid/operators/math
/segment_pooling.h
→
paddle/
phi/kernels/funcs
/segment_pooling.h
浏览文件 @
a07f19ee
...
@@ -14,33 +14,36 @@ limitations under the License. */
...
@@ -14,33 +14,36 @@ limitations under the License. */
#pragma once
#pragma once
#include <string>
#include <string>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace
p
addle
{
namespace
p
hi
{
namespace
operator
s
{
namespace
func
s
{
template
<
typename
Device
Context
,
typename
T
,
typename
IndexT
>
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
class
SegmentPoolFunctor
{
class
SegmentPoolFunctor
{
public:
public:
/* mean pool has summed_ids output */
/* mean pool has summed_ids output */
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
void
operator
()(
const
Context
&
dev_ctx
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
output
,
const
DenseTensor
&
input
,
framework
::
Tensor
*
summed_ids
=
nullptr
,
const
DenseTensor
&
segments
,
DenseTensor
*
output
,
DenseTensor
*
summed_ids
=
nullptr
,
const
std
::
string
pooltype
=
"SUM"
);
const
std
::
string
pooltype
=
"SUM"
);
};
};
template
<
typename
Device
Context
,
typename
T
,
typename
IndexT
>
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
class
SegmentPoolGradFunctor
{
class
SegmentPoolGradFunctor
{
public:
public:
/* mean pool has summed_ids output */
/* mean pool has summed_ids output */
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
void
operator
()(
const
Context
&
dev_ctx
,
const
framework
::
Tensor
&
output
,
const
DenseTensor
&
input
,
const
framework
::
Tensor
&
out_grad
,
const
DenseTensor
&
output
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
in_grad
,
const
DenseTensor
&
out_grad
,
const
framework
::
Tensor
*
summed_ids
=
nullptr
,
const
DenseTensor
&
segments
,
DenseTensor
*
in_grad
,
paddle
::
optional
<
const
DenseTensor
&>
summed_ids
,
const
std
::
string
pooltype
=
"SUM"
);
const
std
::
string
pooltype
=
"SUM"
);
};
};
}
// namespace
operator
s
}
// namespace
func
s
}
// namespace p
addle
}
// namespace p
hi
paddle/phi/kernels/gpu/segment_pool_grad_kernel.cu
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_grad_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
segment_pool_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
SegmentPoolGradKernel
,
float
,
double
)
{}
paddle/phi/kernels/gpu/segment_pool_kernel.cu
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/impl/segment_pool_kernel_impl.h"
#include "paddle/phi/kernels/segment_pool_kernel.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
PD_REGISTER_KERNEL
(
segment_pool
,
GPU
,
ALL_LAYOUT
,
phi
::
SegmentPoolKernel
,
float
,
double
)
{}
paddle/phi/kernels/impl/segment_pool_grad_kernel_impl.h
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SegmentPoolGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
segment_ids
,
const
DenseTensor
&
out
,
paddle
::
optional
<
const
DenseTensor
&>
summed_ids
,
const
DenseTensor
&
out_grad
,
const
std
::
string
&
pooltype
,
DenseTensor
*
x_grad
)
{
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
x_grad
,
static_cast
<
T
>
(
0
));
auto
index_type
=
segment_ids
.
type
();
if
(
index_type
==
DataType
::
INT32
)
{
phi
::
funcs
::
SegmentPoolGradFunctor
<
Context
,
T
,
int
>
pool
;
pool
(
dev_ctx
,
x
,
out
,
out_grad
,
segment_ids
,
x_grad
,
summed_ids
,
pooltype
);
}
else
if
(
index_type
==
DataType
::
INT64
)
{
phi
::
funcs
::
SegmentPoolGradFunctor
<
Context
,
T
,
int64_t
>
pool
;
pool
(
dev_ctx
,
x
,
out
,
out_grad
,
segment_ids
,
x_grad
,
summed_ids
,
pooltype
);
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
}
// namespace phi
paddle/phi/kernels/impl/segment_pool_kernel_impl.h
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/segment_pooling.h"
namespace
phi
{
template
<
typename
Context
,
typename
T
,
typename
IndexT
>
void
SegmentKernelLaunchHelper
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
segment_ids
,
const
std
::
string
&
pooltype
,
DenseTensor
*
out
,
DenseTensor
*
summed_ids
)
{
int64_t
num_indices
=
segment_ids
.
numel
();
PADDLE_ENFORCE_EQ
(
num_indices
,
x
.
dims
()[
0
],
phi
::
errors
::
InvalidArgument
(
"Segment_ids should be the same size as dimension 0 of input X."
));
PADDLE_ENFORCE_EQ
(
num_indices
,
segment_ids
.
dims
()[
0
],
phi
::
errors
::
InvalidArgument
(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s]."
,
segment_ids
.
dims
()));
if
(
x
.
numel
()
==
0
||
segment_ids
.
numel
()
==
0
)
{
return
;
}
bool
cpu_place
=
dev_ctx
.
GetPlace
().
GetType
()
==
phi
::
AllocationType
::
CPU
;
if
(
cpu_place
)
{
auto
dims
=
x
.
dims
();
auto
*
segment_ids_ptr
=
segment_ids
.
data
<
IndexT
>
();
dims
[
0
]
=
static_cast
<
int64_t
>
(
segment_ids_ptr
[
segment_ids
.
numel
()
-
1
]
+
1
);
PADDLE_ENFORCE_GT
(
dims
[
0
],
0
,
phi
::
errors
::
InvalidArgument
(
"Segment ids must be >= 0, but got last id %d"
,
dims
[
0
]));
out
->
Resize
({
dims
});
dev_ctx
.
template
Alloc
<
T
>(
out
);
phi
::
funcs
::
SetConstant
<
Context
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
out
,
static_cast
<
T
>
(
0
));
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if
(
!
cpu_place
)
{
DenseTensor
length
;
length
.
Resize
(
phi
::
make_ddim
({
1
}));
IndexT
*
length_data
=
dev_ctx
.
template
HostAlloc
<
IndexT
>(
&
length
);
const
IndexT
*
segment_ids_ptr
=
segment_ids
.
data
<
IndexT
>
();
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS
(
hipMemcpy
(
length_data
,
segment_ids_ptr
+
num_indices
-
1
,
sizeof
(
IndexT
),
hipMemcpyDeviceToHost
));
#else
PADDLE_ENFORCE_GPU_SUCCESS
(
cudaMemcpy
(
length_data
,
segment_ids_ptr
+
num_indices
-
1
,
sizeof
(
IndexT
),
cudaMemcpyDeviceToHost
));
#endif
IndexT
length_host
=
length_data
[
0
];
length_host
++
;
PADDLE_ENFORCE_GT
(
length_host
,
0
,
phi
::
errors
::
InvalidArgument
(
"Segment ids must be >= 0, but got last id %d"
,
length_data
[
0
]));
auto
dims
=
x
.
dims
();
dims
[
0
]
=
static_cast
<
int64_t
>
(
length_host
);
out
->
Resize
({
dims
});
dev_ctx
.
template
Alloc
<
T
>(
out
);
T
init_value
=
0
;
if
(
pooltype
==
"MAX"
)
{
init_value
=
static_cast
<
T
>
(
-
FLT_MAX
);
}
else
if
(
pooltype
==
"MIN"
)
{
init_value
=
static_cast
<
T
>
(
FLT_MAX
);
}
phi
::
funcs
::
SetConstant
<
Context
,
T
>
setconst
;
setconst
(
dev_ctx
,
out
,
static_cast
<
T
>
(
init_value
));
// the gpu kernel of mean pool record the counts of segment_ids
if
(
pooltype
==
"MEAN"
)
{
summed_ids
->
Resize
({
dims
[
0
],
1
});
dev_ctx
.
template
Alloc
<
T
>(
summed_ids
);
setconst
(
dev_ctx
,
summed_ids
,
static_cast
<
T
>
(
1e-12
));
}
}
#endif
phi
::
funcs
::
SegmentPoolFunctor
<
Context
,
T
,
IndexT
>
pool
;
pool
(
dev_ctx
,
x
,
segment_ids
,
out
,
summed_ids
,
pooltype
);
}
template
<
typename
T
,
typename
Context
>
void
SegmentPoolKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
segment_ids
,
const
std
::
string
&
pooltype
,
DenseTensor
*
out
,
DenseTensor
*
summed_ids
)
{
auto
index_type
=
segment_ids
.
dtype
();
if
(
index_type
==
DataType
::
INT32
)
{
SegmentKernelLaunchHelper
<
Context
,
T
,
int
>
(
dev_ctx
,
x
,
segment_ids
,
pooltype
,
out
,
summed_ids
);
}
else
if
(
index_type
==
DataType
::
INT64
)
{
SegmentKernelLaunchHelper
<
Context
,
T
,
int64_t
>
(
dev_ctx
,
x
,
segment_ids
,
pooltype
,
out
,
summed_ids
);
}
else
{
PADDLE_THROW
(
phi
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
}
// namespace phi
paddle/phi/kernels/segment_pool_grad_kernel.h
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SegmentPoolGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
segment_ids
,
const
DenseTensor
&
out
,
paddle
::
optional
<
const
DenseTensor
&>
summed_ids
,
const
DenseTensor
&
out_grad
,
const
std
::
string
&
pooltype
,
DenseTensor
*
x_grad
);
}
// namespace phi
paddle/phi/kernels/segment_pool_kernel.h
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SegmentPoolKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
segment_ids
,
const
std
::
string
&
pooltype
,
DenseTensor
*
out
,
DenseTensor
*
summed_ids
);
}
// namespace phi
paddle/phi/ops/compat/segment_pool_sig.cc
0 → 100644
浏览文件 @
a07f19ee
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
SegmentPoolGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
)
{
return
KernelSignature
(
"segment_pool_grad"
,
{
"X"
,
"SegmentIds"
,
"Out"
,
"SummedIds"
,
GradVarName
(
"Out"
),
},
{
"pooltype"
},
{
GradVarName
(
"X"
)});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
segment_pool_grad
,
phi
::
SegmentPoolGradOpArgumentMapping
);
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录