Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
381c6a02
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
381c6a02
编写于
3月 20, 2018
作者:
Y
Yang yaming
提交者:
GitHub
3月 20, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #9100 from pkuyym/fix-9049
Enhance sequence_expand operator
上级
5271c32d
2c225525
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
305 addition
and
154 deletion
+305
-154
paddle/fluid/operators/math/math_function.cc
paddle/fluid/operators/math/math_function.cc
+2
-0
paddle/fluid/operators/math/math_function.cu
paddle/fluid/operators/math/math_function.cu
+2
-0
paddle/fluid/operators/sequence_expand_op.cc
paddle/fluid/operators/sequence_expand_op.cc
+124
-62
paddle/fluid/operators/sequence_expand_op.cu
paddle/fluid/operators/sequence_expand_op.cu
+9
-2
paddle/fluid/operators/sequence_expand_op.h
paddle/fluid/operators/sequence_expand_op.h
+101
-47
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+26
-23
python/paddle/fluid/tests/book/test_machine_translation.py
python/paddle/fluid/tests/book/test_machine_translation.py
+3
-3
python/paddle/fluid/tests/unittests/test_layers.py
python/paddle/fluid/tests/unittests/test_layers.py
+2
-2
python/paddle/fluid/tests/unittests/test_sequence_expand.py
python/paddle/fluid/tests/unittests/test_sequence_expand.py
+36
-15
未找到文件。
paddle/fluid/operators/math/math_function.cc
浏览文件 @
381c6a02
...
...
@@ -371,6 +371,8 @@ template struct RowwiseAdd<platform::CPUDeviceContext, double>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
int
>;
template
struct
ColwiseSum
<
platform
::
CPUDeviceContext
,
int64_t
>;
template
struct
RowwiseSum
<
platform
::
CPUDeviceContext
,
float
>;
template
struct
RowwiseSum
<
platform
::
CPUDeviceContext
,
double
>;
...
...
paddle/fluid/operators/math/math_function.cu
浏览文件 @
381c6a02
...
...
@@ -422,6 +422,8 @@ struct RowwiseAdd<platform::CUDADeviceContext, T> {
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
RowwiseAdd
<
platform
::
CUDADeviceContext
,
double
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
float
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
int
>;
template
struct
ColwiseSum
<
platform
::
CUDADeviceContext
,
int64_t
>;
// template struct ColwiseSum<platform::CUDADeviceContext, double>;
// The ColwiseSum<platform::CUDADeviceContext, double> failed in debug mode,
// and only failed for this case. So reimplemented it.
...
...
paddle/fluid/operators/sequence_expand_op.cc
浏览文件 @
381c6a02
...
...
@@ -17,7 +17,7 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
using
framework
::
LoD
Tensor
;
class
SequenceExpandOp
:
public
framework
::
OperatorWithKernel
{
public:
...
...
@@ -25,15 +25,71 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
));
framework
::
DDim
out_dim
;
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
out_dim
=
ctx
->
GetInputDim
(
"X"
);
out_dim
[
0
]
=
y_dim
[
0
];
ctx
->
ShareLoD
(
"Y"
,
"Out"
);
ctx
->
SetOutputDim
(
"Out"
,
out_dim
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) of SequenceExpandOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) of SequenceExpandOp should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
),
"Output(Out) of SequenceExpandOp should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
out_dims
=
x_dims
;
int
ref_level
=
ctx
->
Attrs
().
Get
<
int
>
(
"ref_level"
);
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
2
,
"Dimension number of Input(X) should be at least 2."
);
if
(
ctx
->
IsRuntime
())
{
framework
::
Variable
*
x_var
=
boost
::
get
<
framework
::
Variable
*>
(
ctx
->
GetInputVarPtrs
(
"X"
)[
0
]);
framework
::
Variable
*
y_var
=
boost
::
get
<
framework
::
Variable
*>
(
ctx
->
GetInputVarPtrs
(
"Y"
)[
0
]);
auto
&
x_lod
=
x_var
->
Get
<
LoDTensor
>
().
lod
();
auto
&
y_lod
=
y_var
->
Get
<
LoDTensor
>
().
lod
();
PADDLE_ENFORCE_LE
(
x_lod
.
size
(),
1
,
"Level number of Input(X)'s lod should not be "
"greater than 1."
);
PADDLE_ENFORCE_GT
(
y_lod
.
size
(),
0
,
"Level number of Input(Y)'s lod should be "
"greater than 0."
);
PADDLE_ENFORCE
(
ref_level
==
-
1
||
(
ref_level
>=
0
&&
ref_level
<
static_cast
<
int
>
(
y_lod
.
size
())),
"Invlid `ref_level`, which should be either equal to -1 "
"or in [0, %d)"
,
y_lod
.
size
());
if
(
ref_level
==
-
1
)
ref_level
=
y_lod
.
size
()
-
1
;
if
(
x_lod
.
size
()
>
0
)
{
PADDLE_ENFORCE
(
x_lod
[
0
].
size
()
==
y_lod
[
ref_level
].
size
(),
"Level number of Input(X)'s lod could be 0. Otherwise "
"size of Input(X)'s first level lod should be equal to "
"size of Input(Y)'s referred level lod."
);
}
int64_t
out_first_dim
=
0
;
if
(
y_lod
[
ref_level
].
size
()
<=
1
)
{
out_first_dim
=
x_dims
[
0
];
}
else
{
for
(
size_t
i
=
1
;
i
<
y_lod
[
ref_level
].
size
();
++
i
)
{
int
x_seq_len
=
1
;
if
(
x_lod
.
size
()
==
1
)
{
x_seq_len
=
x_lod
[
0
][
i
]
-
x_lod
[
0
][
i
-
1
];
}
out_first_dim
+=
(
y_lod
[
ref_level
][
i
]
-
y_lod
[
ref_level
][
i
-
1
])
*
x_seq_len
;
}
}
out_dims
[
0
]
=
out_first_dim
;
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
}
else
{
out_dims
[
0
]
=
-
1
;
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
};
...
...
@@ -42,83 +98,81 @@ class SequenceExpandOpMaker : public framework::OpProtoAndCheckerMaker {
SequenceExpandOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"(
Tensor or LoDTensor) The input(X) of this operator can be a
"
"
LoDTensor or a base Tensor
."
);
"(
LoDTensor, default LoDTensor<float>) A 2-D LoDTensor whose lod
"
"
level is at most 1
."
);
AddInput
(
"Y"
,
"(LoDTensor)The reference input(Y) of sequence_expand op."
"It must be a LoDTensor with k-level(k>0)."
"The input(X) will be expanded according to LOD of input(Y)."
"The element numbers of last level in input(Y) "
"must be equal to dims[0] of input(X)."
);
"(LoDTensor, default LoDTensor<float>) Referred LoDTensor whose "
"lod (specified level) is referred by Input(X)."
);
AddOutput
(
"Out"
,
"(LodTensor)The output of sequence_expand op."
"The lod of output will be as same as input(Y)'s lod."
);
"(LodTensor, default LoDTensor<float>) Output LoDTensor which is "
"generated from Input(X) by referring lod of Input(Y)."
);
AddAttr
<
int
>
(
"ref_level"
,
"Specify lod level of Input(Y)."
).
SetDefault
(
-
1
);
AddComment
(
R"DOC(
Sequence Expand Operator.
This operator expands input(X) according to LOD of input(Y).
This operator expands `X` according to specified level lod of `Y`. Current
implementation constaints that lod level of `X` should be at most 1. Attribute
`ref_level` is used to specify which level lod of `Y` is referred to expand `X`.
If set `ref_level` to -1, then last level lod of `Y` would be referred.
Please note, rank of `X` should be at least 2, when the rank exceeds 2, `X`
would be viewed as a 2-D tensor.
Following are cases to better explain how this works:
Case 1:
Given a 2-level LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
Given a 1-level LoDTensor input(X)
X.lod = [[0, 2, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
Out.data = [a, a, a, b, b, b, c, d]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 2, 4, 6, 8]]
Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
Out.dims = [8, 1]
Case 2:
Given 1-level LoDTensor input(X)
X.lod = [[0, 1, 4]]
X.data = [[a], [b], [c], [d]]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
Out.lod = [[0, 1, 2, 5, 8]]
Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
Out.dims = [8, 1]
Case 3:
Given a common Tensor input(X)
X.data = [
a, b, c
]
X.data = [
[a], [b], [c]
]
X.dims = [3, 1]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [a, a, b, c, c, c]
ref_level: -1
then we get a common Tensor
Out.data = [[a], [a], [b], [c], [c], [c]]
Out.dims = [6, 1]
Case
3
:
Case
4
:
Given a common Tensor input(X)
X.data = [[a, b], [c, d], [e, f]]
X.dims = [3, 2]
and input(Y)
Y.lod = [[0, 2, 3, 6]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 1-level LoDTensor
Out.lod = [[0, 2, 3, 6]]
Out.data = [[a,b], [a,b] [c,d], [e, f], [e, f], [e, f]]
ref_level: 0
then we get a common LoDTensor
Out.data = [[a, b], [a, b] [c, d], [e, f], [e, f], [e, f]]
Out.dims = [6, 2]
Case 4:
Given 2-level a LoDTensor input(X)
X.lod = [[0, 2, 3],
[0, 1, 3, 4]]
X.data = [a, b, c, d]
X.dims = [4, 1]
and input(Y)
Y.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
with condition len(Y.lod[-1]) -1 == X.dims[0]
then we get 2-level LoDTensor
Out.lod = [[0, 2, 4],
[0, 3, 6, 6, 8]]
Out.data = [a, a, a, b, b, b, d, d]
Out.dims = [8, 1]
)DOC"
);
}
};
...
...
@@ -129,12 +183,14 @@ class SequenceExpandOpGrad : public framework::OperatorWithKernel {
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
));
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
)
,
"Input(X) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Out"
)
,
"Input(Out) should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"The input(Out@GRAD) should not be null"
);
"Input(Out@GRAD) should not be null."
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
x_dims
);
}
...
...
@@ -149,7 +205,13 @@ REGISTER_OP(sequence_expand, ops::SequenceExpandOp, ops::SequenceExpandOpMaker,
sequence_expand_grad
,
ops
::
SequenceExpandOpGrad
);
REGISTER_OP_CPU_KERNEL
(
sequence_expand
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
REGISTER_OP_CPU_KERNEL
(
sequence_expand_grad
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
);
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_expand_op.cu
浏览文件 @
381c6a02
...
...
@@ -18,7 +18,14 @@ limitations under the License. */
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_CUDA_KERNEL
(
sequence_expand
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SequenceExpandKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
REGISTER_OP_CUDA_KERNEL
(
sequence_expand_grad
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
);
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
float
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
double
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
ops
::
SequenceExpandGradKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_expand_op.h
浏览文件 @
381c6a02
...
...
@@ -16,45 +16,75 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "
unsupported/Eigen/CXX11/Tensor
"
#include "
paddle/fluid/operators/math/math_function.h
"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
,
int
MajorType
=
Eigen
::
RowMajor
,
typename
IndexType
=
Eigen
::
DenseIndex
>
using
EigenMatrix
=
framework
::
EigenMatrix
<
T
,
MajorType
,
IndexType
>
;
template
<
typename
DeviceContext
,
typename
T
>
class
SequenceExpandKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
x_dims
=
x
->
dims
();
auto
*
y
=
context
.
Input
<
LoDTensor
>
(
"Y"
);
PADDLE_ENFORCE
(
!
y
->
lod
().
empty
(),
"y should have lod"
);
PADDLE_ENFORCE_EQ
(
static_cast
<
size_t
>
(
x_dims
[
0
]),
y
->
lod
().
back
().
size
()
-
1
,
"The size of last lod level in Input(Y)"
"must be equal to dims[0] of Input(X)."
);
out
->
set_lod
(
y
->
lod
());
auto
*
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
size_t
element_len
=
framework
::
product
(
x_dims
)
/
x_dims
[
0
];
T
*
out_data
=
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
out_starts
=
out
->
lod
().
back
();
for
(
size_t
i
=
0
;
i
<
out_starts
.
size
()
-
1
;
i
++
)
{
int
scale
=
out_starts
[
i
+
1
]
-
out_starts
[
i
];
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
x_t
(
x_data
,
1
,
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
out_t
(
out_data
,
scale
,
element_len
);
Eigen
::
array
<
int
,
2
>
cast
({{
scale
,
1
}});
out_t
.
device
(
*
place
)
=
x_t
.
broadcast
(
cast
);
x_data
+=
element_len
;
out_data
+=
element_len
*
scale
;
auto
*
out
=
context
.
Output
<
LoDTensor
>
(
"Out"
);
int
ref_level
=
context
.
Attr
<
int
>
(
"ref_level"
);
auto
&
x_lod
=
x
->
lod
();
auto
&
y_lod
=
y
->
lod
();
if
(
ref_level
==
-
1
)
ref_level
=
y_lod
.
size
()
-
1
;
out
->
mutable_data
<
T
>
(
context
.
GetPlace
());
if
(
y_lod
[
ref_level
].
size
()
<=
1
)
{
framework
::
TensorCopy
(
*
x
,
context
.
GetPlace
(),
out
);
return
;
}
auto
&
out_lod
=
*
out
->
mutable_lod
();
if
(
x_lod
.
size
()
==
1
)
{
out_lod
.
resize
(
1
);
out_lod
[
0
]
=
{
0
};
}
int
out_offset
=
0
;
auto
&
eigen_place
=
*
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
for
(
size_t
i
=
1
;
i
<
y_lod
[
ref_level
].
size
();
++
i
)
{
int
repeat_num
=
y_lod
[
ref_level
][
i
]
-
y_lod
[
ref_level
][
i
-
1
];
int
x_start
=
i
-
1
;
int
x_end
=
i
;
if
(
x_lod
.
size
()
==
1
)
{
x_start
=
x_lod
[
0
][
i
-
1
];
x_end
=
x_lod
[
0
][
i
];
}
int
x_seq_len
=
x_end
-
x_start
;
if
(
repeat_num
>
0
)
{
auto
x_sub_tensor
=
x
->
Slice
(
x_start
,
x_end
);
x_sub_tensor
.
Resize
({
1
,
x_sub_tensor
.
numel
()});
int
out_start
=
out_offset
;
if
(
x_lod
.
size
()
==
1
)
{
out_start
=
out_lod
[
0
][
out_offset
];
}
auto
out_sub_tensor
=
out
->
Slice
(
out_start
,
out_start
+
x_seq_len
*
repeat_num
);
out_sub_tensor
.
Resize
({
repeat_num
,
x_sub_tensor
.
dims
()[
1
]});
EigenMatrix
<
T
>::
From
(
out_sub_tensor
).
device
(
eigen_place
)
=
EigenMatrix
<
T
>::
From
(
x_sub_tensor
)
.
broadcast
(
Eigen
::
array
<
int
,
2
>
({{
repeat_num
,
1
}}));
}
for
(
int
j
=
0
;
j
<
repeat_num
;
++
j
)
{
if
(
x_lod
.
size
()
==
1
)
{
out_lod
[
0
].
push_back
(
out_lod
[
0
].
back
()
+
x_seq_len
);
}
out_offset
++
;
}
}
}
};
...
...
@@ -75,27 +105,51 @@ template <typename DeviceContext, typename T>
class
SequenceExpandGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
d
_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
g
_out
=
context
.
Input
<
LoDTensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
auto
*
out
=
context
.
Input
<
LoDTensor
>
(
"Out"
);
auto
*
d_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
out_last_level
=
out
->
lod
().
back
();
d_x
->
set_lod
(
x
->
lod
());
const
T
*
d_out_data
=
d_out
->
data
<
T
>
();
T
*
d_x_data
=
d_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
size_t
element_len
=
d_out
->
numel
()
/
d_out
->
dims
()[
0
];
for
(
size_t
i
=
0
;
i
<
out_last_level
.
size
()
-
1
;
++
i
)
{
size_t
repeat
=
out_last_level
[
i
+
1
]
-
out_last_level
[
i
];
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
const
T
,
2
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
d_out_t
(
d_out_data
,
static_cast
<
int
>
(
repeat
),
element_len
);
Eigen
::
TensorMap
<
Eigen
::
Tensor
<
T
,
1
,
Eigen
::
RowMajor
,
Eigen
::
DenseIndex
>>
d_x_t
(
d_x_data
,
static_cast
<
int
>
(
element_len
));
auto
place
=
context
.
template
device_context
<
DeviceContext
>().
eigen_device
();
d_x_t
.
device
(
*
place
)
=
d_out_t
.
sum
(
Eigen
::
array
<
int
,
1
>
({{
0
}}));
d_out_data
+=
(
repeat
*
element_len
);
d_x_data
+=
element_len
;
auto
*
y
=
context
.
Input
<
LoDTensor
>
(
"Y"
);
auto
*
g_x
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
int
ref_level
=
context
.
Attr
<
int
>
(
"ref_level"
);
g_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
g_x
->
set_lod
(
x
->
lod
());
auto
&
x_lod
=
x
->
lod
();
auto
&
y_lod
=
y
->
lod
();
if
(
ref_level
==
-
1
)
ref_level
=
y_lod
.
size
()
-
1
;
// just copy the gradient
if
(
y_lod
[
ref_level
].
size
()
<=
1
)
{
framework
::
TensorCopy
(
*
g_out
,
context
.
GetPlace
(),
g_x
);
return
;
}
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
math
::
SetConstant
<
DeviceContext
,
T
>
set_zero
;
set_zero
(
dev_ctx
,
g_x
,
static_cast
<
T
>
(
0
));
int
g_out_offset
=
0
;
for
(
size_t
i
=
1
;
i
<
y_lod
[
ref_level
].
size
();
++
i
)
{
int
repeat_num
=
y_lod
[
ref_level
][
i
]
-
y_lod
[
ref_level
][
i
-
1
];
if
(
repeat_num
>
0
)
{
int
x_start
=
i
-
1
;
int
x_end
=
i
;
if
(
x_lod
.
size
()
==
1
)
{
x_start
=
x_lod
[
0
][
i
-
1
];
x_end
=
x_lod
[
0
][
i
];
}
int
x_seq_len
=
x_end
-
x_start
;
auto
g_x_sub
=
g_x
->
Slice
(
x_start
,
x_end
);
g_x_sub
.
Resize
(
flatten_to_1d
(
g_x_sub
.
dims
()));
int
g_out_end
=
g_out_offset
+
repeat_num
*
x_seq_len
;
auto
g_out_sub
=
g_out
->
Slice
(
g_out_offset
,
g_out_end
);
g_out_sub
.
Resize
({
repeat_num
,
g_x_sub
.
dims
()[
0
]});
math
::
ColwiseSum
<
DeviceContext
,
T
>
col_sum
;
col_sum
(
dev_ctx
,
g_out_sub
,
&
g_x_sub
);
g_out_offset
+=
repeat_num
*
x_seq_len
;
}
}
}
};
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
381c6a02
...
...
@@ -1809,52 +1809,52 @@ def conv2d_transpose(input,
return
out
def
sequence_expand
(
x
,
y
,
name
=
None
):
def
sequence_expand
(
x
,
y
,
ref_level
=-
1
,
name
=
None
):
"""Sequence Expand Layer. This layer will expand the input variable **x**
according to LoD information of **y**. And the following examples will
explain how sequence_expand works:
according to specified level lod of **y**. Please note that lod level of
**x** is at most 1 and rank of **x** is at least 2. When rank of **x**
is greater than 2, then it would be viewed as a 2-D tensor.
Following examples will explain how sequence_expand works:
.. code-block:: text
* Case 1
x is a LoDTensor:
x.lod = [[0, 2, 3],
[0, 1, 3, 4]]
x.data = [a, b, c, d]
x.lod = [[0, 2, 4]]
x.data = [[a], [b], [c], [d]]
x.dims = [4, 1]
y is a LoDTensor:
y.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
with condition len(y.lod[-1]) - 1 == x.dims[0]
ref_level: 0
then output is a 2-level LoDTensor:
out.lod = [[0, 2, 4],
[0, 3, 6, 7, 8]]
out.data = [a, a, a, b, b, b, c, d]
then output is a 1-level LoDTensor:
out.lod = [[0, 2, 4, 6, 8]]
out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
out.dims = [8, 1]
* Case 2
x is a Tensor:
x.data = [
a, b, c
]
x.data = [
[a], [b], [c]
]
x.dims = [3, 1]
y is a LoDTensor:
y.lod = [[0, 2,
3, 6
]]
y.lod = [[0, 2,
2, 5
]]
with condition len(y.lod[-1]) - 1 == x.dims[0]
then output is a 1-level LoDTensor:
out.lod = [[0, 2, 3, 6]]
out.data = [a, a, b, c, c, c]
out.dims = [6, 1]
ref_level: -1
then output is a Tensor:
out.data = [[a], [a], [c], [c], [c]]
out.dims = [5, 1]
Args:
x (Variable): The input variable which is a Tensor or LoDTensor.
y (Variable): The input variable which is a LoDTensor.
ref_level (int): Lod level of `y` to be referred by `x`. If set to -1,
refer the last level of lod.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
will be named automatically.
Returns:
Variable: The expanded variable which is a LoDTensor.
...
...
@@ -1865,14 +1865,17 @@ def sequence_expand(x, y, name=None):
x = fluid.layers.data(name='x', shape=[10], dtype='float32')
y = fluid.layers.data(name='y', shape=[10, 20],
dtype='float32', lod_level=1)
out = layers.sequence_expand(x=x, y=y)
out = layers.sequence_expand(x=x, y=y
, ref_level=0
)
"""
helper
=
LayerHelper
(
'sequence_expand'
,
input
=
x
,
**
locals
())
dtype
=
helper
.
input_dtype
()
tmp
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
'sequence_expand'
,
inputs
=
{
'X'
:
x
,
'Y'
:
y
},
outputs
=
{
'Out'
:
tmp
})
type
=
'sequence_expand'
,
inputs
=
{
'X'
:
x
,
'Y'
:
y
},
outputs
=
{
'Out'
:
tmp
},
attrs
=
{
'ref_level'
:
ref_level
})
return
tmp
...
...
python/paddle/fluid/tests/book/test_machine_translation.py
浏览文件 @
381c6a02
...
...
@@ -118,12 +118,12 @@ def decoder_decode(context, is_sparse):
is_sparse
=
is_sparse
)
# use rnn unit to update rnn
current_state
=
pd
.
fc
(
input
=
[
pre_
ids_emb
,
pre_state_expanded
],
current_state
=
pd
.
fc
(
input
=
[
pre_
state_expanded
,
pre_ids_emb
],
size
=
decoder_size
,
act
=
'tanh'
)
current_state_with_lod
=
pd
.
lod_reset
(
x
=
current_state
,
y
=
pre_score
)
# use score to do beam search
current_score
=
pd
.
fc
(
input
=
current_state
,
current_score
=
pd
.
fc
(
input
=
current_state
_with_lod
,
size
=
target_dict_dim
,
act
=
'softmax'
)
topk_scores
,
topk_indices
=
pd
.
topk
(
current_score
,
k
=
50
)
...
...
python/paddle/fluid/tests/unittests/test_layers.py
浏览文件 @
381c6a02
...
...
@@ -181,8 +181,8 @@ class TestBook(unittest.TestCase):
with
program_guard
(
program
):
x
=
layers
.
data
(
name
=
'x'
,
shape
=
[
10
],
dtype
=
'float32'
)
y
=
layers
.
data
(
name
=
'y'
,
shape
=
[
10
,
20
],
dtype
=
'float32'
,
lod_level
=
1
)
self
.
assertIsNotNone
(
layers
.
sequence_expand
(
x
=
x
,
y
=
y
))
name
=
'y'
,
shape
=
[
10
,
20
],
dtype
=
'float32'
,
lod_level
=
2
)
self
.
assertIsNotNone
(
layers
.
sequence_expand
(
x
=
x
,
y
=
y
,
ref_level
=
1
))
print
(
str
(
program
))
def
test_lstm_unit
(
self
):
...
...
python/paddle/fluid/tests/unittests/test_sequence_expand.py
浏览文件 @
381c6a02
...
...
@@ -27,12 +27,36 @@ class TestSequenceExpand(OpTest):
def
compute
(
self
):
x
=
self
.
inputs
[
'X'
]
x_data
,
x_lod
=
x
if
type
(
x
)
==
tuple
else
(
x
,
None
)
n
=
1
+
x_data
.
shape
[
0
]
if
not
x_lod
else
len
(
x_lod
[
0
])
y_data
,
y_lod
=
self
.
inputs
[
'Y'
]
repeats
=
[((
y_lod
[
-
1
][
i
+
1
]
-
y_lod
[
-
1
][
i
]))
for
i
in
range
(
len
(
y_lod
[
-
1
])
-
1
)]
out
=
x_data
.
repeat
(
repeats
,
axis
=
0
)
self
.
outputs
=
{
'Out'
:
out
}
if
hasattr
(
self
,
'attrs'
):
ref_level
=
self
.
attrs
[
'ref_level'
]
else
:
ref_level
=
len
(
y_lod
)
-
1
out
=
np
.
zeros
(
shape
=
((
0
,
)
+
x_data
.
shape
[
1
:]),
dtype
=
x_data
.
dtype
)
if
x_lod
is
None
:
x_idx
=
[
i
for
i
in
xrange
(
x_data
.
shape
[
0
]
+
1
)]
else
:
x_idx
=
x_lod
[
0
]
out_lod
=
[[
0
]]
for
i
in
xrange
(
1
,
len
(
y_lod
[
ref_level
])):
repeat_num
=
y_lod
[
ref_level
][
i
]
-
y_lod
[
ref_level
][
i
-
1
]
x_len
=
x_idx
[
i
]
-
x_idx
[
i
-
1
]
if
repeat_num
>
0
:
x_sub
=
x_data
[
x_idx
[
i
-
1
]:
x_idx
[
i
],
:]
x_sub
=
np
.
repeat
(
x_sub
,
repeat_num
,
axis
=
0
)
out
=
np
.
vstack
((
out
,
x_sub
))
if
x_lod
is
not
None
:
for
j
in
xrange
(
repeat_num
):
out_lod
[
0
].
append
(
out_lod
[
0
][
-
1
]
+
x_len
)
if
x_lod
is
None
:
self
.
outputs
=
{
'Out'
:
out
}
else
:
self
.
outputs
=
{
'Out'
:
(
out
,
out_lod
)}
def
setUp
(
self
):
self
.
op_type
=
'sequence_expand'
...
...
@@ -52,7 +76,8 @@ class TestSequenceExpandCase1(TestSequenceExpand):
x_lod
=
[[
0
,
2
,
5
]]
y_data
=
np
.
random
.
uniform
(
0.1
,
1
,
[
13
,
1
]).
astype
(
'float32'
)
y_lod
=
[[
0
,
2
,
5
],
[
0
,
2
,
4
,
7
,
10
,
13
]]
self
.
inputs
=
{
'X'
:
(
x_data
,
x_lod
),
'Y'
:
(
y_data
,
y_lod
)}
self
.
inputs
=
{
'X'
:
x_data
,
'Y'
:
(
y_data
,
y_lod
)}
self
.
attrs
=
{
'ref_level'
:
0
}
class
TestSequenceExpandCase2
(
TestSequenceExpand
):
...
...
@@ -60,8 +85,9 @@ class TestSequenceExpandCase2(TestSequenceExpand):
x_data
=
np
.
random
.
uniform
(
0.1
,
1
,
[
1
,
2
,
2
]).
astype
(
'float32'
)
x_lod
=
[[
0
,
1
]]
y_data
=
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
2
,
2
]).
astype
(
'float32'
)
y_lod
=
[[
0
,
2
]]
y_lod
=
[[
0
,
2
]
,
[
0
,
2
]
]
self
.
inputs
=
{
'X'
:
(
x_data
,
x_lod
),
'Y'
:
(
y_data
,
y_lod
)}
self
.
attrs
=
{
'ref_level'
:
0
}
class
TestSequenceExpandCase3
(
TestSequenceExpand
):
...
...
@@ -75,14 +101,9 @@ class TestSequenceExpandCase3(TestSequenceExpand):
class
TestSequenceExpandCase4
(
TestSequenceExpand
):
def
set_data
(
self
):
x_data
=
np
.
array
(
[
0.1
,
0.3
,
0.2
,
0.15
,
0.25
,
0.2
,
0.15
,
0.25
,
0.1
,
0.3
]).
reshape
(
[
2
,
5
]).
astype
(
'float32'
)
x_lod
=
[[
0
,
1
,
2
,
]]
data
=
[
0.1
,
0.3
,
0.2
,
0.15
,
0.25
,
0.2
,
0.15
,
0.25
,
0.1
,
0.3
]
x_data
=
np
.
array
(
data
).
reshape
([
5
,
2
]).
astype
(
'float32'
)
x_lod
=
[[
0
,
2
,
5
]]
y_data
=
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
1
]).
astype
(
'float32'
)
y_lod
=
[[
0
,
1
,
2
],
[
0
,
1
,
2
]]
self
.
inputs
=
{
'X'
:
(
x_data
,
x_lod
),
'Y'
:
(
y_data
,
y_lod
)}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录