Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
f4c750d7
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
f4c750d7
编写于
9月 22, 2020
作者:
Z
Zhong Hui
提交者:
GitHub
9月 22, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add the cpu version of segment sum mean max min op
Add the cpu version of segment sum mean max min op
上级
afe94903
变更
7
显示空白变更内容
内联
并排
Showing
7 changed file
with
694 addition
and
1 deletion
+694
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+1
-1
paddle/fluid/operators/math/CMakeLists.txt
paddle/fluid/operators/math/CMakeLists.txt
+1
-0
paddle/fluid/operators/math/segment_pooling.cc
paddle/fluid/operators/math/segment_pooling.cc
+148
-0
paddle/fluid/operators/math/segment_pooling.h
paddle/fluid/operators/math/segment_pooling.h
+46
-0
paddle/fluid/operators/segment_pool_op.cc
paddle/fluid/operators/segment_pool_op.cc
+166
-0
paddle/fluid/operators/segment_pool_op.h
paddle/fluid/operators/segment_pool_op.h
+130
-0
python/paddle/fluid/tests/unittests/test_segment_ops.py
python/paddle/fluid/tests/unittests/test_segment_ops.py
+202
-0
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
f4c750d7
...
...
@@ -92,7 +92,7 @@ cc_library(common_infer_shape_functions SRCS common_infer_shape_functions.cc DEP
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
selected_rows_functor selected_rows
lod_tensor maxouting unpooling pooling lod_rank_table context_project
sequence_pooling executor device_memory_aligment generator
)
sequence_pooling
segment_pooling
executor device_memory_aligment generator
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
dynload_warpctc
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler sample_prob tree2col
)
set
(
COMMON_OP_DEPS
${
COMMON_OP_DEPS
}
sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search fc matrix_inverse
)
...
...
paddle/fluid/operators/math/CMakeLists.txt
浏览文件 @
f4c750d7
...
...
@@ -76,6 +76,7 @@ math_library(prelu)
math_library
(
bert_encoder_functor
)
math_library
(
tree2col DEPS math_function
)
math_library
(
matrix_inverse
)
math_library
(
segment_pooling
)
cc_test
(
math_function_test SRCS math_function_test.cc DEPS math_function
)
cc_test
(
selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor
)
...
...
paddle/fluid/operators/math/segment_pooling.cc
0 → 100644
浏览文件 @
f4c750d7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/segment_pooling.h"
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolFunctor
<
platform
::
CPUDeviceContext
,
T
,
IndexT
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
index
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
auto
curent_id
=
segment_ids
[
0
];
int64_t
last_idx
=
0
;
int64_t
w
=
input
.
numel
()
/
input
.
dims
()[
0
];
auto
&
place
=
*
context
.
eigen_device
();
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
if
(
idx
<
segments
.
numel
())
{
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
curent_id
,
platform
::
errors
::
InvalidArgument
(
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
}
Tensor
out_t
=
output
->
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
int64_t
h
=
idx
-
last_idx
;
auto
in_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_t
,
framework
::
make_ddim
({
h
,
w
}));
auto
out_e
=
framework
::
EigenVector
<
T
>::
Flatten
(
out_t
);
auto
reduce_dim
=
Eigen
::
array
<
int
,
1
>
({{
0
}});
if
(
pooltype
==
"MEAN"
)
{
out_e
.
device
(
place
)
=
in_e
.
mean
(
reduce_dim
);
}
else
if
(
pooltype
==
"SUM"
)
{
out_e
.
device
(
place
)
=
in_e
.
sum
(
reduce_dim
);
}
else
if
(
pooltype
==
"MAX"
)
{
out_e
.
device
(
place
)
=
in_e
.
maximum
(
reduce_dim
);
}
else
if
(
pooltype
==
"MIN"
)
{
out_e
.
device
(
place
)
=
in_e
.
minimum
(
reduce_dim
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
pooltype
));
}
last_idx
=
idx
;
if
(
idx
<
segments
.
numel
())
curent_id
=
segment_ids
[
idx
];
}
}
};
template
<
typename
T
,
typename
IndexT
>
class
SegmentPoolGradFunctor
<
platform
::
CPUDeviceContext
,
T
,
IndexT
>
{
public:
void
operator
()(
const
platform
::
CPUDeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
out_grad
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
in_grad
,
const
framework
::
Tensor
*
index
=
nullptr
,
const
std
::
string
pooltype
=
"SUM"
)
{
const
IndexT
*
segment_ids
=
segments
.
data
<
IndexT
>
();
auto
&
place
=
*
context
.
eigen_device
();
auto
curent_id
=
segment_ids
[
0
];
int64_t
last_idx
=
0
;
int64_t
w
=
in_grad
->
numel
()
/
in_grad
->
dims
()[
0
];
for
(
int64_t
idx
=
1
;
idx
<=
segments
.
numel
();
++
idx
)
{
if
(
idx
<
segments
.
numel
())
{
if
(
segment_ids
[
idx
]
==
curent_id
)
continue
;
PADDLE_ENFORCE_GE
(
segment_ids
[
idx
],
curent_id
,
platform
::
errors
::
InvalidArgument
(
"The segment ids should be sorted, but got "
"segment_ids[%d]:%d > segment_ids[%d]:%d."
,
idx
-
1
,
curent_id
,
idx
,
segment_ids
[
idx
]));
}
Tensor
out_g_t
=
out_grad
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_g_t
=
in_grad
->
Slice
(
last_idx
,
idx
);
int64_t
h
=
idx
-
last_idx
;
auto
in_g_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_g_t
,
{
h
,
w
});
auto
out_g_e
=
framework
::
EigenMatrix
<
T
>::
From
(
out_g_t
,
{
1
,
w
});
Eigen
::
DSizes
<
int
,
2
>
bcast
(
h
,
1
);
if
(
pooltype
==
"MEAN"
)
{
in_g_e
.
device
(
place
)
=
(
out_g_e
/
static_cast
<
T
>
(
h
)).
broadcast
(
bcast
);
}
else
if
(
pooltype
==
"SUM"
)
{
in_g_e
.
device
(
place
)
=
out_g_e
.
broadcast
(
bcast
);
}
else
if
(
pooltype
==
"MAX"
||
pooltype
==
"MIN"
)
{
Tensor
out_t
=
output
.
Slice
(
curent_id
,
curent_id
+
1
);
Tensor
in_t
=
input
.
Slice
(
last_idx
,
idx
);
auto
in_e
=
framework
::
EigenMatrix
<
T
>::
From
(
in_t
,
{
h
,
w
});
auto
out_e
=
framework
::
EigenMatrix
<
T
>::
From
(
out_t
,
{
1
,
w
});
in_g_e
.
device
(
place
)
=
(
in_e
==
out_e
.
broadcast
(
bcast
)).
template
cast
<
T
>()
*
out_g_e
.
broadcast
(
bcast
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported segment pooling type, only MEAN, SUM, MAX, MIN "
"available, but got %s."
,
pooltype
));
}
last_idx
=
idx
;
if
(
idx
<
segments
.
numel
())
curent_id
=
segment_ids
[
idx
];
}
}
};
using
CPU
=
platform
::
CPUDeviceContext
;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int
>;
template
class
SegmentPoolFunctor
<
CPU
,
float
,
int64_t
>;
template
class
SegmentPoolFunctor
<
CPU
,
double
,
int
>;
template
class
SegmentPoolFunctor
<
CPU
,
double
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
float
,
int
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
float
,
int64_t
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int
>;
template
class
SegmentPoolGradFunctor
<
CPU
,
double
,
int64_t
>;
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/math/segment_pooling.h
0 → 100644
浏览文件 @
f4c750d7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
>
class
SegmentPoolFunctor
{
public:
/* mean pool has summed_ids output */
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
summed_ids
=
nullptr
,
const
std
::
string
pooltype
=
"SUM"
);
};
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
>
class
SegmentPoolGradFunctor
{
public:
/* mean pool has summed_ids output */
void
operator
()(
const
DeviceContext
&
context
,
const
framework
::
Tensor
&
input
,
const
framework
::
Tensor
&
output
,
const
framework
::
Tensor
&
out_grad
,
const
framework
::
Tensor
&
segments
,
framework
::
Tensor
*
in_grad
,
const
framework
::
Tensor
*
summed_ids
=
nullptr
,
const
std
::
string
pooltype
=
"SUM"
);
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/segment_pool_op.cc
0 → 100644
浏览文件 @
f4c750d7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/segment_pool_op.h"
#include <memory>
#include <string>
namespace
paddle
{
namespace
operators
{
class
SegmentPoolOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"SegmentPool"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SegmentIds"
),
"Input"
,
"SegmentIds"
,
"SegmentPool"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"SegmentPool"
);
auto
dims
=
ctx
->
GetInputDim
(
"X"
);
dims
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
dims
);
if
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"pooltype"
)
==
"MEAN"
)
{
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SummedIds"
),
"Output"
,
"SummedIds"
,
"SegmentPool"
);
ctx
->
SetOutputDim
(
"SummedIds"
,
{
-
1
,
1
});
}
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
device_context
());
}
};
class
SegmentPoolOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"(Tensor) The input data of SegmentPoolOp"
);
AddInput
(
"SegmentIds"
,
"(Tensor) 1-D tensor which have the same size with the fist "
"dimension of input X."
);
AddOutput
(
"Out"
,
"(Tensor) The output of SegmentPoolOp."
);
AddOutput
(
"SummedIds"
,
"(Tensor) This tensor is used to counts of segment ids for the "
"backward of the mean pool."
)
.
AsIntermediate
();
AddAttr
<
std
::
string
>
(
"pooltype"
,
"(string, default 'SUM') the pooling type of SegmentPoolOp."
)
.
SetDefault
(
"SUM"
)
.
InEnum
({
"SUM"
,
"MEAN"
,
"MIN"
,
"MAX"
});
AddComment
(
R"DOC(
Segment Pool Operator.
This operator will pool the elements of input `X` which with the same index
in `SegmentIds`.
For SUM operation, it computes a tensor such that $Out_i = \sum_{j} X_{j}$
where sum is over j such that `SegmentIds[j] == i`.
For MEAN operation, it computes a tensor such that
$Out_i = \frac{1}{n_i} \sum_{j} X_{j}$ where sum is over j such that
`SegmentIds[j] == i` and $n_i$ is the number of all index `SegmentIds[j] == i`.
For MIN operation, it computes a tensor such that $Out_i = \min_{j} X_{j}$
where min is over j such that `SegmentIds[j] == i`.
For MAX operation, it computes a tensor such that $Out_i = \max_{j} X_{j}$
where max is over j such that `SegmentIds[j] == i`.
)DOC"
);
}
};
class
SegmentPoolGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"SegmentPoolGrad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"SegmentPoolGrad"
);
auto
og_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
og_dims
.
size
(),
x_dims
.
size
(),
platform
::
errors
::
InvalidArgument
(
"The rank of output grad must equal to Input(X). But "
"received: input rank %u, input shape [%s]."
,
og_dims
.
size
(),
og_dims
));
for
(
int64_t
i
=
1
;
i
<
og_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
og_dims
[
i
],
x_dims
[
i
],
platform
::
errors
::
InvalidArgument
(
"The dimension mismatch between Input(OUT@GRAD) and "
"Input(X). Received Input(OUT@GRAD): input rank %u, "
"input shape [%s]; received Input(X): input rank %u, "
"input shape [%s]."
,
og_dims
.
size
(),
og_dims
,
x_dims
.
size
(),
x_dims
));
}
ctx
->
ShareDim
(
"X"
,
/*->*/
framework
::
GradVarName
(
"X"
));
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
device_context
());
}
};
template
<
typename
T
>
class
SegmentPoolGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op_desc_ptr
)
const
override
{
op_desc_ptr
->
SetType
(
"segment_pool_grad"
);
op_desc_ptr
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op_desc_ptr
->
SetInput
(
"SegmentIds"
,
this
->
Input
(
"SegmentIds"
));
op_desc_ptr
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
if
(
BOOST_GET_CONST
(
std
::
string
,
this
->
GetAttr
(
"pooltype"
))
==
"MEAN"
)
{
op_desc_ptr
->
SetInput
(
"SummedIds"
,
this
->
Output
(
"SummedIds"
));
}
op_desc_ptr
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op_desc_ptr
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
op_desc_ptr
->
SetAttrMap
(
this
->
Attrs
());
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
segment_pool
,
ops
::
SegmentPoolOp
,
ops
::
SegmentPoolOpMaker
,
ops
::
SegmentPoolGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SegmentPoolGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
segment_pool_grad
,
ops
::
SegmentPoolGradOp
);
REGISTER_OP_CPU_KERNEL
(
segment_pool
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SegmentPoolKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
REGISTER_OP_CPU_KERNEL
(
segment_pool_grad
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SegmentPoolGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/segment_pool_op.h
0 → 100644
浏览文件 @
f4c750d7
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/segment_pooling.h"
#include "paddle/fluid/platform/macros.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
,
typename
IndexT
>
void
SegmentKernelLaunchHelper
(
const
framework
::
ExecutionContext
&
context
)
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
*
output
=
context
.
Output
<
Tensor
>
(
"Out"
);
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
Tensor
*
summed_ids
=
nullptr
;
int64_t
num_indices
=
segment
->
numel
();
PADDLE_ENFORCE_EQ
(
num_indices
,
input
->
dims
()[
0
],
platform
::
errors
::
InvalidArgument
(
"Segment_ids should be the same size as dimension 0 of input X."
));
PADDLE_ENFORCE_EQ
(
num_indices
,
segment
->
dims
()[
0
],
platform
::
errors
::
InvalidArgument
(
"Segment_ids should be 1-D tensor, or it's other "
"dimension size is 1. Segment_ids's shape is: [%s]."
,
segment
->
dims
()));
if
(
input
->
numel
()
==
0
||
segment
->
numel
()
==
0
)
{
return
;
}
bool
cpu_place
=
context
.
GetPlace
().
type
()
==
typeid
(
platform
::
CPUPlace
);
if
(
cpu_place
)
{
auto
dims
=
input
->
dims
();
auto
*
segment_ids
=
segment
->
data
<
IndexT
>
();
dims
[
0
]
=
static_cast
<
int64_t
>
(
segment_ids
[
segment
->
numel
()
-
1
]
+
1
);
PADDLE_ENFORCE_GT
(
dims
[
0
],
0
,
platform
::
errors
::
InvalidArgument
(
"Segment ids must be >= 0, but got last id %d"
,
dims
[
0
]));
output
->
Resize
({
dims
});
output
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
output
,
static_cast
<
T
>
(
0
));
}
SegmentPoolFunctor
<
DeviceContext
,
T
,
IndexT
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
segment
,
output
,
summed_ids
,
pooltype
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
SegmentPoolKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
index_type
=
segment
->
type
();
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
SegmentKernelLaunchHelper
<
DeviceContext
,
T
,
int
>
(
context
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
SegmentKernelLaunchHelper
<
DeviceContext
,
T
,
int64_t
>
(
context
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SegmentPoolGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
output
=
context
.
Input
<
Tensor
>
(
"Out"
);
auto
*
segment
=
context
.
Input
<
Tensor
>
(
"SegmentIds"
);
auto
*
out_g
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
in_g
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
const
Tensor
*
summed_ids
=
nullptr
;
if
(
pooltype
==
"MEAN"
)
{
summed_ids
=
context
.
Input
<
Tensor
>
(
"SummedIds"
);
}
in_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
set_zero
(
dev_ctx
,
in_g
,
static_cast
<
T
>
(
0
));
auto
index_type
=
segment
->
type
();
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
SegmentPoolGradFunctor
<
DeviceContext
,
T
,
int
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
out_g
,
*
segment
,
in_g
,
summed_ids
,
pooltype
);
}
else
if
(
index_type
==
framework
::
proto
::
VarType
::
INT64
)
{
SegmentPoolGradFunctor
<
DeviceContext
,
T
,
int64_t
>
pool
;
pool
(
context
.
template
device_context
<
DeviceContext
>(),
*
input
,
*
output
,
*
out_g
,
*
segment
,
in_g
,
summed_ids
,
pooltype
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Unsupported index type, Expected int, int64, but got %s."
,
index_type
));
}
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_segment_ops.py
0 → 100644
浏览文件 @
f4c750d7
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
sys
from
op_test
import
OpTest
def
compute_segment_sum
(
x
,
segment_ids
):
length
=
segment_ids
[
-
1
]
+
1
target_shape
=
list
(
x
.
shape
)
target_shape
[
0
]
=
length
results
=
np
.
zeros
(
target_shape
,
dtype
=
x
.
dtype
)
for
index
,
ids
in
enumerate
(
segment_ids
):
results
[
ids
,
:]
+=
x
[
index
,
:]
return
results
def
compute_segment_mean
(
x
,
segment_ids
):
length
=
segment_ids
[
-
1
]
+
1
target_shape
=
list
(
x
.
shape
)
target_shape
[
0
]
=
length
results
=
np
.
zeros
(
target_shape
,
dtype
=
x
.
dtype
)
count
=
np
.
zeros
(
length
,
dtype
=
x
.
dtype
)
+
1e-8
for
index
,
ids
in
enumerate
(
segment_ids
):
results
[
ids
,
:]
+=
x
[
index
,
:]
count
[
ids
]
+=
1
results
=
results
/
count
.
reshape
([
-
1
,
1
])
return
results
def
compute_segment_min_max
(
x
,
segment_ids
,
pooltype
=
"MAX"
):
length
=
segment_ids
[
-
1
]
+
1
target_shape
=
list
(
x
.
shape
)
target_shape
[
0
]
=
length
gradient
=
np
.
zeros_like
(
x
)
results
=
np
.
zeros
(
target_shape
,
dtype
=
x
.
dtype
)
last_idx
=
0
current_id
=
segment_ids
[
0
]
for
idx
in
range
(
1
,
len
(
segment_ids
)
+
1
):
if
idx
<
len
(
segment_ids
):
if
segment_ids
[
idx
]
==
current_id
:
continue
sub_x
=
x
[
last_idx
:
idx
,
:]
if
pooltype
==
"MAX"
:
results
[
current_id
]
=
np
.
amax
(
sub_x
,
axis
=
0
)
elif
pooltype
==
"MIN"
:
results
[
current_id
]
=
np
.
amin
(
sub_x
,
axis
=
0
)
else
:
raise
ValueError
(
"Invalid pooltype, only MAX, MIN supported!"
)
gradient
[
last_idx
:
idx
,
:][
sub_x
==
results
[
current_id
]]
=
1
last_idx
=
idx
if
idx
<
len
(
segment_ids
):
current_id
=
segment_ids
[
idx
]
return
results
,
gradient
/
results
.
size
class
TestSegmentOps
(
OpTest
):
def
set_data
(
self
):
x
=
np
.
random
.
uniform
(
-
1
,
1
,
self
.
shape
).
astype
(
self
.
dtype
)
segment_ids
=
self
.
set_segment
(
len
(
x
),
len
(
x
)
//
5
+
1
)
return
x
,
segment_ids
def
set_segment
(
self
,
origin_len
,
reduce_len
):
segment
=
np
.
zeros
(
reduce_len
,
dtype
=
'int64'
)
segment
=
np
.
random
.
randint
(
0
,
reduce_len
,
size
=
[
origin_len
])
segment
=
np
.
sort
(
segment
)
return
segment
.
astype
(
'int64'
)
def
compute
(
self
,
x
,
segment_ids
):
return
compute_segment_sum
(
x
,
segment_ids
)
def
prepare
(
self
):
self
.
op_type
=
"segment_pool"
self
.
dtype
=
np
.
float64
self
.
shape
=
[
30
,
15
]
self
.
attrs
=
{
"pooltype"
:
"SUM"
}
def
setUp
(
self
):
self
.
prepare
()
x
,
segment_ids
=
self
.
set_data
()
result
=
self
.
compute
(
x
,
segment_ids
)
self
.
inputs
=
{
'X'
:
x
.
astype
(
self
.
dtype
),
'SegmentIds'
:
segment_ids
.
astype
(
np
.
int64
)
}
self
.
outputs
=
{
'Out'
:
result
.
astype
(
self
.
dtype
)}
def
test_check_output
(
self
):
self
.
check_output
()
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
)
class
TestSegmentSum2
(
TestSegmentOps
):
def
prepare
(
self
):
super
(
TestSegmentSum2
,
self
).
prepare
()
self
.
shape
=
[
40
,
20
]
self
.
dtype
=
np
.
float32
def
setUp
(
self
):
self
.
prepare
()
x
,
segment_ids
=
self
.
set_data
()
result
=
self
.
compute
(
x
,
segment_ids
)
self
.
inputs
=
{
'X'
:
x
.
astype
(
self
.
dtype
),
'SegmentIds'
:
segment_ids
.
astype
(
np
.
int32
)
}
self
.
outputs
=
{
'Out'
:
result
.
astype
(
self
.
dtype
)}
class
TestSegmentMax
(
TestSegmentOps
):
def
compute
(
self
,
x
,
segment_ids
):
return
compute_segment_min_max
(
x
,
segment_ids
,
pooltype
=
"MAX"
)
def
prepare
(
self
):
super
(
TestSegmentMax
,
self
).
prepare
()
self
.
shape
=
[
40
,
20
]
self
.
attrs
=
{
'pooltype'
:
"MAX"
}
def
setUp
(
self
):
self
.
prepare
()
x
,
segment_ids
=
self
.
set_data
()
result
,
self
.
gradient
=
self
.
compute
(
x
,
segment_ids
)
self
.
inputs
=
{
'X'
:
x
.
astype
(
self
.
dtype
),
'SegmentIds'
:
segment_ids
.
astype
(
np
.
int32
)
}
self
.
outputs
=
{
'Out'
:
result
.
astype
(
self
.
dtype
)}
def
test_check_grad
(
self
):
self
.
check_grad
([
"X"
],
"Out"
,
user_defined_grads
=
[
self
.
gradient
])
class
TestSegmentMax2
(
TestSegmentMax
):
def
prepare
(
self
):
super
(
TestSegmentMax2
,
self
).
prepare
()
self
.
dtype
=
np
.
float32
class
TestSegmentMin
(
TestSegmentMax
):
def
compute
(
self
,
x
,
segment_ids
):
return
compute_segment_min_max
(
x
,
segment_ids
,
pooltype
=
"MIN"
)
def
prepare
(
self
):
super
(
TestSegmentMin
,
self
).
prepare
()
self
.
attrs
=
{
'pooltype'
:
"MIN"
}
class
TestSegmentMin2
(
TestSegmentMin
):
def
prepare
(
self
):
super
(
TestSegmentMin2
,
self
).
prepare
()
self
.
dtype
=
np
.
float32
class
TestSegmentMean
(
TestSegmentOps
):
def
compute
(
self
,
x
,
segment_ids
):
return
compute_segment_mean
(
x
,
segment_ids
)
def
prepare
(
self
):
super
(
TestSegmentMean
,
self
).
prepare
()
self
.
shape
=
[
40
,
20
]
self
.
attrs
=
{
'pooltype'
:
"MEAN"
}
def
setUp
(
self
):
self
.
prepare
()
x
,
segment_ids
=
self
.
set_data
()
result
=
self
.
compute
(
x
,
segment_ids
)
self
.
inputs
=
{
'X'
:
x
,
'SegmentIds'
:
segment_ids
}
self
.
outputs
=
{
'Out'
:
result
,
'SummedIds'
:
compute_segment_sum
(
np
.
ones
([
len
(
x
),
1
]).
astype
(
self
.
dtype
),
segment_ids
)
}
class
TestSegmentMean2
(
TestSegmentMean
):
def
prepare
(
self
):
super
(
TestSegmentMean2
,
self
).
prepare
()
self
.
dtype
=
np
.
float32
self
.
shape
=
[
30
,
20
]
self
.
attrs
=
{
'pooltype'
:
"MEAN"
}
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录