Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
741046e8
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
741046e8
编写于
6月 06, 2018
作者:
G
guosheng
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix and enhance beam_search_op and beam_searc_decode_op to be comparable with python beam search
上级
01fdf17e
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
243 addition
and
62 deletion
+243
-62
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+2
-2
paddle/fluid/operators/beam_search_decode_op.cc
paddle/fluid/operators/beam_search_decode_op.cc
+31
-11
paddle/fluid/operators/beam_search_decode_op.h
paddle/fluid/operators/beam_search_decode_op.h
+110
-6
paddle/fluid/operators/beam_search_op.cc
paddle/fluid/operators/beam_search_op.cc
+62
-21
paddle/fluid/operators/beam_search_op.h
paddle/fluid/operators/beam_search_op.h
+30
-16
paddle/fluid/operators/tensor_array_read_write_op.cc
paddle/fluid/operators/tensor_array_read_write_op.cc
+2
-3
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+6
-3
未找到文件。
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
741046e8
...
@@ -288,8 +288,8 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
...
@@ -288,8 +288,8 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_test
(
scatter_test SRCS scatter_test.cc DEPS tensor
)
cc_test
(
scatter_test SRCS scatter_test.cc DEPS tensor
)
cc_test
(
beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor
)
#
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test
(
beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op
)
#
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test
(
strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory
)
cc_test
(
strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory
)
cc_test
(
save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op
)
cc_test
(
save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op
)
cc_test
(
save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op
)
cc_test
(
save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op
)
...
...
paddle/fluid/operators/beam_search_decode_op.cc
浏览文件 @
741046e8
...
@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include
"paddle/fluid/operators/beam_search_decode_op.h"
#include
<algorithm>
#include <string>
#include <string>
#include "paddle/fluid/operators/beam_search_decode_op.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
paddle
{
...
@@ -22,8 +24,11 @@ namespace operators {
...
@@ -22,8 +24,11 @@ namespace operators {
struct
BeamSearchDecodeFunctor
{
struct
BeamSearchDecodeFunctor
{
BeamSearchDecodeFunctor
(
const
LoDTensorArray
&
step_ids
,
BeamSearchDecodeFunctor
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
,
:
step_ids_origin_
(
step_ids
),
size_t
beam_size
,
int
end_id
)
:
beam_size_
(
beam_size
),
end_id_
(
end_id
),
step_ids_origin_
(
step_ids
),
step_scores_origin_
(
step_scores
),
step_scores_origin_
(
step_scores
),
id_tensor_
(
id_tensor
),
id_tensor_
(
id_tensor
),
score_tensor_
(
score_tensor
)
{
score_tensor_
(
score_tensor
)
{
...
@@ -67,6 +72,8 @@ struct BeamSearchDecodeFunctor {
...
@@ -67,6 +72,8 @@ struct BeamSearchDecodeFunctor {
void
operator
()()
const
;
void
operator
()()
const
;
bool
tensor_on_gpu_
;
bool
tensor_on_gpu_
;
size_t
beam_size_
;
int
end_id_
;
const
LoDTensorArray
&
step_ids_origin_
;
const
LoDTensorArray
&
step_ids_origin_
;
const
LoDTensorArray
&
step_scores_origin_
;
const
LoDTensorArray
&
step_scores_origin_
;
LoDTensorArray
step_ids_
=
LoDTensorArray
();
LoDTensorArray
step_ids_
=
LoDTensorArray
();
...
@@ -77,14 +84,18 @@ struct BeamSearchDecodeFunctor {
...
@@ -77,14 +84,18 @@ struct BeamSearchDecodeFunctor {
template
<
typename
T
>
template
<
typename
T
>
void
BeamSearchDecodeFunctor
::
operator
()()
const
{
void
BeamSearchDecodeFunctor
::
operator
()()
const
{
BeamSearchDecoder
<
T
>
beam_search_decoder
;
BeamSearchDecoder
<
T
>
beam_search_decoder
(
beam_size_
,
end_id_
)
;
// Check if the tensor is on GPU. If so, use the CPU copy instead
// Check if the tensor is on GPU. If so, use the CPU copy instead
if
(
tensor_on_gpu_
)
{
if
(
tensor_on_gpu_
)
{
beam_search_decoder
.
PackAllSteps
(
step_ids_
,
step_scores_
,
id_tensor_
,
// beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
score_tensor_
);
// score_tensor_);
beam_search_decoder
.
Backtrace
(
step_ids_
,
step_scores_
,
id_tensor_
,
score_tensor_
);
}
else
{
}
else
{
beam_search_decoder
.
PackAllSteps
(
step_ids_origin_
,
step_scores_origin_
,
// beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
id_tensor_
,
score_tensor_
);
// id_tensor_, score_tensor_);
beam_search_decoder
.
Backtrace
(
step_ids_origin_
,
step_scores_origin_
,
id_tensor_
,
score_tensor_
);
}
}
}
}
...
@@ -122,13 +133,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
...
@@ -122,13 +133,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
"Level of LodTensor should be 2"
);
"Level of LodTensor should be 2"
);
}
}
size_t
beam_size
=
ctx
.
Attr
<
int
>
(
"beam_size"
);
int
end_id
=
ctx
.
Attr
<
int
>
(
"end_id"
);
// prepare output
// prepare output
LoDTensor
*
sentenceIds
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceIds"
);
LoDTensor
*
sentenceIds
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceIds"
);
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
LoDTensor
*
sentenceScores
=
ctx
.
Output
<
LoDTensor
>
(
"SentenceScores"
);
framework
::
VisitDataType
(
framework
::
VisitDataType
(
framework
::
ToDataType
(
scores
->
at
(
0
).
type
()),
framework
::
ToDataType
(
scores
->
at
(
0
).
type
()),
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
));
BeamSearchDecodeFunctor
(
*
ids
,
*
scores
,
sentenceIds
,
sentenceScores
,
beam_size
,
end_id
));
}
}
};
};
...
@@ -147,6 +162,9 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -147,6 +162,9 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput
(
"SentenceScores"
,
AddOutput
(
"SentenceScores"
,
"(LodTensor)"
"(LodTensor)"
"All possible result sentences of word scores"
);
"All possible result sentences of word scores"
);
AddAttr
<
int
>
(
"beam_size"
,
"beam size for beam search"
);
AddAttr
<
int
>
(
"end_id"
,
"the token id which indicates the end of a sequence"
);
AddComment
(
R"DOC(
AddComment
(
R"DOC(
Pack the result of Beam search op into SentenceIds and SentenceScores.
Pack the result of Beam search op into SentenceIds and SentenceScores.
)DOC"
);
)DOC"
);
...
@@ -172,10 +190,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
...
@@ -172,10 +190,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceIds"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceIds"
))
{
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
&
sentence_ids
=
block
->
FindRecursiveOrCreateVar
(
o
);
sentence_ids
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceScores"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"SentenceScores"
))
{
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
&
sentence_scores
=
block
->
FindRecursiveOrCreateVar
(
o
);
sentence_scores
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/beam_search_decode_op.h
浏览文件 @
741046e8
...
@@ -14,7 +14,9 @@ limitations under the License. */
...
@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#pragma once
#include <algorithm>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
...
@@ -72,6 +74,9 @@ using SentenceVector = std::vector<Sentence<T>>;
...
@@ -72,6 +74,9 @@ using SentenceVector = std::vector<Sentence<T>>;
template
<
typename
T
>
template
<
typename
T
>
struct
BeamSearchDecoder
{
struct
BeamSearchDecoder
{
BeamSearchDecoder
(
size_t
beam_size
,
int
end_id
)
:
beam_size_
(
beam_size
),
end_id_
(
end_id
)
{}
/**
/**
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
*/
*/
...
@@ -103,7 +108,8 @@ struct BeamSearchDecoder {
...
@@ -103,7 +108,8 @@ struct BeamSearchDecoder {
*/
*/
void
ConvertSentenceVectorToLodTensor
(
void
ConvertSentenceVectorToLodTensor
(
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
;
LoDTensor
*
score_tensor
,
bool
reverse
=
false
,
bool
sort_by_score
=
true
)
const
;
/**
/**
* Pack all steps of id/score LodTensor into sentence LoDTensor
* Pack all steps of id/score LodTensor into sentence LoDTensor
...
@@ -121,6 +127,13 @@ struct BeamSearchDecoder {
...
@@ -121,6 +127,13 @@ struct BeamSearchDecoder {
void
PackAllSteps
(
const
LoDTensorArray
&
step_ids
,
void
PackAllSteps
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
;
LoDTensor
*
score_tensor
)
const
;
void
Backtrace
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
;
size_t
beam_size_
;
int
end_id_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
...
@@ -200,7 +213,7 @@ std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
...
@@ -200,7 +213,7 @@ std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
template
<
typename
T
>
template
<
typename
T
>
void
BeamSearchDecoder
<
T
>::
ConvertSentenceVectorToLodTensor
(
void
BeamSearchDecoder
<
T
>::
ConvertSentenceVectorToLodTensor
(
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
{
LoDTensor
*
score_tensor
,
bool
reverse
,
bool
sort_by_score
)
const
{
size_t
src_num
=
sentence_vector_list
.
size
();
size_t
src_num
=
sentence_vector_list
.
size
();
PADDLE_ENFORCE_NE
(
src_num
,
0
,
"src_num should not be 0"
);
PADDLE_ENFORCE_NE
(
src_num
,
0
,
"src_num should not be 0"
);
...
@@ -211,11 +224,29 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
...
@@ -211,11 +224,29 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std
::
vector
<
T
>
score_data
;
std
::
vector
<
T
>
score_data
;
for
(
size_t
src_idx
=
0
;
src_idx
<
src_num
;
++
src_idx
)
{
for
(
size_t
src_idx
=
0
;
src_idx
<
src_num
;
++
src_idx
)
{
if
(
sort_by_score
)
{
sort
(
sentence_vector_list
[
src_idx
].
begin
(),
sentence_vector_list
[
src_idx
].
end
(),
[
reverse
](
const
Sentence
<
T
>&
a
,
const
Sentence
<
T
>&
b
)
{
if
(
reverse
)
return
a
.
scores
.
front
()
>
b
.
scores
.
front
();
else
return
a
.
scores
.
back
()
>
b
.
scores
.
back
();
});
}
for
(
Sentence
<
T
>&
sentence
:
sentence_vector_list
[
src_idx
])
{
for
(
Sentence
<
T
>&
sentence
:
sentence_vector_list
[
src_idx
])
{
id_data
.
insert
(
id_data
.
end
(),
sentence
.
word_ids
.
begin
(),
if
(
reverse
)
{
sentence
.
word_ids
.
end
());
id_data
.
insert
(
id_data
.
end
(),
sentence
.
word_ids
.
rbegin
(),
score_data
.
insert
(
score_data
.
end
(),
sentence
.
scores
.
begin
(),
sentence
.
word_ids
.
rend
());
sentence
.
scores
.
end
());
score_data
.
insert
(
score_data
.
end
(),
sentence
.
scores
.
rbegin
(),
sentence
.
scores
.
rend
());
}
else
{
id_data
.
insert
(
id_data
.
end
(),
sentence
.
word_ids
.
begin
(),
sentence
.
word_ids
.
end
());
score_data
.
insert
(
score_data
.
end
(),
sentence
.
scores
.
begin
(),
sentence
.
scores
.
end
());
}
sentence_level_lod
.
push_back
(
sentence_level_lod
.
back
()
+
sentence_level_lod
.
push_back
(
sentence_level_lod
.
back
()
+
sentence
.
word_ids
.
size
());
sentence
.
word_ids
.
size
());
}
}
...
@@ -278,5 +309,78 @@ void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
...
@@ -278,5 +309,78 @@ void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
score_tensor
);
score_tensor
);
}
}
template
<
typename
T
>
void
BeamSearchDecoder
<
T
>::
Backtrace
(
const
LoDTensorArray
&
step_ids
,
const
LoDTensorArray
&
step_scores
,
LoDTensor
*
id_tensor
,
LoDTensor
*
score_tensor
)
const
{
PADDLE_ENFORCE
(
!
step_ids
.
empty
(),
"step num should be larger than 0"
);
PADDLE_ENFORCE_EQ
(
step_ids
.
size
(),
step_scores
.
size
(),
"step_ids and step_scores should be the same"
);
const
size_t
step_num
=
step_ids
.
size
();
const
size_t
src_num
=
step_ids
.
at
(
0
).
lod
().
at
(
kSourceLevel
).
size
()
-
1
;
std
::
vector
<
SentenceVector
<
T
>>
sentence_vector_list
(
src_num
,
SentenceVector
<
T
>
(
beam_size_
));
std
::
vector
<
std
::
vector
<
size_t
>>
prefix_idx_vector_list
(
src_num
,
std
::
vector
<
size_t
>
());
for
(
int
step_id
=
step_num
-
1
;
step_id
>=
0
;
--
step_id
)
{
auto
&
cur_ids
=
step_ids
.
at
(
step_id
);
auto
&
cur_scores
=
step_scores
.
at
(
step_id
);
for
(
size_t
src_idx
=
0
;
src_idx
<
src_num
;
++
src_idx
)
{
// for each source sentence
auto
&
sentence_vector
=
sentence_vector_list
.
at
(
src_idx
);
auto
&
prefix_idx_vector
=
prefix_idx_vector_list
.
at
(
src_idx
);
size_t
src_prefix_start
=
cur_ids
.
lod
().
at
(
kSourceLevel
)[
src_idx
];
size_t
src_prefix_end
=
cur_ids
.
lod
().
at
(
kSourceLevel
)[
src_idx
+
1
];
if
(
prefix_idx_vector
.
empty
())
{
// be finished and pruned at this step
// or the last time step
for
(
size_t
prefix_idx
=
src_prefix_start
;
prefix_idx
<
src_prefix_end
;
++
prefix_idx
)
{
size_t
candidate_start
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
size_t
candidate_end
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
];
for
(
size_t
candidate_idx
=
candidate_start
;
candidate_idx
<
candidate_end
;
++
candidate_idx
)
{
prefix_idx_vector
.
push_back
(
prefix_idx
);
size_t
idx
=
prefix_idx_vector
.
size
()
-
1
;
auto
cur_id
=
cur_ids
.
data
<
int64_t
>
()[
candidate_idx
];
auto
cur_score
=
cur_scores
.
data
<
T
>
()[
candidate_idx
];
sentence_vector
.
at
(
idx
).
word_ids
.
push_back
(
cur_id
);
sentence_vector
.
at
(
idx
).
scores
.
push_back
(
cur_score
);
}
}
}
else
{
// use prefix_idx_vector to backtrace
size_t
src_candidate_start
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
src_prefix_start
];
size_t
prefix_idx
=
src_prefix_start
;
size_t
candidate_num
=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
]
-
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
for
(
size_t
idx
=
0
;
idx
<
prefix_idx_vector
.
size
();
++
idx
)
{
auto
candidate_idx
=
prefix_idx_vector
.
at
(
idx
);
auto
cur_id
=
cur_ids
.
data
<
int64_t
>
()[
candidate_idx
];
auto
cur_score
=
cur_scores
.
data
<
T
>
()[
candidate_idx
];
if
(
cur_id
!=
end_id_
||
sentence_vector
.
at
(
idx
).
word_ids
.
empty
())
{
// to skip redundant end tokens
sentence_vector
.
at
(
idx
).
word_ids
.
push_back
(
cur_id
);
sentence_vector
.
at
(
idx
).
scores
.
push_back
(
cur_score
);
}
while
(
src_candidate_start
+
candidate_num
<=
candidate_idx
)
{
// search the corresponding prefix
prefix_idx
++
;
candidate_num
+=
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
+
1
]
-
cur_ids
.
lod
().
at
(
kSentenceLevel
)[
prefix_idx
];
}
prefix_idx_vector
.
at
(
idx
)
=
prefix_idx
;
}
}
}
}
ConvertSentenceVectorToLodTensor
(
sentence_vector_list
,
id_tensor
,
score_tensor
,
true
,
true
);
}
}
// namespace operators
}
// namespace operators
}
// namespace paddle
}
// namespace paddle
paddle/fluid/operators/beam_search_op.cc
浏览文件 @
741046e8
...
@@ -12,25 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -12,25 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License. */
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include <algorithm>
#include <algorithm>
#include <limits>
#include <map>
#include <map>
#include <string>
#include <string>
#include <vector>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/beam_search_op.h"
namespace
paddle
{
namespace
paddle
{
namespace
operators
{
namespace
operators
{
void
BeamSearch
::
operator
()(
const
framework
::
LoDTensor
&
pre_ids
,
void
BeamSearch
::
operator
()(
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
,
framework
::
LoDTensor
*
selected_ids
,
framework
::
LoDTensor
*
selected_ids
,
framework
::
LoDTensor
*
selected_scores
)
{
framework
::
LoDTensor
*
selected_scores
)
{
auto
abs_lod
=
framework
::
ToAbsOffset
(
ids_
->
lod
());
auto
abs_lod
=
framework
::
ToAbsOffset
(
ids_
->
lod
());
auto
&
high_level
=
abs_lod
[
lod_level_
];
auto
&
high_level
=
abs_lod
[
lod_level_
];
auto
items
=
SelectTopBeamSizeItems
();
auto
items
=
SelectTopBeamSizeItems
(
pre_ids
,
pre_scores
);
auto
selected_items
=
ToMap
(
items
,
high_level
.
back
());
auto
selected_items
=
ToMap
(
items
,
high_level
.
back
());
VLOG
(
3
)
<<
"selected_items:"
;
VLOG
(
3
)
<<
"selected_items:"
;
for
(
size_t
i
=
0
;
i
<
selected_items
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
selected_items
.
size
();
++
i
)
{
...
@@ -39,7 +41,8 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
...
@@ -39,7 +41,8 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
VLOG
(
3
)
<<
ItemToString
(
item
);
VLOG
(
3
)
<<
ItemToString
(
item
);
}
}
}
}
PruneEndidCandidates
(
pre_ids
,
&
selected_items
);
PruneEndBeams
(
pre_ids
,
&
selected_items
);
// calculate the output tensor's height
// calculate the output tensor's height
size_t
num_instances
=
std
::
accumulate
(
size_t
num_instances
=
std
::
accumulate
(
std
::
begin
(
selected_items
),
std
::
end
(
selected_items
),
0
,
std
::
begin
(
selected_items
),
std
::
end
(
selected_items
),
0
,
...
@@ -61,12 +64,6 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
...
@@ -61,12 +64,6 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
size_t
low_offset
=
0
;
size_t
low_offset
=
0
;
for
(
auto
&
items
:
selected_items
)
{
for
(
auto
&
items
:
selected_items
)
{
low_level
.
push_back
(
low_offset
);
low_level
.
push_back
(
low_offset
);
sort
(
items
.
begin
(),
items
.
end
(),
[](
const
Item
&
a
,
const
Item
&
b
)
{
if
(
a
.
offset
<
b
.
offset
)
{
return
true
;
}
return
a
.
id
<
b
.
id
;
});
for
(
auto
&
item
:
items
)
{
for
(
auto
&
item
:
items
)
{
ids_data
[
low_offset
]
=
item
.
id
;
ids_data
[
low_offset
]
=
item
.
id
;
scores_data
[
low_offset
]
=
item
.
score
;
scores_data
[
low_offset
]
=
item
.
score
;
...
@@ -86,6 +83,33 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
...
@@ -86,6 +83,33 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
selected_scores
->
set_lod
(
lod
);
selected_scores
->
set_lod
(
lod
);
}
}
void
BeamSearch
::
PruneEndBeams
(
const
framework
::
LoDTensor
&
pre_ids
,
std
::
vector
<
std
::
vector
<
Item
>>
*
items
)
{
auto
*
pre_ids_data
=
pre_ids
.
data
<
int64_t
>
();
auto
abs_lod
=
framework
::
ToAbsOffset
(
ids_
->
lod
());
auto
&
high_level
=
abs_lod
[
lod_level_
];
for
(
size_t
src_idx
=
0
;
src_idx
<
high_level
.
size
();
++
src_idx
)
{
size_t
src_prefix_start
=
high_level
[
src_idx
];
size_t
src_prefix_end
=
high_level
[
src_idx
+
1
];
bool
finish_flag
=
true
;
for
(
size_t
offset
=
src_prefix_start
;
offset
<
src_prefix_end
;
offset
++
)
{
for
(
auto
&
item
:
items
->
at
(
offset
))
{
if
(
item
.
id
!=
static_cast
<
size_t
>
(
end_id_
)
||
pre_ids_data
[
offset
]
!=
end_id_
)
{
finish_flag
=
false
;
break
;
}
}
if
(
!
finish_flag
)
break
;
}
if
(
finish_flag
)
{
// all branchs of the beam (source sentence) end and
// prune this beam
for
(
size_t
offset
=
src_prefix_start
;
offset
<
src_prefix_end
;
offset
++
)
items
->
at
(
offset
).
clear
();
}
}
}
int
BeamSearch
::
PruneEndidCandidates
(
const
framework
::
LoDTensor
&
pre_ids
,
int
BeamSearch
::
PruneEndidCandidates
(
const
framework
::
LoDTensor
&
pre_ids
,
std
::
vector
<
std
::
vector
<
Item
>>
*
items
)
{
std
::
vector
<
std
::
vector
<
Item
>>
*
items
)
{
auto
*
pre_ids_data
=
pre_ids
.
data
<
int64_t
>
();
auto
*
pre_ids_data
=
pre_ids
.
data
<
int64_t
>
();
...
@@ -115,13 +139,14 @@ std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
...
@@ -115,13 +139,14 @@ std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
return
result
;
return
result
;
}
}
std
::
vector
<
std
::
vector
<
BeamSearch
::
Item
>>
std
::
vector
<
std
::
vector
<
BeamSearch
::
Item
>>
BeamSearch
::
SelectTopBeamSizeItems
(
BeamSearch
::
SelectTopBeamSizeItems
()
{
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
)
{
std
::
vector
<
std
::
vector
<
Item
>>
result
;
std
::
vector
<
std
::
vector
<
Item
>>
result
;
std
::
vector
<
Item
>
items
;
std
::
vector
<
Item
>
items
;
// for each source sentence, select the top beam_size items across all
// for each source sentence, select the top beam_size items across all
// candidate sets.
// candidate sets.
while
(
NextItemSet
(
&
items
))
{
while
(
NextItemSet
(
pre_ids
,
pre_scores
,
&
items
))
{
std
::
nth_element
(
std
::
begin
(
items
),
std
::
begin
(
items
)
+
beam_size_
,
std
::
nth_element
(
std
::
begin
(
items
),
std
::
begin
(
items
)
+
beam_size_
,
std
::
end
(
items
),
[](
const
Item
&
a
,
const
Item
&
b
)
{
std
::
end
(
items
),
[](
const
Item
&
a
,
const
Item
&
b
)
{
// TODO(superjom) make score's comparation customizable.
// TODO(superjom) make score's comparation customizable.
...
@@ -146,7 +171,9 @@ BeamSearch::SelectTopBeamSizeItems() {
...
@@ -146,7 +171,9 @@ BeamSearch::SelectTopBeamSizeItems() {
}
}
// the candidates of a source
// the candidates of a source
bool
BeamSearch
::
NextItemSet
(
std
::
vector
<
BeamSearch
::
Item
>
*
items
)
{
bool
BeamSearch
::
NextItemSet
(
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
,
std
::
vector
<
BeamSearch
::
Item
>
*
items
)
{
if
(
sent_offset_
>=
ids_
->
NumElements
(
lod_level_
))
{
if
(
sent_offset_
>=
ids_
->
NumElements
(
lod_level_
))
{
return
false
;
return
false
;
}
}
...
@@ -164,14 +191,25 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
...
@@ -164,14 +191,25 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
instance_dim
*=
ids
.
dims
()[
i
];
instance_dim
*=
ids
.
dims
()[
i
];
}
}
auto
*
pre_ids_data
=
pre_ids
.
data
<
int64_t
>
();
auto
*
pre_scores_data
=
pre_scores
.
data
<
float
>
();
items
->
clear
();
items
->
clear
();
items
->
reserve
(
framework
::
product
(
ids
.
dims
()));
items
->
reserve
(
framework
::
product
(
ids
.
dims
()));
for
(
size_t
offset
=
abs_lod
[
lod_level_
][
sent_offset_
];
for
(
size_t
offset
=
abs_lod
[
lod_level_
][
sent_offset_
];
offset
<
abs_lod
[
lod_level_
][
sent_offset_
+
1
];
offset
++
)
{
offset
<
abs_lod
[
lod_level_
][
sent_offset_
+
1
];
offset
++
)
{
for
(
size_t
d
=
0
;
d
<
instance_dim
;
d
++
)
{
auto
pre_id
=
pre_ids_data
[
offset
];
const
size_t
dim_offset
=
offset
*
instance_dim
+
d
;
auto
pre_score
=
pre_scores_data
[
offset
];
items
->
emplace_back
(
offset
,
ids_data
[
dim_offset
],
if
(
pre_id
==
end_id_
)
{
scores_data
[
dim_offset
]);
// Allocate all probability mass to eos_id for finished branchs and the
// other
// candidate ids can be ignored.
items
->
emplace_back
(
offset
,
end_id_
,
pre_score
);
}
else
{
for
(
size_t
d
=
0
;
d
<
instance_dim
;
d
++
)
{
const
size_t
dim_offset
=
offset
*
instance_dim
+
d
;
items
->
emplace_back
(
offset
,
ids_data
[
dim_offset
],
scores_data
[
dim_offset
]);
}
}
}
}
}
...
@@ -199,7 +237,8 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
...
@@ -199,7 +237,8 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
public:
void
Make
()
override
{
void
Make
()
override
{
// inputs and outputs stored in proto
// inputs and outputs stored in proto
AddInput
(
"pre_ids"
,
"ids in previous step"
);
AddInput
(
"pre_ids"
,
"ids in the previous step"
);
AddInput
(
"pre_scores"
,
"accumulated scores in the previous step"
);
AddInput
(
"ids"
,
"a LoDTensor of shape of [None,k]"
);
AddInput
(
"ids"
,
"a LoDTensor of shape of [None,k]"
);
AddInput
(
"scores"
,
AddInput
(
"scores"
,
"a LoDTensor that has the same shape and LoD with `ids`"
);
"a LoDTensor that has the same shape and LoD with `ids`"
);
...
@@ -253,10 +292,12 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
...
@@ -253,10 +292,12 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
framework
::
BlockDesc
*
block
)
const
override
{
for
(
auto
&
o
:
op_desc
.
Output
(
"selected_ids"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"selected_ids"
))
{
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
&
selected_ids
=
block
->
FindRecursiveOrCreateVar
(
o
);
selected_ids
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
for
(
auto
&
o
:
op_desc
.
Output
(
"selected_scores"
))
{
for
(
auto
&
o
:
op_desc
.
Output
(
"selected_scores"
))
{
block
->
Var
(
o
)
->
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
auto
&
selected_scores
=
block
->
FindRecursiveOrCreateVar
(
o
);
selected_scores
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
}
};
};
...
...
paddle/fluid/operators/beam_search_op.h
浏览文件 @
741046e8
...
@@ -132,6 +132,7 @@ class BeamSearch {
...
@@ -132,6 +132,7 @@ class BeamSearch {
* that means no candidates is provided, and the task will stop running.
* that means no candidates is provided, and the task will stop running.
*/
*/
void
operator
()(
const
framework
::
LoDTensor
&
pre_ids
,
void
operator
()(
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
,
framework
::
LoDTensor
*
selected_ids
,
framework
::
LoDTensor
*
selected_ids
,
framework
::
LoDTensor
*
selected_scores
);
framework
::
LoDTensor
*
selected_scores
);
/*
/*
...
@@ -152,6 +153,14 @@ class BeamSearch {
...
@@ -152,6 +153,14 @@ class BeamSearch {
};
};
protected:
protected:
/*
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing, since the end tokens
* must be writed out. Also the finished branchs with top 1 score can
* be pruned.
*/
void
PruneEndBeams
(
const
framework
::
LoDTensor
&
pre_ids
,
std
::
vector
<
std
::
vector
<
Item
>>*
items
);
/*
/*
* Delete all the records that follows the end token.
* Delete all the records that follows the end token.
*/
*/
...
@@ -160,7 +169,7 @@ class BeamSearch {
...
@@ -160,7 +169,7 @@ class BeamSearch {
/*
/*
* Transform the items into a map whose key is offset, value is the items.
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance
* NOTE low performance
.
*/
*/
std
::
vector
<
std
::
vector
<
Item
>>
ToMap
(
std
::
vector
<
std
::
vector
<
Item
>>
ToMap
(
const
std
::
vector
<
std
::
vector
<
Item
>>&
inputs
,
size_t
element_num
);
const
std
::
vector
<
std
::
vector
<
Item
>>&
inputs
,
size_t
element_num
);
...
@@ -168,12 +177,16 @@ class BeamSearch {
...
@@ -168,12 +177,16 @@ class BeamSearch {
/*
/*
* For each source, select top beam_size records.
* For each source, select top beam_size records.
*/
*/
std
::
vector
<
std
::
vector
<
Item
>>
SelectTopBeamSizeItems
();
std
::
vector
<
std
::
vector
<
Item
>>
SelectTopBeamSizeItems
(
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
);
/*
/*
* Get the items of next source sequence, return false if no remaining items.
* Get the items of next source sequence, return false if no remaining items.
*/
*/
bool
NextItemSet
(
std
::
vector
<
Item
>*
items
);
bool
NextItemSet
(
const
framework
::
LoDTensor
&
pre_ids
,
const
framework
::
LoDTensor
&
pre_scores
,
std
::
vector
<
Item
>*
items
);
private:
private:
size_t
beam_size_
;
size_t
beam_size_
;
...
@@ -192,24 +205,25 @@ template <typename DeviceContext, typename T>
...
@@ -192,24 +205,25 @@ template <typename DeviceContext, typename T>
class
BeamSearchOpKernel
:
public
framework
::
OpKernel
<
T
>
{
class
BeamSearchOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
ids_var
=
context
.
Input
<
framework
::
LoDTensor
>
(
"ids"
);
auto
*
ids
=
context
.
Input
<
framework
::
LoDTensor
>
(
"ids"
);
auto
*
scores_var
=
context
.
Input
<
framework
::
LoDTensor
>
(
"scores"
);
auto
*
scores
=
context
.
Input
<
framework
::
LoDTensor
>
(
"scores"
);
auto
*
pre_ids_var
=
context
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
);
auto
*
pre_ids
=
context
.
Input
<
framework
::
LoDTensor
>
(
"pre_ids"
);
PADDLE_ENFORCE_NOT_NULL
(
ids_var
);
auto
*
pre_scores
=
context
.
Input
<
framework
::
LoDTensor
>
(
"pre_scores"
);
PADDLE_ENFORCE_NOT_NULL
(
scores_var
);
PADDLE_ENFORCE_NOT_NULL
(
ids
);
PADDLE_ENFORCE_NOT_NULL
(
pre_ids_var
);
PADDLE_ENFORCE_NOT_NULL
(
scores
);
PADDLE_ENFORCE_NOT_NULL
(
pre_ids
);
PADDLE_ENFORCE_NOT_NULL
(
pre_scores
);
size_t
level
=
context
.
Attr
<
int
>
(
"level"
);
size_t
level
=
context
.
Attr
<
int
>
(
"level"
);
size_t
beam_size
=
context
.
Attr
<
int
>
(
"beam_size"
);
size_t
beam_size
=
context
.
Attr
<
int
>
(
"beam_size"
);
int
end_id
=
context
.
Attr
<
int
>
(
"end_id"
);
int
end_id
=
context
.
Attr
<
int
>
(
"end_id"
);
BeamSearch
alg
(
*
ids_var
,
*
scores_var
,
level
,
beam_size
,
end_id
);
BeamSearch
alg
(
*
ids
,
*
scores
,
level
,
beam_size
,
end_id
);
auto
selected_ids_var
=
auto
selected_ids
=
context
.
Output
<
framework
::
LoDTensor
>
(
"selected_ids"
);
context
.
Output
<
framework
::
LoDTensor
>
(
"selected_ids"
);
auto
selected_scores
=
auto
selected_scores_var
=
context
.
Output
<
framework
::
LoDTensor
>
(
"selected_scores"
);
context
.
Output
<
framework
::
LoDTensor
>
(
"selected_scores"
);
PADDLE_ENFORCE_NOT_NULL
(
selected_ids
_var
);
PADDLE_ENFORCE_NOT_NULL
(
selected_ids
);
PADDLE_ENFORCE_NOT_NULL
(
selected_scores
_var
);
PADDLE_ENFORCE_NOT_NULL
(
selected_scores
);
alg
(
*
pre_ids
_var
,
selected_ids_var
,
selected_scores_var
);
alg
(
*
pre_ids
,
*
pre_scores
,
selected_ids
,
selected_scores
);
}
}
};
};
}
// namespace operators
}
// namespace operators
...
...
paddle/fluid/operators/tensor_array_read_write_op.cc
浏览文件 @
741046e8
...
@@ -38,15 +38,14 @@ class WriteToArrayOp : public ArrayOp {
...
@@ -38,15 +38,14 @@ class WriteToArrayOp : public ArrayOp {
<<
" to "
<<
offset
+
1
;
<<
" to "
<<
offset
+
1
;
out
->
resize
(
offset
+
1
);
out
->
resize
(
offset
+
1
);
}
}
auto
*
out_tensor
=
&
out
->
at
(
offset
);
out_tensor
->
set_lod
(
x_tensor
.
lod
());
if
(
x_tensor
.
memory_size
()
>
0
)
{
if
(
x_tensor
.
memory_size
()
>
0
)
{
auto
*
out_tensor
=
&
out
->
at
(
offset
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
::
Instance
();
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
auto
&
dev_ctx
=
*
pool
.
Get
(
place
);
TensorCopy
(
x_tensor
,
place
,
dev_ctx
,
out_tensor
);
TensorCopy
(
x_tensor
,
place
,
dev_ctx
,
out_tensor
);
out_tensor
->
set_lod
(
x_tensor
.
lod
());
}
else
{
}
else
{
VLOG
(
10
)
<<
"WARNING: The input tensor 'x_tensor' holds no memory, so "
VLOG
(
10
)
<<
"WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
"nothing has been written to output array["
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
741046e8
...
@@ -1686,7 +1686,7 @@ def layer_norm(input,
...
@@ -1686,7 +1686,7 @@ def layer_norm(input,
return
helper
.
append_activation
(
layer_norm_out
)
return
helper
.
append_activation
(
layer_norm_out
)
def
beam_search_decode
(
ids
,
scores
,
name
=
None
):
def
beam_search_decode
(
ids
,
scores
,
beam_size
,
end_id
,
name
=
None
):
helper
=
LayerHelper
(
'beam_search_decode'
,
**
locals
())
helper
=
LayerHelper
(
'beam_search_decode'
,
**
locals
())
sentence_ids
=
helper
.
create_tmp_variable
(
dtype
=
ids
.
dtype
)
sentence_ids
=
helper
.
create_tmp_variable
(
dtype
=
ids
.
dtype
)
sentence_scores
=
helper
.
create_tmp_variable
(
dtype
=
ids
.
dtype
)
sentence_scores
=
helper
.
create_tmp_variable
(
dtype
=
ids
.
dtype
)
...
@@ -1698,7 +1698,9 @@ def beam_search_decode(ids, scores, name=None):
...
@@ -1698,7 +1698,9 @@ def beam_search_decode(ids, scores, name=None):
outputs
=
{
outputs
=
{
"SentenceIds"
:
sentence_ids
,
"SentenceIds"
:
sentence_ids
,
"SentenceScores"
:
sentence_scores
"SentenceScores"
:
sentence_scores
})
},
attrs
=
{
"beam_size"
:
beam_size
,
"end_id"
:
end_id
})
return
sentence_ids
,
sentence_scores
return
sentence_ids
,
sentence_scores
...
@@ -1926,7 +1928,7 @@ def sequence_expand(x, y, ref_level=-1, name=None):
...
@@ -1926,7 +1928,7 @@ def sequence_expand(x, y, ref_level=-1, name=None):
return
tmp
return
tmp
def
beam_search
(
pre_ids
,
ids
,
scores
,
beam_size
,
end_id
,
level
=
0
):
def
beam_search
(
pre_ids
,
pre_scores
,
ids
,
scores
,
beam_size
,
end_id
,
level
=
0
):
'''
'''
This function implements the beam search algorithm.
This function implements the beam search algorithm.
'''
'''
...
@@ -1941,6 +1943,7 @@ def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
...
@@ -1941,6 +1943,7 @@ def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
type
=
'beam_search'
,
type
=
'beam_search'
,
inputs
=
{
inputs
=
{
'pre_ids'
:
pre_ids
,
'pre_ids'
:
pre_ids
,
'pre_scores'
:
pre_scores
,
'ids'
:
ids
,
'ids'
:
ids
,
'scores'
:
scores
,
'scores'
:
scores
,
},
},
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录