Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
afc6343e
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
afc6343e
编写于
11月 02, 2017
作者:
D
dangqingqing
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine sequence max-pooling and add unit testing of gradient check.
上级
dfe851a0
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
362 addition
and
31 deletion
+362
-31
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+2
-0
paddle/operators/math/CMakeLists.txt
paddle/operators/math/CMakeLists.txt
+2
-0
paddle/operators/math/sequence_pooling.cc
paddle/operators/math/sequence_pooling.cc
+103
-0
paddle/operators/math/sequence_pooling.cu
paddle/operators/math/sequence_pooling.cu
+136
-0
paddle/operators/math/sequence_pooling.h
paddle/operators/math/sequence_pooling.h
+45
-0
paddle/operators/sequence_pool_op.cc
paddle/operators/sequence_pool_op.cc
+18
-3
paddle/operators/sequence_pool_op.h
paddle/operators/sequence_pool_op.h
+21
-18
python/paddle/v2/framework/tests/test_seq_pool.py
python/paddle/v2/framework/tests/test_seq_pool.py
+35
-10
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
afc6343e
...
...
@@ -141,6 +141,7 @@ set(DEPS_OPS
pool_with_index_op
nccl_op
sequence_conv_op
sequence_pool_op
lstm_op
)
op_library
(
cond_op SRCS cond_op.cc DEPS framework_proto tensor operator net_op
)
...
...
@@ -153,6 +154,7 @@ if(WITH_GPU)
op_library
(
nccl_op DEPS nccl_common
)
endif
()
op_library
(
sequence_conv_op DEPS context_project
)
op_library
(
sequence_pool_op DEPS sequence_pooling
)
op_library
(
lstm_op DEPS sequence2batch lstm_compute
)
op_library
(
dynamic_recurrent_op SRCS dynamic_recurrent_op.cc rnn/recurrent_op_utils.cc
DEPS net_op tensor_array
)
...
...
paddle/operators/math/CMakeLists.txt
浏览文件 @
afc6343e
...
...
@@ -8,6 +8,7 @@ if(WITH_GPU)
nv_library
(
softmax SRCS softmax.cc softmax.cu DEPS operator
)
nv_library
(
cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator
)
nv_library
(
pooling SRCS pooling.cc pooling.cu DEPS device_context
)
nv_library
(
sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function
)
nv_library
(
vol2col SRCS vol2col.cc vol2col.cu DEPS device_context
)
nv_library
(
context_project SRCS context_project.cc context_project.cu DEPS device_context
)
nv_library
(
sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context
)
...
...
@@ -18,6 +19,7 @@ else()
cc_library
(
softmax SRCS softmax.cc DEPS operator
)
cc_library
(
cross_entropy SRCS cross_entropy.cc DEPS operator
)
cc_library
(
pooling SRCS pooling.cc DEPS device_context
)
nv_library
(
sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function
)
cc_library
(
vol2col SRCS vol2col.cc DEPS device_context
)
cc_library
(
context_project SRCS context_project.cc DEPS device_context
)
cc_library
(
sequence2batch SRCS sequence2batch.cc DEPS device_context
)
...
...
paddle/operators/math/sequence_pooling.cc
0 → 100644
浏览文件 @
afc6343e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/sequence_pooling.h"
#include "paddle/operators/math/math_function.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
template
<
typename
T
>
class
MaxSeqPoolFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
input
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
index
)
{
auto
in_dims
=
input
.
dims
();
auto
out_dims
=
output
->
dims
();
auto
idx_dims
=
index
->
dims
();
PADDLE_ENFORCE_GT
(
in_dims
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
out_dims
.
size
(),
1UL
);
for
(
size_t
i
=
1
;
i
<
in_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
in_dims
[
i
],
out_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
idx_dims
,
out_dims
);
auto
starts
=
input
.
lod
()[
0
];
const
T
*
in_data
=
input
.
data
<
T
>
();
T
*
out_data
=
output
->
data
<
T
>
();
int
*
max_index
=
index
->
data
<
int
>
();
int64_t
num_seq
=
out_dims
[
0
];
int64_t
dim
=
output
->
numel
()
/
num_seq
;
for
(
int64_t
i
=
0
;
i
<
num_seq
;
++
i
)
{
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
out_data
[
i
*
dim
+
k
]
=
in_data
[
starts
[
i
]
*
dim
+
k
];
max_index
[
i
*
dim
+
k
]
=
starts
[
i
];
}
for
(
size_t
j
=
starts
[
i
]
+
1
;
j
<
starts
[
i
+
1
];
++
j
)
{
for
(
int64_t
k
=
0
;
k
<
dim
;
++
k
)
{
if
(
in_data
[
j
*
dim
+
k
]
>
out_data
[
i
*
dim
+
k
])
{
out_data
[
i
*
dim
+
k
]
=
in_data
[
j
*
dim
+
k
];
max_index
[
i
*
dim
+
k
]
=
j
;
}
}
}
}
}
};
template
<
typename
T
>
class
MaxSeqPoolGradFunctor
<
platform
::
CPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
out_grad
,
const
framework
::
Tensor
&
index
,
framework
::
LoDTensor
*
in_grad
)
{
auto
og_dims
=
out_grad
.
dims
();
auto
ig_dims
=
in_grad
->
dims
();
auto
idx_dims
=
index
.
dims
();
PADDLE_ENFORCE_GT
(
og_dims
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
ig_dims
.
size
(),
1UL
);
for
(
size_t
i
=
1
;
i
<
og_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
og_dims
[
i
],
ig_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
idx_dims
,
og_dims
);
const
T
*
og_data
=
out_grad
.
data
<
T
>
();
const
int
*
max_index
=
index
.
data
<
int
>
();
T
*
ig_data
=
in_grad
->
data
<
T
>
();
SetConstant
<
platform
::
CPUPlace
,
T
>
set_zero
;
set_zero
(
context
,
in_grad
,
static_cast
<
T
>
(
0.0
));
int64_t
num_seq
=
og_dims
[
0
];
int64_t
dim
=
out_grad
.
numel
()
/
num_seq
;
for
(
size_t
i
=
0
;
i
<
num_seq
;
++
i
)
{
for
(
size_t
j
=
0
;
j
<
dim
;
++
j
)
{
int
step_id
=
max_index
[
i
*
dim
+
j
];
ig_data
[
step_id
*
dim
+
j
]
=
og_data
[
i
*
dim
+
j
];
}
}
}
};
template
class
MaxSeqPoolFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxSeqPoolFunctor
<
platform
::
CPUPlace
,
double
>;
template
class
MaxSeqPoolGradFunctor
<
platform
::
CPUPlace
,
float
>;
template
class
MaxSeqPoolGradFunctor
<
platform
::
CPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence_pooling.cu
0 → 100644
浏览文件 @
afc6343e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_pooling.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
#define FLT_MAX __FLT_MAX__
template
<
typename
T
>
__global__
void
KeMaxSequencePool
(
const
T
*
input
,
const
size_t
*
starts
,
T
*
output
,
int
*
index
,
int64_t
num_seq
,
int64_t
dim
)
{
int
dim_idx
=
threadIdx
.
x
;
int
seq_id
=
blockIdx
.
x
;
if
(
seq_id
>=
num_seq
)
return
;
size_t
start
=
starts
[
seq_id
];
size_t
end
=
starts
[
seq_id
+
1
];
for
(
int
i
=
dim_idx
;
i
<
dim
;
i
+=
blockDim
.
x
)
{
T
max_val
=
static_cast
<
T
>
(
-
FLT_MAX
);
int
max_id
=
-
1
;
for
(
size_t
step_id
=
start
;
step_id
<
end
;
step_id
++
)
{
if
(
max_val
<
input
[
step_id
*
dim
+
i
])
{
max_val
=
input
[
step_id
*
dim
+
i
];
max_id
=
step_id
;
}
}
output
[
seq_id
*
dim
+
i
]
=
max_val
;
index
[
seq_id
*
dim
+
i
]
=
max_id
;
}
}
template
<
typename
T
>
class
MaxSeqPoolFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
input
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
index
)
{
auto
in_dims
=
input
.
dims
();
auto
out_dims
=
output
->
dims
();
auto
idx_dims
=
index
->
dims
();
PADDLE_ENFORCE_GT
(
in_dims
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
out_dims
.
size
(),
1UL
);
for
(
size_t
i
=
1
;
i
<
in_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
in_dims
[
i
],
out_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
idx_dims
,
out_dims
);
auto
starts
=
input
.
lod
()[
0
];
const
T
*
in_data
=
input
.
data
<
T
>
();
T
*
out_data
=
output
->
data
<
T
>
();
int
*
max_index
=
index
->
data
<
int
>
();
int64_t
num_seq
=
out_dims
[
0
];
int64_t
dim
=
output
->
numel
()
/
num_seq
;
dim3
threads
(
256
,
1
);
dim3
grid
(
num_seq
,
1
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
KeMaxSequencePool
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
in_data
,
starts
.
data
(),
out_data
,
max_index
,
num_seq
,
dim
);
}
};
template
<
typename
T
>
__global__
void
KeMaxSequencePoolGrad
(
const
T
*
out_grad
,
const
int
*
max_index
,
T
*
in_grad
,
int64_t
num_seq
,
int64_t
dim
)
{
int
idx
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
col_idx
=
idx
%
dim
;
if
(
idx
<
num_seq
*
dim
)
{
int
step_id
=
max_index
[
idx
];
in_grad
[
step_id
*
dim
+
col_idx
]
=
out_grad
[
idx
];
}
}
template
<
typename
T
>
class
MaxSeqPoolGradFunctor
<
platform
::
GPUPlace
,
T
>
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
out_grad
,
const
framework
::
Tensor
&
index
,
framework
::
LoDTensor
*
in_grad
)
{
auto
og_dims
=
out_grad
.
dims
();
auto
idx_dims
=
index
.
dims
();
auto
ig_dims
=
in_grad
->
dims
();
PADDLE_ENFORCE_GT
(
og_dims
.
size
(),
1UL
);
PADDLE_ENFORCE_GT
(
ig_dims
.
size
(),
1UL
);
for
(
size_t
i
=
1
;
i
<
og_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
og_dims
[
i
],
ig_dims
[
i
]);
}
PADDLE_ENFORCE_EQ
(
idx_dims
,
og_dims
);
const
T
*
og_data
=
out_grad
.
data
<
T
>
();
const
int
*
max_index
=
index
.
data
<
int
>
();
T
*
ig_data
=
in_grad
->
data
<
T
>
();
SetConstant
<
platform
::
GPUPlace
,
T
>
set_zero
;
set_zero
(
context
,
in_grad
,
static_cast
<
T
>
(
0.0
));
int64_t
num_seq
=
og_dims
[
0
];
int64_t
dim
=
out_grad
.
numel
()
/
num_seq
;
unsigned
int
blocks
=
(
num_seq
*
dim
+
128
-
1
)
/
128
;
dim3
threads
(
128
,
1
);
dim3
grid
(
blocks
,
1
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
).
stream
();
KeMaxSequencePoolGrad
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
og_data
,
max_index
,
ig_data
,
num_seq
,
dim
);
}
};
template
class
MaxSeqPoolFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
MaxSeqPoolFunctor
<
platform
::
GPUPlace
,
double
>;
template
class
MaxSeqPoolGradFunctor
<
platform
::
GPUPlace
,
float
>;
template
class
MaxSeqPoolGradFunctor
<
platform
::
GPUPlace
,
double
>;
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/math/sequence_pooling.h
0 → 100644
浏览文件 @
afc6343e
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
namespace
math
{
#define FLT_MAX __FLT_MAX__
template
<
typename
Place
,
typename
T
>
class
MaxSeqPoolFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
LoDTensor
&
input
,
framework
::
Tensor
*
output
,
framework
::
Tensor
*
index
);
};
template
<
typename
Place
,
class
T
>
class
MaxSeqPoolGradFunctor
{
public:
void
operator
()(
const
platform
::
DeviceContext
&
context
,
const
framework
::
Tensor
&
out_grad
,
const
framework
::
Tensor
&
index
,
framework
::
LoDTensor
*
in_grad
);
};
}
// namespace math
}
// namespace operators
}
// namespace paddle
paddle/operators/sequence_pool_op.cc
浏览文件 @
afc6343e
...
...
@@ -27,6 +27,11 @@ class SequencePoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SequencePoolOp should not be null."
);
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"X"
));
if
(
ctx
->
Attrs
().
Get
<
std
::
string
>
(
"pooltype"
)
==
"MAX"
)
{
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"MaxIndex"
),
"Output(MaxIndex) of SequencePoolOp should not be null."
);
ctx
->
SetOutputDim
(
"MaxIndex"
,
ctx
->
GetInputDim
(
"X"
));
}
}
};
...
...
@@ -35,13 +40,17 @@ class SequencePoolOpMaker : public framework::OpProtoAndCheckerMaker {
SequencePoolOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(LoDTensor)
, t
he variable-length input of SequencePoolOp"
);
AddInput
(
"X"
,
"(LoDTensor)
T
he variable-length input of SequencePoolOp"
);
AddOutput
(
"Out"
,
"(Tensor)
, output of SequencePoolOp, which
does not contain LoD "
"(Tensor)
The output of SequencePoolOp
does not contain LoD "
"infomation."
);
AddOutput
(
"MaxIndex"
,
"(Tensor<int>) This tensor is used for the max-pooling "
"of sequence to record the max indexes."
)
.
AsIntermediate
();
AddAttr
<
std
::
string
>
(
"pooltype"
,
"(int, default AVERAGE)
t
he pooling pooltype of SequencePoolOp."
)
"(int, default AVERAGE)
T
he pooling pooltype of SequencePoolOp."
)
.
SetDefault
(
"AVERAGE"
);
AddComment
(
R"DOC(
SequencePoolOp pools features of all time-steps of each instance.
...
...
@@ -92,6 +101,12 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
x_dims
);
}
protected:
framework
::
DataType
IndicateDataType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
ToDataType
(
ctx
.
Input
<
Tensor
>
(
"X"
)
->
type
());
}
};
}
// namespace operators
...
...
paddle/operators/sequence_pool_op.h
浏览文件 @
afc6343e
...
...
@@ -16,6 +16,7 @@ limitations under the License. */
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
#include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/sequence_pooling.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -34,7 +35,7 @@ class SequencePoolKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoD
Tensor
>
(
"Out"
);
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
auto
dims
=
in
->
dims
();
...
...
@@ -53,6 +54,16 @@ class SequencePoolKernel : public framework::OpKernel<T> {
auto
lod_level_0
=
lod
[
0
];
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
pooltype
==
"MAX"
)
{
math
::
MaxSeqPoolFunctor
<
Place
,
T
>
max_pool
;
auto
*
index
=
context
.
Output
<
Tensor
>
(
"MaxIndex"
);
index
->
Resize
({
dims
});
index
->
mutable_data
<
int
>
(
context
.
GetPlace
());
max_pool
(
context
.
device_context
(),
*
in
,
out
,
index
);
return
;
}
auto
place
=
context
.
GetEigenDevice
<
Place
>
();
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
lod_level_0
.
size
())
-
1
;
++
i
)
{
Tensor
in_t
=
in
->
Slice
(
static_cast
<
int
>
(
lod_level_0
[
i
]),
...
...
@@ -69,8 +80,6 @@ class SequencePoolKernel : public framework::OpKernel<T> {
}
else
if
(
pooltype
==
"SQRT"
)
{
out_e
.
device
(
place
)
=
in_e
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}))
/
std
::
sqrt
(
static_cast
<
T
>
(
h
));
}
else
if
(
pooltype
==
"MAX"
)
{
out_e
.
device
(
place
)
=
in_e
.
maximum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
}
else
if
(
pooltype
==
"LAST"
)
{
out_e
.
device
(
place
)
=
in_e
.
chip
(
h
-
1
,
0
);
}
else
if
(
pooltype
==
"FIRST"
)
{
...
...
@@ -87,8 +96,8 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
in
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out_g
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
in_g
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
out_g
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
std
::
string
pooltype
=
context
.
Attr
<
std
::
string
>
(
"pooltype"
);
auto
dims
=
in
->
dims
();
...
...
@@ -96,6 +105,14 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
int64_t
w
=
in
->
numel
()
/
dims
[
0
];
in_g
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
pooltype
==
"MAX"
)
{
math
::
MaxSeqPoolGradFunctor
<
Place
,
T
>
max_pool_grad
;
auto
*
index
=
context
.
Input
<
Tensor
>
(
"MaxIndex"
);
max_pool_grad
(
context
.
device_context
(),
*
out_g
,
*
index
,
in_g
);
return
;
}
if
(
pooltype
==
"LAST"
||
pooltype
==
"FIRST"
)
{
// set X@Grad be zero at first when pooltype is LAST/FIRST
math
::
SetConstant
<
Place
,
T
>
functor
;
...
...
@@ -118,20 +135,6 @@ class SequencePoolGradKernel : public framework::OpKernel<T> {
}
else
if
(
pooltype
==
"SQRT"
)
{
in_g_e
.
device
(
place
)
=
(
out_g_e
/
std
::
sqrt
(
static_cast
<
T
>
(
h
))).
broadcast
(
bcast
);
}
else
if
(
pooltype
==
"MAX"
)
{
auto
in_t
=
in
->
Slice
(
static_cast
<
int
>
(
lod
[
i
]),
static_cast
<
int
>
(
lod
[
i
+
1
]));
Eigen
::
Map
<
const
Eigen
::
Matrix
<
T
,
Eigen
::
Dynamic
,
Eigen
::
Dynamic
>>
in_t_map
(
in_t
.
data
<
T
>
(),
h
,
w
);
int
row_id
;
Eigen
::
array
<
int
,
2
>
extents
{{
1
,
1
}};
for
(
int
col_id
=
0
;
col_id
<
w
;
col_id
++
)
{
in_t_map
.
col
(
col_id
).
maxCoeff
(
&
row_id
);
Eigen
::
array
<
int
,
2
>
in_offsets
{{
row_id
,
col_id
}};
Eigen
::
array
<
int
,
2
>
out_offsets
{{
0
,
col_id
}};
in_g_e
.
slice
(
in_offsets
,
extents
).
device
(
place
)
=
out_g_e
.
slice
(
out_offsets
,
extents
);
}
}
else
if
(
pooltype
==
"LAST"
)
{
in_g_e
.
chip
(
h
-
1
,
0
).
device
(
place
)
=
out_g_e
;
}
else
if
(
pooltype
==
"FIRST"
)
{
...
...
python/paddle/v2/framework/tests/test_seq_pool.py
浏览文件 @
afc6343e
...
...
@@ -29,6 +29,9 @@ class TestSeqAvgPool(OpTest):
self
.
check_output
()
def
test_check_grad
(
self
):
# Remove MaxIndex after check_grad is refined.
self
.
outputs
[
'MaxIndex'
]
=
\
np
.
zeros
(
self
.
outputs
[
'Out'
].
shape
).
astype
(
'int32'
)
self
.
check_grad
([
"X"
],
"Out"
)
...
...
@@ -85,31 +88,53 @@ class TestSeqSqrtPool2D(TestSeqAvgPool2D):
out
[
i
]
=
np
.
reshape
(
sub_x
.
sum
(
axis
=
0
)
/
np
.
sqrt
(
len
),
(
3
,
17
))
def
test_check_grad
(
self
):
# Remove MaxIndex after check_grad is refined.
self
.
outputs
[
'MaxIndex'
]
=
\
np
.
zeros
(
self
.
outputs
[
'Out'
].
shape
).
astype
(
'int32'
)
self
.
check_grad
([
"X"
],
"Out"
,
max_relative_error
=
0.06
)
class
TestSeqMaxPool
(
TestSeqAvgPool
):
def
set_data
(
self
):
self
.
op_type
=
'sequence_pool'
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
23
]).
astype
(
'float32'
)
lod
=
[[
0
,
4
,
5
,
8
,
13
]]
for
i
in
range
(
4
):
l
=
lod
[
0
][
i
+
1
]
-
lod
[
0
][
i
]
x
[
lod
[
0
][
i
]
+
np
.
random
.
randint
(
l
),
:]
+=
2.0
self
.
inputs
=
{
'X'
:
(
x
,
lod
)}
out
=
np
.
zeros
((
4
,
23
)).
astype
(
'float32'
)
self
.
outputs
=
{
'Out'
:
out
}
return
x
,
lod
,
out
def
compute
(
self
,
x
,
lod
,
out
):
self
.
attrs
=
{
'pooltype'
:
"MAX"
}
for
i
in
range
(
4
):
sub_x
=
x
[
lod
[
0
][
i
]:
lod
[
0
][
i
+
1
],
:]
out
[
i
]
=
np
.
amax
(
sub_x
,
axis
=
0
)
def
test_check_grad
(
self
):
# Remove MaxPool2D from gradient check to confirm the success of CI.
return
class
TestSeqMaxPool2D
(
TestSeqAvgPool2D
):
def
set_data
(
self
):
self
.
op_type
=
'sequence_pool'
x
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
3
,
11
]).
astype
(
'float32'
)
lod
=
[[
0
,
4
,
5
,
8
,
13
]]
self
.
inputs
=
{
'X'
:
(
x
,
lod
)}
for
i
in
range
(
4
):
l
=
lod
[
0
][
i
+
1
]
-
lod
[
0
][
i
]
x
[
lod
[
0
][
i
]
+
np
.
random
.
randint
(
l
),
:]
+=
1.0
out
=
np
.
zeros
((
4
,
3
,
11
)).
astype
(
'float32'
)
self
.
outputs
=
{
'Out'
:
out
}
return
x
,
lod
,
out
def
compute
(
self
,
x
,
lod
,
out
):
self
.
attrs
=
{
'pooltype'
:
"MAX"
}
for
i
in
range
(
4
):
sub_x
=
np
.
reshape
(
x
[
lod
[
0
][
i
]:
lod
[
0
][
i
+
1
],
:],
(
-
1
,
3
*
17
))
out
[
i
]
=
np
.
reshape
(
np
.
amax
(
sub_x
,
axis
=
0
),
(
3
,
17
))
def
test_check_grad
(
self
):
# Remove MaxPool2D from gradient check to confirm the success of CI.
return
sub_x
=
np
.
reshape
(
x
[
lod
[
0
][
i
]:
lod
[
0
][
i
+
1
],
:],
(
-
1
,
3
*
11
))
out
[
i
]
=
np
.
reshape
(
np
.
amax
(
sub_x
,
axis
=
0
),
(
3
,
11
))
class
TestSeqLastPool
(
TestSeqAvgPool
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录