未验证 提交 163b5e57 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #11805 from guoshengCS/cherry-pick-beam-search

[cherry-pick] Fix and enhance beam_search_op and beam_searc_decode_op
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_decode_op.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/beam_search_decode_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -22,8 +24,11 @@ namespace operators {
struct BeamSearchDecodeFunctor {
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor, LoDTensor* score_tensor)
: step_ids_origin_(step_ids),
LoDTensor* id_tensor, LoDTensor* score_tensor,
size_t beam_size, int end_id)
: beam_size_(beam_size),
end_id_(end_id),
step_ids_origin_(step_ids),
step_scores_origin_(step_scores),
id_tensor_(id_tensor),
score_tensor_(score_tensor) {
......@@ -37,9 +42,11 @@ struct BeamSearchDecodeFunctor {
// Copy all tensors in the input tensor array
for (auto& step_id : step_ids_origin_) {
framework::LoDTensor out;
dev_ctx->Wait();
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
dev_ctx->Wait();
if (step_id.numel() > 0) {
dev_ctx->Wait();
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
dev_ctx->Wait();
}
out.set_lod(step_id.lod());
step_ids_.push_back(out);
......@@ -53,9 +60,12 @@ struct BeamSearchDecodeFunctor {
// Copy all tensors in the input tensor array
for (auto& step_score : step_scores_origin_) {
framework::LoDTensor out;
dev_ctx->Wait();
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx, &out);
dev_ctx->Wait();
if (step_score.numel() > 0) {
dev_ctx->Wait();
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
&out);
dev_ctx->Wait();
}
out.set_lod(step_score.lod());
step_scores_.push_back(out);
......@@ -67,6 +77,8 @@ struct BeamSearchDecodeFunctor {
void operator()() const;
bool tensor_on_gpu_;
size_t beam_size_;
int end_id_;
const LoDTensorArray& step_ids_origin_;
const LoDTensorArray& step_scores_origin_;
LoDTensorArray step_ids_ = LoDTensorArray();
......@@ -77,14 +89,14 @@ struct BeamSearchDecodeFunctor {
template <typename T>
void BeamSearchDecodeFunctor::operator()() const {
BeamSearchDecoder<T> beam_search_decoder;
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
// Check if the tensor is on GPU. If so, use the CPU copy instead
if (tensor_on_gpu_) {
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
score_tensor_);
beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
score_tensor_);
} else {
beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
id_tensor_, score_tensor_);
beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
id_tensor_, score_tensor_);
}
}
......@@ -122,13 +134,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
"Level of LodTensor should be 2");
}
size_t beam_size = ctx.Attr<int>("beam_size");
int end_id = ctx.Attr<int>("end_id");
// prepare output
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id));
}
};
......@@ -137,18 +153,32 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
void Make() override {
AddInput("Ids",
"(LodTensorArray)"
"score of the candidate words in each step");
"The LodTensorArray containing the selected ids of all steps");
AddInput("Scores",
"(LodTensorArray)"
"score of the candidate words in each step");
AddOutput("SentenceIds",
"(LodTensor)"
"All possible result sentences of word ids");
AddOutput("SentenceScores",
"(LodTensor)"
"All possible result sentences of word scores");
"The LodTensorArray containing the selected scores of all steps");
AddOutput(
"SentenceIds",
"(LodTensor)"
"An LodTensor containing all generated id sequences for all source "
"sentences");
AddOutput(
"SentenceScores",
"(LodTensor)"
"An LodTensor containing scores corresponding to Output(SentenceIds)");
AddAttr<int>("beam_size", "beam size for beam search");
AddAttr<int>("end_id",
"the token id which indicates the end of a sequence");
AddComment(R"DOC(
Pack the result of Beam search op into SentenceIds and SentenceScores.
Beam Search Decode Operator. This Operator constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray Input(ids)
whose lods can be used to restore the path in the beam search tree.
The Output(SentenceIds) and Output(SentenceScores) separately contain the
generated id sequences and the corresponding scores. The shapes and lods of the
two LodTensor are same. The lod level is 2 and the two levels separately
indicate how many hypotheses each source sentence has and how many ids each
hypothesis has.
)DOC");
}
};
......@@ -172,10 +202,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto& sentence_scores = block->FindRecursiveOrCreateVar(o);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -25,42 +27,12 @@ using LoDTensor = framework::LoDTensor;
using LoDTensorArray = framework::LoDTensorArray;
// all the lod have 2 levels.
// The First is source level, the second is sentence level.
// source level describe how many candidate words for this source.
// sentence level describe these candidates belong to which prefix
// The first is source level, the second is sentence level.
// source level describe how many prefixes (branchs) for each source sentece
// (beam). sentence level describe how these candidates belong to the prefixes.
const size_t kSourceLevel = 0;
const size_t kSentenceLevel = 1;
template <typename T>
struct BeamNode {
BeamNode(int64_t word_id, T score) : word_id_(word_id), score_(score) {}
~BeamNode() {
if (parent_) {
parent_->DropKid(this);
if (parent_->kids_.size() == 0UL) {
delete parent_;
}
}
VLOG(3) << "Delete BeamNode root with word_id:" << this->word_id_;
}
void AppendTo(BeamNode* parent) {
parent_ = parent;
parent->kids_.insert(this);
}
void DropKid(BeamNode* kid) { kids_.erase(kid); }
BeamNode* parent_ = nullptr;
std::unordered_set<BeamNode*> kids_;
int64_t word_id_;
T score_;
};
template <typename T>
using BeamNodeVector = std::vector<std::unique_ptr<BeamNode<T>>>;
template <typename T>
struct Sentence {
std::vector<int64_t> word_ids;
......@@ -72,24 +44,8 @@ using SentenceVector = std::vector<Sentence<T>>;
template <typename T>
struct BeamSearchDecoder {
/**
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
*/
Sentence<T> MakeSentence(const BeamNode<T>* node) const;
/**
* Param:
* cur_ids: LoDTensor of One step for word ID
* cur_scores: LoDTensor of One Step for word score
* prefixes_list: prefixes for each source sentence.
* sentence_vector_list: result sentence_vector for each source sentence.
* Return:
* a new prefixes list for each source of current step
*/
std::vector<BeamNodeVector<T>> PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
std::vector<BeamNodeVector<T>>* prefixes_list,
std::vector<SentenceVector<T>>* sentence_vector_list) const;
BeamSearchDecoder(size_t beam_size, int end_id)
: beam_size_(beam_size), end_id_(end_id) {}
/**
* convert the result sentence_vector for each source sentence into two
......@@ -100,107 +56,30 @@ struct BeamSearchDecoder {
* sentence_vector_list: sentence_vector for each source sentence.
* id_tensor: result LoDTensor for sentences of id.
* score_tensor: result LoDTensor for sentences of score.
* reverse: whether ids of sentence in sentence_vector_list is reversed
* sort_by_score: whether to sort hypotheses of each sentence by scores.
*/
void ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
LoDTensor* score_tensor, bool reverse = true,
bool sort_by_score = true) const;
/**
* Pack all steps of id/score LodTensor into sentence LoDTensor
* it's main logic is:
* ```python
* prefix
* result_sentence
* result_lod_tensor
*
* for (step in steps):
* prefix = PackTwoSteps(prefix, step, &result_sentence)
* ConvertSentenceVector<T>ToLodTensor(result_sentence, &result_lod_tensor)
* ```
* Gather the hypotheses for each source sentence by backtrace though the
* LoDTensorArray step_ids whose lods reserve the path in the tree.
*/
void PackAllSteps(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
};
template <typename T>
Sentence<T> BeamSearchDecoder<T>::MakeSentence(const BeamNode<T>* node) const {
Sentence<T> sentence;
while (node != nullptr) {
sentence.word_ids.emplace_back(node->word_id_);
sentence.scores.emplace_back(node->score_);
node = node->parent_;
}
std::reverse(std::begin(sentence.word_ids), std::end(sentence.word_ids));
std::reverse(std::begin(sentence.scores), std::end(sentence.scores));
return sentence;
}
template <typename T>
std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
const LoDTensor& cur_ids, const LoDTensor& cur_scores,
std::vector<BeamNodeVector<T>>* prefixes_list,
std::vector<SentenceVector<T>>* sentence_vector_list) const {
std::vector<BeamNodeVector<T>> result;
void Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
for (size_t src_idx = 0; src_idx < cur_ids.lod()[kSourceLevel].size() - 1;
++src_idx) {
size_t src_start = cur_ids.lod().at(kSourceLevel)[src_idx];
size_t src_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
BeamNodeVector<T> beam_nodes;
// if prefixes size is 0, it means this is the first step. In this step,
// all candidate id is the start of candidate sentences.
if (prefixes_list->empty()) {
PADDLE_ENFORCE_EQ(cur_ids.lod().at(kSourceLevel).back(),
cur_ids.lod().at(kSentenceLevel).back(),
"in the first step");
for (size_t id_idx = src_start; id_idx < src_end; ++id_idx) {
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(new BeamNode<T>(
cur_ids.data<int64_t>()[id_idx], cur_scores.data<T>()[id_idx])));
}
} else {
BeamNodeVector<T>& prefixes = prefixes_list->at(src_idx);
SentenceVector<T>& sentence_vector = (*sentence_vector_list)[src_idx];
PADDLE_ENFORCE_EQ(src_end - src_start, prefixes.size(),
"prefix and candidate set number should be the same");
auto candidate_offset = cur_ids.lod()[kSentenceLevel];
for (size_t prefix_idx = 0; prefix_idx < prefixes.size(); ++prefix_idx) {
std::unique_ptr<BeamNode<T>>& prefix = prefixes[prefix_idx];
size_t candidate_start = candidate_offset[src_start + prefix_idx];
size_t candidate_end = candidate_offset[src_start + prefix_idx + 1];
if (candidate_start == candidate_end) {
VLOG(3) << "this sentence has no more candidate, "
"add to result sentence and rm it from beam tree";
sentence_vector.push_back(MakeSentence(prefix.get()));
prefix.reset();
} else {
for (size_t candidate_idx = candidate_start;
candidate_idx < candidate_end; ++candidate_idx) {
auto* candidate =
new BeamNode<T>(cur_ids.data<int64_t>()[candidate_idx],
cur_scores.data<T>()[candidate_idx]);
candidate->AppendTo(prefix.get());
beam_nodes.push_back(std::unique_ptr<BeamNode<T>>(candidate));
}
prefix.release();
}
}
}
result.push_back(std::move(beam_nodes));
}
return result;
}
size_t beam_size_;
int end_id_;
};
template <typename T>
void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
LoDTensor* score_tensor, bool reverse, bool sort_by_score) const {
size_t src_num = sentence_vector_list.size();
PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0");
......@@ -211,11 +90,29 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<T> score_data;
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
if (sort_by_score) {
sort(sentence_vector_list[src_idx].begin(),
sentence_vector_list[src_idx].end(),
[reverse](const Sentence<T>& a, const Sentence<T>& b) {
if (reverse)
return a.scores.front() > b.scores.front();
else
return a.scores.back() > b.scores.back();
});
}
for (Sentence<T>& sentence : sentence_vector_list[src_idx]) {
id_data.insert(id_data.end(), sentence.word_ids.begin(),
sentence.word_ids.end());
score_data.insert(score_data.end(), sentence.scores.begin(),
sentence.scores.end());
if (reverse) {
id_data.insert(id_data.end(), sentence.word_ids.rbegin(),
sentence.word_ids.rend());
score_data.insert(score_data.end(), sentence.scores.rbegin(),
sentence.scores.rend());
} else {
id_data.insert(id_data.end(), sentence.word_ids.begin(),
sentence.word_ids.end());
score_data.insert(score_data.end(), sentence.scores.begin(),
sentence.scores.end());
}
sentence_level_lod.push_back(sentence_level_lod.back() +
sentence.word_ids.size());
}
......@@ -243,39 +140,75 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
}
template <typename T>
void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
void BeamSearchDecoder<T>::Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0");
PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(),
"step_ids and step_scores should be the same");
const size_t step_num = step_ids.size();
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
std::vector<SentenceVector<T>> sentence_vector_list(
src_num, SentenceVector<T>(beam_size_));
std::vector<std::vector<size_t>> prefix_idx_vector_list(src_num);
for (int step_id = step_num - 1; step_id >= 0; --step_id) {
auto& cur_ids = step_ids.at(step_id);
auto& cur_scores = step_scores.at(step_id);
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
// for each source sentence
auto& sentence_vector = sentence_vector_list.at(src_idx);
auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx);
size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx];
size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
if (prefix_idx_vector.empty()) { // be finished and pruned at this step
// or the last time step
for (size_t prefix_idx = src_prefix_start; prefix_idx < src_prefix_end;
++prefix_idx) {
size_t candidate_start = cur_ids.lod().at(kSentenceLevel)[prefix_idx];
size_t candidate_end =
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1];
for (size_t candidate_idx = candidate_start;
candidate_idx < candidate_end; ++candidate_idx) {
prefix_idx_vector.push_back(prefix_idx);
size_t idx = prefix_idx_vector.size() - 1;
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx];
sentence_vector.at(idx).word_ids.push_back(cur_id);
sentence_vector.at(idx).scores.push_back(cur_score);
}
}
} else { // use prefix_idx_vector to backtrace
size_t src_candidate_start =
cur_ids.lod().at(kSentenceLevel)[src_prefix_start];
size_t prefix_idx = src_prefix_start;
size_t candidate_num =
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) {
auto candidate_idx = prefix_idx_vector.at(idx);
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx];
if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) {
// to skip redundant end tokens
sentence_vector.at(idx).word_ids.push_back(cur_id);
sentence_vector.at(idx).scores.push_back(cur_score);
}
PADDLE_ENFORCE_GT(src_num, 0UL, "source num should be larger than 0");
// previous prefixes for each step,
// the init length is 0, means this is the first step.
std::vector<BeamNodeVector<T>> beamnode_vector_list(0);
std::vector<SentenceVector<T>> sentence_vector_list(src_num);
// pack all steps for one batch first, then another batch
for (size_t step_id = 0; step_id < step_num; ++step_id) {
beamnode_vector_list =
PackTwoSteps(step_ids.at(step_id), step_scores.at(step_id),
&beamnode_vector_list, &sentence_vector_list);
}
// append last beam_node to result
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
for (auto& beam_node : beamnode_vector_list.at(src_idx)) {
sentence_vector_list[src_idx].push_back(MakeSentence(beam_node.get()));
beam_node.reset();
while (src_candidate_start + candidate_num <=
candidate_idx) { // search the corresponding prefix
prefix_idx++;
candidate_num += cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
}
prefix_idx_vector.at(idx) = prefix_idx;
}
}
}
}
ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor,
score_tensor);
score_tensor, true, true);
}
} // namespace operators
......
......@@ -20,15 +20,11 @@ using LoD = paddle::framework::LoD;
using LoDTensor = paddle::framework::LoDTensor;
using LoDTensorArray = paddle::framework::LoDTensorArray;
template <typename T>
using BeamNode = paddle::operators::BeamNode<T>;
template <typename T>
using BeamSearchDecoder = paddle::operators::BeamSearchDecoder<T>;
template <typename T>
using Sentence = paddle::operators::Sentence<T>;
template <typename T>
using BeamNodeVector = paddle::operators::BeamNodeVector<T>;
template <typename T>
using SentenceVector = paddle::operators::SentenceVector<T>;
namespace paddle {
......@@ -77,138 +73,50 @@ void GenerateExample(const std::vector<size_t>& level_0,
} // namespace test
} // namespace paddle
TEST(BeamSearchDecodeOp, DeleteBeamNode) {
auto* root = new BeamNode<float>(0, 0);
auto* b1 = new BeamNode<float>(1, 1);
auto* b2 = new BeamNode<float>(2, 2);
auto* b3 = new BeamNode<float>(3, 3);
b1->AppendTo(root);
b2->AppendTo(root);
b3->AppendTo(b1);
delete b3;
delete b2;
}
TEST(BeamSearchDecodeOp, MakeSentence) {
auto* root = new BeamNode<float>(0, 0);
auto* b1 = new BeamNode<float>(1, 1);
auto* end = new BeamNode<float>(2, 2);
b1->AppendTo(root);
end->AppendTo(b1);
BeamSearchDecoder<float> helper;
Sentence<float> sentence = helper.MakeSentence(end);
delete end;
std::vector<int64_t> expect_ids = {0, 1, 2};
ASSERT_EQ(sentence.word_ids, expect_ids);
std::vector<float> expect_scores = {0, 1, 2};
ASSERT_EQ(sentence.scores, expect_scores);
}
TEST(BeamSearchDecodeOp, PackTwoStepsFistStep) {
CPUPlace place;
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(
std::vector<size_t>{0, 2, 6}, std::vector<size_t>{0, 1, 2, 3, 4, 5, 6},
std::vector<int>{1, 2, 3, 4, 5, 6}, &ids, &scores);
std::vector<BeamNodeVector<float>> beamnode_vector_list;
std::vector<SentenceVector<float>> sentence_vector_list(
2, SentenceVector<float>());
BeamSearchDecoder<float> helper;
beamnode_vector_list = helper.PackTwoSteps(
ids[0], scores[0], &beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(beamnode_vector_list.size(), 2UL);
ASSERT_EQ(beamnode_vector_list[0].size(), 2UL);
ASSERT_EQ(beamnode_vector_list[1].size(), 4UL);
}
TEST(BeamSearchDecodeOp, PackTwoSteps) {
CPUPlace place;
// first source has three prefix
BeamNodeVector<float> source0_prefixes;
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(1, 1)));
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(0, 0)));
source0_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(3, 3)));
// second source has two prefix
BeamNodeVector<float> source1_prefixes;
source1_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(4, 4)));
source1_prefixes.push_back(
std::unique_ptr<BeamNode<float>>(new BeamNode<float>(5, 5)));
std::vector<BeamNodeVector<float>> beamnode_vector_list;
std::vector<SentenceVector<float>> sentence_vector_list(
2, SentenceVector<float>());
beamnode_vector_list.push_back(std::move(source0_prefixes));
beamnode_vector_list.push_back(std::move(source1_prefixes));
// generate data for one step
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(std::vector<size_t>{0, 3, 5},
std::vector<size_t>{0, 1, 1, 3, 4, 5},
std::vector<int>{0, 1, 2, 3, 4}, &ids, &scores);
BeamSearchDecoder<float> helper1;
beamnode_vector_list = helper1.PackTwoSteps(
ids[0], scores[0], &beamnode_vector_list, &sentence_vector_list);
ASSERT_EQ(sentence_vector_list[0].size(), 1UL);
ASSERT_EQ(sentence_vector_list[1].size(), 0UL);
ASSERT_EQ(beamnode_vector_list[0].size(), 3UL);
ASSERT_EQ(beamnode_vector_list[1].size(), 2UL);
}
TEST(BeamSearchDecodeOp, PackAllSteps) {
TEST(BeamSearchDecodeOp, Backtrace) {
CPUPlace place;
// we will constuct a sample data with 3 steps and 2 source sentences
// Construct sample data with 5 steps and 2 source sentences
// beam_size = 2, start_id = 0, end_id = 1
LoDTensorArray ids;
LoDTensorArray scores;
paddle::test::GenerateExample(
std::vector<size_t>{0, 3, 6}, std::vector<size_t>{0, 1, 2, 3, 4, 5, 6},
std::vector<int>{1, 2, 3, 4, 5, 6}, &ids, &scores);
std::vector<size_t>{0, 1, 2}, std::vector<size_t>{0, 1, 2},
std::vector<int>{0, 0}, &ids, &scores); // start with start_id
paddle::test::GenerateExample(std::vector<size_t>{0, 1, 2},
std::vector<size_t>{0, 2, 4},
std::vector<int>{2, 3, 4, 5}, &ids, &scores);
paddle::test::GenerateExample(std::vector<size_t>{0, 2, 4},
std::vector<size_t>{0, 2, 2, 4, 4},
std::vector<int>{3, 1, 5, 4}, &ids, &scores);
paddle::test::GenerateExample(std::vector<size_t>{0, 2, 4},
std::vector<size_t>{0, 1, 2, 3, 4},
std::vector<int>{1, 1, 3, 5}, &ids, &scores);
paddle::test::GenerateExample(
std::vector<size_t>{0, 3, 6}, std::vector<size_t>{0, 1, 1, 3, 5, 5, 6},
std::vector<int>{0, 1, 2, 3, 4, 5}, &ids, &scores);
paddle::test::GenerateExample(std::vector<size_t>{0, 3, 6},
std::vector<size_t>{0, 0, 1, 2, 3, 4, 5},
std::vector<int>{0, 1, 2, 3, 4}, &ids, &scores);
std::vector<size_t>{0, 2, 4},
std::vector<size_t>{0, 0, 0, 2,
2}, // the branchs of the first source sentence
// are pruned since finished
std::vector<int>{5, 1},
&ids, &scores);
ASSERT_EQ(ids.size(), 3UL);
ASSERT_EQ(scores.size(), 3UL);
ASSERT_EQ(ids.size(), 5UL);
ASSERT_EQ(scores.size(), 5UL);
BeamSearchDecoder<float> helper;
BeamSearchDecoder<float> helper(2, 1); // beam_size = 2, end_id = 1
LoDTensor id_tensor;
LoDTensor score_tensor;
helper.PackAllSteps(ids, scores, &id_tensor, &score_tensor);
helper.Backtrace(ids, scores, &id_tensor, &score_tensor);
LoD lod = id_tensor.lod();
std::vector<size_t> expect_source_lod = {0, 4, 8};
std::vector<size_t> expect_source_lod = {0, 2, 4};
EXPECT_EQ(lod[0], expect_source_lod);
std::vector<size_t> expect_sentence_lod = {0, 1, 3, 6, 9, 10, 13, 16, 19};
std::vector<size_t> expect_sentence_lod = {0, 4, 7, 12, 17};
EXPECT_EQ(lod[1], expect_sentence_lod);
// 2| 1, 0| 3, 1, 0| 3, 2, 1| 5| 4, 3, 2| 4, 4, 3| 6, 5, 4
std::vector<int> expect_data = {2, 1, 0, 3, 1, 0, 3, 2, 1, 5,
4, 3, 2, 4, 4, 3, 6, 5, 4};
std::vector<int> expect_data = {0, 2, 3, 1, 0, 2, 1, 0, 4,
5, 3, 5, 0, 4, 5, 3, 1};
ASSERT_EQ(id_tensor.dims()[0], static_cast<int64_t>(expect_data.size()));
for (size_t i = 0; i < expect_data.size(); ++i) {
ASSERT_EQ(id_tensor.data<int64_t>()[i],
......
......@@ -12,25 +12,26 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include <algorithm>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/beam_search_op.h"
namespace paddle {
namespace operators {
void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
framework::LoDTensor *selected_ids,
framework::LoDTensor *selected_scores) {
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
auto items = SelectTopBeamSizeItems();
auto items = SelectTopBeamSizeItems(pre_ids, pre_scores);
auto selected_items = ToMap(items, high_level.back());
VLOG(3) << "selected_items:";
for (size_t i = 0; i < selected_items.size(); ++i) {
......@@ -39,7 +40,8 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
VLOG(3) << ItemToString(item);
}
}
PruneEndidCandidates(pre_ids, &selected_items);
PruneEndBeams(pre_ids, &selected_items);
// calculate the output tensor's height
size_t num_instances = std::accumulate(
std::begin(selected_items), std::end(selected_items), 0,
......@@ -61,12 +63,6 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
size_t low_offset = 0;
for (auto &items : selected_items) {
low_level.push_back(low_offset);
sort(items.begin(), items.end(), [](const Item &a, const Item &b) {
if (a.offset < b.offset) {
return true;
}
return a.id < b.id;
});
for (auto &item : items) {
ids_data[low_offset] = item.id;
scores_data[low_offset] = item.score;
......@@ -86,21 +82,31 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
selected_scores->set_lod(lod);
}
int BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int64_t>();
int res = 0;
for (size_t offset = 0; offset < items->size(); offset++) {
auto prefix_id = pre_ids_data[offset];
if (prefix_id == end_id_) {
items->at(offset).clear();
} else {
res++;
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
size_t src_prefix_start = high_level[src_idx];
size_t src_prefix_end = high_level[src_idx + 1];
bool finish_flag = true;
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) {
for (auto &item : items->at(offset)) {
if (item.id != static_cast<size_t>(end_id_) ||
pre_ids_data[offset] != end_id_) {
finish_flag = false;
break;
}
}
if (!finish_flag) break;
}
if (finish_flag) { // all branchs of the beam (source sentence) end and
// prune this beam
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++)
items->at(offset).clear();
}
}
return res;
}
std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
......@@ -115,19 +121,17 @@ std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
return result;
}
std::vector<std::vector<BeamSearch::Item>>
BeamSearch::SelectTopBeamSizeItems() {
std::vector<std::vector<BeamSearch::Item>> BeamSearch::SelectTopBeamSizeItems(
const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores) {
std::vector<std::vector<Item>> result;
std::vector<Item> items;
// for each source sentence, select the top beam_size items across all
// candidate sets.
while (NextItemSet(&items)) {
std::nth_element(std::begin(items), std::begin(items) + beam_size_,
std::end(items), [](const Item &a, const Item &b) {
// TODO(superjom) make score's comparation customizable.
// partial sort in descending order
return a.score > b.score;
});
while (NextItemSet(pre_ids, pre_scores, &items)) {
std::nth_element(
std::begin(items), std::begin(items) + beam_size_, std::end(items),
[](const Item &a, const Item &b) { return a.score > b.score; });
// prune the top beam_size items.
if (items.size() > beam_size_) {
items.resize(beam_size_);
......@@ -146,7 +150,9 @@ BeamSearch::SelectTopBeamSizeItems() {
}
// the candidates of a source
bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
bool BeamSearch::NextItemSet(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
std::vector<BeamSearch::Item> *items) {
if (sent_offset_ >= ids_->NumElements(lod_level_)) {
return false;
}
......@@ -164,14 +170,24 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
instance_dim *= ids.dims()[i];
}
auto *pre_ids_data = pre_ids.data<int64_t>();
auto *pre_scores_data = pre_scores.data<float>();
items->clear();
items->reserve(framework::product(ids.dims()));
for (size_t offset = abs_lod[lod_level_][sent_offset_];
offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) {
for (size_t d = 0; d < instance_dim; d++) {
const size_t dim_offset = offset * instance_dim + d;
items->emplace_back(offset, ids_data[dim_offset],
scores_data[dim_offset]);
auto pre_id = pre_ids_data[offset];
auto pre_score = pre_scores_data[offset];
if (pre_id == end_id_) {
// Allocate all probability mass to eos_id for finished branchs and the
// other candidate ids can be ignored.
items->emplace_back(offset, end_id_, pre_score);
} else {
for (size_t d = 0; d < instance_dim; d++) {
const size_t dim_offset = offset * instance_dim + d;
items->emplace_back(offset, ids_data[dim_offset],
scores_data[dim_offset]);
}
}
}
......@@ -199,15 +215,27 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]");
AddInput("pre_ids",
"(LoDTensor) The LoDTensor containing the selected ids at the "
"previous step. It should be a tensor with shape (batch_size, 1) "
"and lod `[[0, 1, ... , batch_size], [0, 1, ..., batch_size]]` at "
"thefirst step.");
AddInput("pre_scores",
"(LoDTensor) The LoDTensor containing the accumulated "
"scores corresponding to the selected ids at the previous step.");
AddInput("ids",
"(LoDTensor) The LoDTensor containing the candidates ids. Its "
"shape should be (batch_size * beam_size, K), where K supposed to "
"be beam_size.");
AddInput("scores",
"a LoDTensor that has the same shape and LoD with `ids`");
"(LoDTensor) The LodTensor containing the accumulated scores "
"corresponding to Input(ids) and its shape is the same as the "
"shape of Input(ids).");
AddOutput("selected_ids",
"a LoDTensor that stores the IDs selected by beam search");
AddOutput(
"selected_scores",
"a LoDTensor that has the same shape and LoD with `selected_ids`");
"A LodTensor that stores the IDs selected by beam search.");
AddOutput("selected_scores",
"A LoDTensor containing the accumulated scores corresponding to "
"Output(selected_ids).");
// Attributes stored in AttributeMap
AddAttr<int>("level", "the level of LoDTensor");
......@@ -215,8 +243,21 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("end_id",
"the token id which indicates the end of a sequence");
AddComment(
"This is a beam search operator that help to generate sequences.");
AddComment(R"DOC(
This operator does the search in beams for one time step.
Specifically, it selects the top-K candidate word ids of current step from
Input(ids) according to their Input(scores) for all source sentences,
where K is Attr(beam_size) and Input(ids), Input(scores) are predicted results
from the computation cell. Additionally, Input(pre_ids) and Input(pre_scores)
are the output of beam_search at previous step, they are needed for special use
to handle ended candidate translations. The paths linking prefixes and selected
candidates are organized and reserved in lod.
Note that the Input(scores) passed in should be accumulated scores, and
length penalty should be done with extra operators before calculating the
accumulated scores if needed, also suggest finding top-K before it and
using the top-K candidates following.
)DOC");
}
};
......@@ -253,10 +294,12 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto &selected_scores = block->FindRecursiveOrCreateVar(o);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -132,6 +132,7 @@ class BeamSearch {
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores);
/*
......@@ -153,14 +154,16 @@ class BeamSearch {
protected:
/*
* Delete all the records that follows the end token.
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing (thus pre_ids is needed here),
* since the end tokens must be writed out.
*/
int PruneEndidCandidates(const framework::LoDTensor& pre_ids,
std::vector<std::vector<Item>>* items);
void PruneEndBeams(const framework::LoDTensor& pre_ids,
std::vector<std::vector<Item>>* items);
/*
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance
* NOTE low performance.
*/
std::vector<std::vector<Item>> ToMap(
const std::vector<std::vector<Item>>& inputs, size_t element_num);
......@@ -168,12 +171,16 @@ class BeamSearch {
/*
* For each source, select top beam_size records.
*/
std::vector<std::vector<Item>> SelectTopBeamSizeItems();
std::vector<std::vector<Item>> SelectTopBeamSizeItems(
const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores);
/*
* Get the items of next source sequence, return false if no remaining items.
*/
bool NextItemSet(std::vector<Item>* items);
bool NextItemSet(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
std::vector<Item>* items);
private:
size_t beam_size_;
......@@ -192,24 +199,25 @@ template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* ids_var = context.Input<framework::LoDTensor>("ids");
auto* scores_var = context.Input<framework::LoDTensor>("scores");
auto* pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto* ids = context.Input<framework::LoDTensor>("ids");
auto* scores = context.Input<framework::LoDTensor>("scores");
auto* pre_ids = context.Input<framework::LoDTensor>("pre_ids");
auto* pre_scores = context.Input<framework::LoDTensor>("pre_scores");
PADDLE_ENFORCE_NOT_NULL(ids);
PADDLE_ENFORCE_NOT_NULL(scores);
PADDLE_ENFORCE_NOT_NULL(pre_ids);
PADDLE_ENFORCE_NOT_NULL(pre_scores);
size_t level = context.Attr<int>("level");
size_t beam_size = context.Attr<int>("beam_size");
int end_id = context.Attr<int>("end_id");
BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
auto selected_ids_var =
context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores_var =
BeamSearch alg(*ids, *scores, level, beam_size, end_id);
auto selected_ids = context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores =
context.Output<framework::LoDTensor>("selected_scores");
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
alg(*pre_ids_var, selected_ids_var, selected_scores_var);
PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores);
alg(*pre_ids, *pre_scores, selected_ids, selected_scores);
}
};
} // namespace operators
......
......@@ -30,7 +30,7 @@ using std::endl;
void CreateInput(LoDTensor* ids, LoDTensor* scores) {
LoD lod;
vector<size_t> level0({0, 1, 4});
vector<size_t> level0({0, 2, 4});
vector<size_t> level1({0, 1, 2, 3, 4});
lod.push_back(level0);
lod.push_back(level1);
......@@ -64,17 +64,22 @@ TEST(beam_search_op, run) {
for (int i = 0; i < 4; i++) {
pre_ids.mutable_data<int64_t>(place)[i] = i + 1;
}
LoDTensor pre_scores;
pre_scores.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
for (int i = 0; i < 4; i++) {
pre_scores.mutable_data<float>(place)[i] = 0.1 * (i + 1);
}
BeamSearch beamsearch(ids, scores, (int64_t)0, (int64_t)2, 0);
BeamSearch beamsearch(ids, scores, (size_t)0, (size_t)2, 0);
LoDTensor sids, sscores;
beamsearch(pre_ids, &sids, &sscores);
beamsearch(pre_ids, pre_scores, &sids, &sscores);
LOG(INFO) << "score: " << sscores << endl;
ASSERT_EQ(sids.lod(), sscores.lod());
vector<int> tids({2, 4, 3, 8});
vector<float> tscores({0.3, 0.5, 0.9, 0.7});
vector<int> tids({4, 2, 3, 8});
vector<float> tscores({0.5, 0.6, 0.9, 0.7});
for (int i = 0; i < 4; i++) {
ASSERT_EQ(tids[i], sids.data<int64_t>()[i]);
......
......@@ -38,15 +38,14 @@ class WriteToArrayOp : public ArrayOp {
<< " to " << offset + 1;
out->resize(offset + 1);
}
auto *out_tensor = &out->at(offset);
out_tensor->set_lod(x_tensor.lod());
if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset);
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
TensorCopy(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod());
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
......
......@@ -2223,56 +2223,6 @@ def layer_norm(input,
return helper.append_activation(layer_norm_out)
def beam_search_decode(ids, scores, name=None):
"""
Beam Search Decode
This layers is to pack the output of beam search layer into sentences and
associated scores. It is usually called after the beam search layer.
Typically, the output of beam search layer is a tensor of selected ids, with
a tensor of the score of each id. Beam search layer's output ids, however,
are generated directly during the tree search, and they are stacked by each
level of the search tree. Thus we need to reorganize them into sentences,
based on the score of each id. This layer takes the output of beam search
layer as input and repack them into sentences.
Args:
ids (Variable): The selected ids, output of beam search layer.
scores (Variable): The associated scores of the ids, out put of beam
search layer.
name (str): The name of this layer. It is optional.
Returns:
tuple(Variable): a tuple of two output tensors: sentence_ids, sentence_scores.
sentence_ids is a tensor with shape [size, length], where size is the
beam size of beam search, and length is the length of each sentence.
Note that the length of sentences may vary.
sentence_scores is a tensor with the same shape as sentence_ids.
Examples:
.. code-block:: python
ids, scores = fluid.layers.beam_search(
pre_ids, ids, scores, beam_size, end_id)
sentence_ids, sentence_scores = fluid.layers.beam_search_decode(
ids, scores)
"""
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
return sentence_ids, sentence_scores
def conv2d_transpose(input,
num_filters,
output_size=None,
......@@ -2676,38 +2626,89 @@ def sequence_expand(x, y, ref_level=-1, name=None):
return tmp
def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
'''
**beam search**
This function implements the beam search algorithm.
Beam search is a classical algorithm for selecting candidate words
in a machine translation task.
def beam_search(pre_ids,
pre_scores,
ids,
scores,
beam_size,
end_id,
level=0,
name=None):
"""
Beam search is a classical algorithm for selecting candidate words in a
machine translation task.
Refer to `Beam search <https://en.wikipedia.org/wiki/Beam_search>`_
for more details.
This layer does the search in beams for one time step. Specifically, it
selects the top-K candidate word ids of current step from :attr:`ids`
according to their :attr:`scores` for all source sentences, where K is
:attr:`beam_size` and :attr:`ids, scores` are predicted results from the
computation cell. Additionally, :attr:`pre_ids` and :attr:`pre_scores` are
the output of beam_search at previous step, they are needed for special use
to handle ended candidate translations.
Note that the :attr:`scores` passed in should be accumulated scores, and
length penalty should be done with extra operators before calculating the
accumulated scores if needed, also suggest finding top-K before it and
using the top-K candidates following.
Please see the following demo for a fully beam search usage example:
fluid/tests/book/test_machine_translation.py
Args:
pre_ids (Variable): ids in previous step.
ids (Variable): a LoDTensor of shape of [None,k]
scores (Variable): a LoDTensor that has the same shape and LoD with `ids`
beam_size (int): beam size for beam search
end_id (int): the token id which indicates the end of a sequence
level (int): the level of LoDTensor
pre_ids(Variable): The LodTensor variable which is the output of
beam_search at previous step. It should be a LodTensor with shape
:math:`(batch_size, 1)` and lod
:math:`[[0, 1, ... , batch_size], [0, 1, ..., batch_size]]` at the
first step.
pre_scores(Variable): The LodTensor variable which is the output of
beam_search at previous step.
ids(Variable): The LodTensor variable containing the candidates ids.
Its shape should be :math:`(batch_size \\times beam_size, K)`,
where :math:`K` supposed to be :attr:`beam_size`.
scores(Variable): The LodTensor variable containing the accumulated
scores corresponding to :attr:`ids` and its shape is the same as
the shape of :attr:`ids`.
beam_size(int): The beam width used in beam search.
end_id(int): The id of end token.
level(int, default 0): It can be ignored and mustn't change currently.
It means the source level of lod, which is explained as following.
The lod level of :attr:`ids` should be 2. The first level is source
level which describes how many prefixes (branchs) for each source
sentece (beam), and the second level is sentence level which
describes how these candidates belong to the prefix. The paths
linking prefixes and selected candidates are organized and reserved
in lod.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
tuple: a tuple of beam_search output variables: `selected_ids`, `selected_scores`
Variable: The LodTensor pair containing the selected ids and the \
corresponding scores.
Examples:
.. code-block:: python
# current_score is a Tensor of shape (num_batch_size, embed_size), which
# consists score of each candidate word.
topk_scores, topk_indices = pd.topk(current_score, k=50)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
'''
# Suppose `probs` contains predicted results from the computation
# cell and `pre_ids` and `pre_scores` is the output of beam_search
# at previous step.
topk_scores, topk_indices = layers.topk(probs, k=beam_size)
accu_scores = layers.elementwise_add(
x=layers.log(x=topk_scores)),
y=layers.reshape(
pre_scores, shape=[-1]),
axis=0)
selected_ids, selected_scores = layers.beam_search(
pre_ids=pre_ids,
pre_scores=pre_scores,
ids=topk_indices,
scores=accu_scores,
beam_size=beam_size,
end_id=end_id)
"""
helper = LayerHelper('beam_search', **locals())
score_type = scores.dtype
id_type = ids.dtype
......@@ -2719,6 +2720,7 @@ def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
type='beam_search',
inputs={
'pre_ids': pre_ids,
'pre_scores': pre_scores,
'ids': ids,
'scores': scores,
},
......@@ -2736,6 +2738,56 @@ def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
return selected_ids, selected_scores
def beam_search_decode(ids, scores, beam_size, end_id, name=None):
"""
Beam Search Decode Layer. This layer constructs the full hypotheses for
each source sentence by walking back along the LoDTensorArray :attr:`ids`
whose lods can be used to restore the path in the beam search tree.
Please see the following demo for a fully beam search usage example:
fluid/tests/book/test_machine_translation.py
Args:
ids(Variable): The LodTensorArray variable containing the selected ids
of all steps.
scores(Variable): The LodTensorArray variable containing the selected
scores of all steps.
beam_size(int): The beam width used in beam search.
end_id(int): The id of end token.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The LodTensor pair containing the generated id sequences \
and the corresponding scores. The shapes and lods of the two \
LodTensor are same. The lod level is 2 and the two levels \
separately indicate how many hypotheses each source sentence has \
and how many ids each hypothesis has.
Examples:
.. code-block:: python
# Suppose `ids` and `scores` are LodTensorArray variables reserving
# the selected ids and scores of all steps
finished_ids, finished_scores = layers.beam_search_decode(
ids, scores, beam_size=5, end_id=0)
"""
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
helper.append_op(
type="beam_search_decode",
inputs={"Ids": ids,
"Scores": scores},
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
},
attrs={"beam_size": beam_size,
"end_id": end_id})
return sentence_ids, sentence_scores
def lstm_unit(x_t,
hidden_t_prev,
cell_t_prev,
......
......@@ -127,9 +127,19 @@ def decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=topk_size)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(
pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)
pd.increment(x=counter, value=1, in_place=True)
......@@ -138,10 +148,14 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)
pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
# return init_ids, init_scores
......
......@@ -126,9 +126,19 @@ def decoder_decode(context, is_sparse):
current_score = pd.fc(input=current_state_with_lod,
size=target_dict_dim,
act='softmax')
topk_scores, topk_indices = pd.topk(current_score, k=50)
topk_scores, topk_indices = pd.topk(current_score, k=beam_size)
# calculate accumulated scores after topk to reduce computation cost
accu_scores = pd.elementwise_add(
x=pd.log(topk_scores), y=pd.reshape(
pre_score, shape=[-1]), axis=0)
selected_ids, selected_scores = pd.beam_search(
pre_ids, topk_indices, topk_scores, beam_size, end_id=10, level=0)
pre_ids,
pre_score,
topk_indices,
accu_scores,
beam_size,
end_id=10,
level=0)
pd.increment(x=counter, value=1, in_place=True)
......@@ -137,10 +147,14 @@ def decoder_decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter)
pd.less_than(x=counter, y=array_len, cond=cond)
# update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array)
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
# return init_ids, init_scores
......
......@@ -20,44 +20,58 @@ from paddle.fluid.op import Operator
class TestBeamSearchDecodeOp(unittest.TestCase):
"""unittest of beam_search_decode_op"""
def setUp(self):
self.scope = core.Scope()
self.place = core.CPUPlace()
def append_lod_tensor(self, tensor_array, lod, data):
lod_tensor = core.LoDTensor()
lod_tensor.set_recursive_sequence_lengths(lod)
lod_tensor.set_lod(lod)
lod_tensor.set(data, self.place)
tensor_array.append(lod_tensor)
def test_get_set(self):
ids = self.scope.var("ids").get_lod_tensor_array()
self.append_lod_tensor(
ids, [[3, 3], [1, 1, 1, 1, 1, 1]],
np.array(
[1, 2, 3, 4, 5, 6], dtype="int64"))
self.append_lod_tensor(
ids, [[3, 3], [1, 0, 2, 2, 0, 1]],
np.array(
[0, 1, 2, 3, 4, 5], dtype="int64"))
self.append_lod_tensor(
ids, [[3, 3], [0, 1, 1, 1, 1, 1]],
np.array(
[0, 1, 2, 3, 4], dtype="int64"))
scores = self.scope.var("scores").get_lod_tensor_array()
self.append_lod_tensor(
scores, [[3, 3], [1, 1, 1, 1, 1, 1]],
np.array(
[1, 2, 3, 4, 5, 6], dtype="float64"))
self.append_lod_tensor(
scores, [[3, 3], [1, 0, 2, 2, 0, 1]],
np.array(
[0, 1, 2, 3, 4, 5], dtype="float64"))
self.append_lod_tensor(
scores, [[3, 3], [0, 1, 1, 1, 1, 1]],
np.array(
[0, 1, 2, 3, 4], dtype="float64"))
# Construct sample data with 5 steps and 2 source sentences
# beam_size = 2, end_id = 1
# start with start_id
[
self.append_lod_tensor(
array, [[0, 1, 2], [0, 1, 2]], np.array(
[0, 0], dtype=dtype))
for array, dtype in ((ids, "int64"), (scores, "float32"))
]
[
self.append_lod_tensor(
array, [[0, 1, 2], [0, 2, 4]],
np.array(
[2, 3, 4, 5], dtype=dtype))
for array, dtype in ((ids, "int64"), (scores, "float32"))
]
[
self.append_lod_tensor(
array, [[0, 2, 4], [0, 2, 2, 4, 4]],
np.array(
[3, 1, 5, 4], dtype=dtype))
for array, dtype in ((ids, "int64"), (scores, "float32"))
]
[
self.append_lod_tensor(
array, [[0, 2, 4], [0, 1, 2, 3, 4]],
np.array(
[1, 1, 3, 5], dtype=dtype))
for array, dtype in ((ids, "int64"), (scores, "float32"))
]
[
self.append_lod_tensor(
array, [[0, 2, 4], [0, 0, 0, 2, 2]],
np.array(
[5, 1], dtype=dtype))
for array, dtype in ((ids, "int64"), (scores, "float32"))
]
sentence_ids = self.scope.var("sentence_ids").get_tensor()
sentence_scores = self.scope.var("sentence_scores").get_tensor()
......@@ -69,18 +83,18 @@ class TestBeamSearchDecodeOp(unittest.TestCase):
Scores="scores",
# outputs
SentenceIds="sentence_ids",
SentenceScores="sentence_scores")
SentenceScores="sentence_scores",
beam_size=2,
end_id=1, )
beam_search_decode_op.run(self.scope, self.place)
expected_lod = [[4, 4], [1, 2, 3, 3, 1, 3, 3, 3]]
self.assertEqual(sentence_ids.recursive_sequence_lengths(),
expected_lod)
self.assertEqual(sentence_scores.recursive_sequence_lengths(),
expected_lod)
expected_lod = [[0, 2, 4], [0, 4, 7, 12, 17]]
self.assertEqual(sentence_ids.lod(), expected_lod)
self.assertEqual(sentence_scores.lod(), expected_lod)
expected_data = np.array(
[2, 1, 0, 3, 1, 0, 3, 2, 1, 5, 4, 3, 2, 4, 4, 3, 6, 5, 4], "int64")
[0, 2, 3, 1, 0, 2, 1, 0, 4, 5, 3, 5, 0, 4, 5, 3, 1], "int64")
self.assertTrue(np.array_equal(np.array(sentence_ids), expected_data))
self.assertTrue(
np.array_equal(np.array(sentence_scores), expected_data))
......
......@@ -26,9 +26,12 @@ def create_tensor(scope, name, np_data):
class BeamSearchOpTester(unittest.TestCase):
"""unittest of beam_search_op"""
def setUp(self):
self.scope = core.Scope()
self._create_ids()
self._create_pre_scores()
self._create_scores()
self._create_pre_ids()
self.scope.var('selected_ids')
......@@ -37,7 +40,8 @@ class BeamSearchOpTester(unittest.TestCase):
def test_run(self):
op = Operator(
'beam_search',
pre_ids="pre_ids",
pre_ids='pre_ids',
pre_scores='pre_scores',
ids='ids',
scores='scores',
selected_ids='selected_ids',
......@@ -47,19 +51,31 @@ class BeamSearchOpTester(unittest.TestCase):
end_id=0, )
op.run(self.scope, core.CPUPlace())
selected_ids = self.scope.find_var("selected_ids").get_tensor()
print 'selected_ids', np.array(selected_ids)
print 'lod', selected_ids.recursive_sequence_lengths()
selected_scores = self.scope.find_var("selected_scores").get_tensor()
self.assertTrue(
np.allclose(
np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis]))
self.assertTrue(
np.allclose(
np.array(selected_scores),
np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis]))
self.assertEqual(selected_ids.lod(),
[[0L, 2L, 4L], [0L, 1L, 2L, 3L, 4L]])
def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int64')
tensor = create_tensor(self.scope, "pre_ids", np_data)
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1, 0.2, 0.3, 0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[1, 3], [1, 1, 1, 1]]
self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
np_data = np.array(
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_recursive_sequence_lengths(self.lod)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
......@@ -71,7 +87,7 @@ class BeamSearchOpTester(unittest.TestCase):
],
dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_recursive_sequence_lengths(self.lod)
tensor.set_lod(self.lod)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册