提交 741046e8 编写于 作者: G guosheng

Fix and enhance beam_search_op and beam_searc_decode_op to be comparable with python beam search

上级 01fdf17e
......@@ -288,8 +288,8 @@ set(GLOB_OP_LIB ${OP_LIBRARY} CACHE INTERNAL "Global OP library")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
# cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
# cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
......
......@@ -12,8 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_decode_op.h"
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/beam_search_decode_op.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -22,8 +24,11 @@ namespace operators {
struct BeamSearchDecodeFunctor {
BeamSearchDecodeFunctor(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor, LoDTensor* score_tensor)
: step_ids_origin_(step_ids),
LoDTensor* id_tensor, LoDTensor* score_tensor,
size_t beam_size, int end_id)
: beam_size_(beam_size),
end_id_(end_id),
step_ids_origin_(step_ids),
step_scores_origin_(step_scores),
id_tensor_(id_tensor),
score_tensor_(score_tensor) {
......@@ -67,6 +72,8 @@ struct BeamSearchDecodeFunctor {
void operator()() const;
bool tensor_on_gpu_;
size_t beam_size_;
int end_id_;
const LoDTensorArray& step_ids_origin_;
const LoDTensorArray& step_scores_origin_;
LoDTensorArray step_ids_ = LoDTensorArray();
......@@ -77,13 +84,17 @@ struct BeamSearchDecodeFunctor {
template <typename T>
void BeamSearchDecodeFunctor::operator()() const {
BeamSearchDecoder<T> beam_search_decoder;
BeamSearchDecoder<T> beam_search_decoder(beam_size_, end_id_);
// Check if the tensor is on GPU. If so, use the CPU copy instead
if (tensor_on_gpu_) {
beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
// beam_search_decoder.PackAllSteps(step_ids_, step_scores_, id_tensor_,
// score_tensor_);
beam_search_decoder.Backtrace(step_ids_, step_scores_, id_tensor_,
score_tensor_);
} else {
beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
// beam_search_decoder.PackAllSteps(step_ids_origin_, step_scores_origin_,
// id_tensor_, score_tensor_);
beam_search_decoder.Backtrace(step_ids_origin_, step_scores_origin_,
id_tensor_, score_tensor_);
}
}
......@@ -122,13 +133,17 @@ class BeamSearchDecodeOp : public framework::OperatorBase {
"Level of LodTensor should be 2");
}
size_t beam_size = ctx.Attr<int>("beam_size");
int end_id = ctx.Attr<int>("end_id");
// prepare output
LoDTensor* sentenceIds = ctx.Output<LoDTensor>("SentenceIds");
LoDTensor* sentenceScores = ctx.Output<LoDTensor>("SentenceScores");
framework::VisitDataType(
framework::ToDataType(scores->at(0).type()),
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores));
BeamSearchDecodeFunctor(*ids, *scores, sentenceIds, sentenceScores,
beam_size, end_id));
}
};
......@@ -147,6 +162,9 @@ class BeamSearchDecodeOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("SentenceScores",
"(LodTensor)"
"All possible result sentences of word scores");
AddAttr<int>("beam_size", "beam size for beam search");
AddAttr<int>("end_id",
"the token id which indicates the end of a sequence");
AddComment(R"DOC(
Pack the result of Beam search op into SentenceIds and SentenceScores.
)DOC");
......@@ -172,10 +190,12 @@ class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc& op_desc,
framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto& sentence_ids = block->FindRecursiveOrCreateVar(o);
sentence_ids.SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto& o : op_desc.Output("SentenceScores")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto& sentence_scores = block->FindRecursiveOrCreateVar(o);
sentence_scores.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -14,7 +14,9 @@ limitations under the License. */
#pragma once
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
......@@ -72,6 +74,9 @@ using SentenceVector = std::vector<Sentence<T>>;
template <typename T>
struct BeamSearchDecoder {
BeamSearchDecoder(size_t beam_size, int end_id)
: beam_size_(beam_size), end_id_(end_id) {}
/**
* make a BeamNode and all it's related prefix BeanNode into a Sentence.
*/
......@@ -103,7 +108,8 @@ struct BeamSearchDecoder {
*/
void ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
LoDTensor* score_tensor, bool reverse = false,
bool sort_by_score = true) const;
/**
* Pack all steps of id/score LodTensor into sentence LoDTensor
......@@ -121,6 +127,13 @@ struct BeamSearchDecoder {
void PackAllSteps(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
void Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores, LoDTensor* id_tensor,
LoDTensor* score_tensor) const;
size_t beam_size_;
int end_id_;
};
template <typename T>
......@@ -200,7 +213,7 @@ std::vector<BeamNodeVector<T>> BeamSearchDecoder<T>::PackTwoSteps(
template <typename T>
void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<SentenceVector<T>> sentence_vector_list, LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
LoDTensor* score_tensor, bool reverse, bool sort_by_score) const {
size_t src_num = sentence_vector_list.size();
PADDLE_ENFORCE_NE(src_num, 0, "src_num should not be 0");
......@@ -211,11 +224,29 @@ void BeamSearchDecoder<T>::ConvertSentenceVectorToLodTensor(
std::vector<T> score_data;
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
if (sort_by_score) {
sort(sentence_vector_list[src_idx].begin(),
sentence_vector_list[src_idx].end(),
[reverse](const Sentence<T>& a, const Sentence<T>& b) {
if (reverse)
return a.scores.front() > b.scores.front();
else
return a.scores.back() > b.scores.back();
});
}
for (Sentence<T>& sentence : sentence_vector_list[src_idx]) {
if (reverse) {
id_data.insert(id_data.end(), sentence.word_ids.rbegin(),
sentence.word_ids.rend());
score_data.insert(score_data.end(), sentence.scores.rbegin(),
sentence.scores.rend());
} else {
id_data.insert(id_data.end(), sentence.word_ids.begin(),
sentence.word_ids.end());
score_data.insert(score_data.end(), sentence.scores.begin(),
sentence.scores.end());
}
sentence_level_lod.push_back(sentence_level_lod.back() +
sentence.word_ids.size());
}
......@@ -278,5 +309,78 @@ void BeamSearchDecoder<T>::PackAllSteps(const LoDTensorArray& step_ids,
score_tensor);
}
template <typename T>
void BeamSearchDecoder<T>::Backtrace(const LoDTensorArray& step_ids,
const LoDTensorArray& step_scores,
LoDTensor* id_tensor,
LoDTensor* score_tensor) const {
PADDLE_ENFORCE(!step_ids.empty(), "step num should be larger than 0");
PADDLE_ENFORCE_EQ(step_ids.size(), step_scores.size(),
"step_ids and step_scores should be the same");
const size_t step_num = step_ids.size();
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
std::vector<SentenceVector<T>> sentence_vector_list(
src_num, SentenceVector<T>(beam_size_));
std::vector<std::vector<size_t>> prefix_idx_vector_list(
src_num, std::vector<size_t>());
for (int step_id = step_num - 1; step_id >= 0; --step_id) {
auto& cur_ids = step_ids.at(step_id);
auto& cur_scores = step_scores.at(step_id);
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
// for each source sentence
auto& sentence_vector = sentence_vector_list.at(src_idx);
auto& prefix_idx_vector = prefix_idx_vector_list.at(src_idx);
size_t src_prefix_start = cur_ids.lod().at(kSourceLevel)[src_idx];
size_t src_prefix_end = cur_ids.lod().at(kSourceLevel)[src_idx + 1];
if (prefix_idx_vector.empty()) { // be finished and pruned at this step
// or the last time step
for (size_t prefix_idx = src_prefix_start; prefix_idx < src_prefix_end;
++prefix_idx) {
size_t candidate_start = cur_ids.lod().at(kSentenceLevel)[prefix_idx];
size_t candidate_end =
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1];
for (size_t candidate_idx = candidate_start;
candidate_idx < candidate_end; ++candidate_idx) {
prefix_idx_vector.push_back(prefix_idx);
size_t idx = prefix_idx_vector.size() - 1;
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx];
sentence_vector.at(idx).word_ids.push_back(cur_id);
sentence_vector.at(idx).scores.push_back(cur_score);
}
}
} else { // use prefix_idx_vector to backtrace
size_t src_candidate_start =
cur_ids.lod().at(kSentenceLevel)[src_prefix_start];
size_t prefix_idx = src_prefix_start;
size_t candidate_num =
cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) {
auto candidate_idx = prefix_idx_vector.at(idx);
auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx];
if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) {
// to skip redundant end tokens
sentence_vector.at(idx).word_ids.push_back(cur_id);
sentence_vector.at(idx).scores.push_back(cur_score);
}
while (src_candidate_start + candidate_num <=
candidate_idx) { // search the corresponding prefix
prefix_idx++;
candidate_num += cur_ids.lod().at(kSentenceLevel)[prefix_idx + 1] -
cur_ids.lod().at(kSentenceLevel)[prefix_idx];
}
prefix_idx_vector.at(idx) = prefix_idx;
}
}
}
}
ConvertSentenceVectorToLodTensor(sentence_vector_list, id_tensor,
score_tensor, true, true);
}
} // namespace operators
} // namespace paddle
......@@ -12,25 +12,27 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include <algorithm>
#include <limits>
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/beam_search_op.h"
namespace paddle {
namespace operators {
void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
framework::LoDTensor *selected_ids,
framework::LoDTensor *selected_scores) {
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
auto items = SelectTopBeamSizeItems();
auto items = SelectTopBeamSizeItems(pre_ids, pre_scores);
auto selected_items = ToMap(items, high_level.back());
VLOG(3) << "selected_items:";
for (size_t i = 0; i < selected_items.size(); ++i) {
......@@ -39,7 +41,8 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
VLOG(3) << ItemToString(item);
}
}
PruneEndidCandidates(pre_ids, &selected_items);
PruneEndBeams(pre_ids, &selected_items);
// calculate the output tensor's height
size_t num_instances = std::accumulate(
std::begin(selected_items), std::end(selected_items), 0,
......@@ -61,12 +64,6 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
size_t low_offset = 0;
for (auto &items : selected_items) {
low_level.push_back(low_offset);
sort(items.begin(), items.end(), [](const Item &a, const Item &b) {
if (a.offset < b.offset) {
return true;
}
return a.id < b.id;
});
for (auto &item : items) {
ids_data[low_offset] = item.id;
scores_data[low_offset] = item.score;
......@@ -86,6 +83,33 @@ void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
selected_scores->set_lod(lod);
}
void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int64_t>();
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
for (size_t src_idx = 0; src_idx < high_level.size(); ++src_idx) {
size_t src_prefix_start = high_level[src_idx];
size_t src_prefix_end = high_level[src_idx + 1];
bool finish_flag = true;
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) {
for (auto &item : items->at(offset)) {
if (item.id != static_cast<size_t>(end_id_) ||
pre_ids_data[offset] != end_id_) {
finish_flag = false;
break;
}
}
if (!finish_flag) break;
}
if (finish_flag) { // all branchs of the beam (source sentence) end and
// prune this beam
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++)
items->at(offset).clear();
}
}
}
int BeamSearch::PruneEndidCandidates(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int64_t>();
......@@ -115,13 +139,14 @@ std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
return result;
}
std::vector<std::vector<BeamSearch::Item>>
BeamSearch::SelectTopBeamSizeItems() {
std::vector<std::vector<BeamSearch::Item>> BeamSearch::SelectTopBeamSizeItems(
const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores) {
std::vector<std::vector<Item>> result;
std::vector<Item> items;
// for each source sentence, select the top beam_size items across all
// candidate sets.
while (NextItemSet(&items)) {
while (NextItemSet(pre_ids, pre_scores, &items)) {
std::nth_element(std::begin(items), std::begin(items) + beam_size_,
std::end(items), [](const Item &a, const Item &b) {
// TODO(superjom) make score's comparation customizable.
......@@ -146,7 +171,9 @@ BeamSearch::SelectTopBeamSizeItems() {
}
// the candidates of a source
bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
bool BeamSearch::NextItemSet(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
std::vector<BeamSearch::Item> *items) {
if (sent_offset_ >= ids_->NumElements(lod_level_)) {
return false;
}
......@@ -164,16 +191,27 @@ bool BeamSearch::NextItemSet(std::vector<BeamSearch::Item> *items) {
instance_dim *= ids.dims()[i];
}
auto *pre_ids_data = pre_ids.data<int64_t>();
auto *pre_scores_data = pre_scores.data<float>();
items->clear();
items->reserve(framework::product(ids.dims()));
for (size_t offset = abs_lod[lod_level_][sent_offset_];
offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) {
auto pre_id = pre_ids_data[offset];
auto pre_score = pre_scores_data[offset];
if (pre_id == end_id_) {
// Allocate all probability mass to eos_id for finished branchs and the
// other
// candidate ids can be ignored.
items->emplace_back(offset, end_id_, pre_score);
} else {
for (size_t d = 0; d < instance_dim; d++) {
const size_t dim_offset = offset * instance_dim + d;
items->emplace_back(offset, ids_data[dim_offset],
scores_data[dim_offset]);
}
}
}
sent_offset_++;
return true;
......@@ -199,7 +237,8 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
// inputs and outputs stored in proto
AddInput("pre_ids", "ids in previous step");
AddInput("pre_ids", "ids in the previous step");
AddInput("pre_scores", "accumulated scores in the previous step");
AddInput("ids", "a LoDTensor of shape of [None,k]");
AddInput("scores",
"a LoDTensor that has the same shape and LoD with `ids`");
......@@ -253,10 +292,12 @@ class BeamSearchInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("selected_ids")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto &selected_ids = block->FindRecursiveOrCreateVar(o);
selected_ids.SetType(framework::proto::VarType::LOD_TENSOR);
}
for (auto &o : op_desc.Output("selected_scores")) {
block->Var(o)->SetType(framework::proto::VarType::LOD_TENSOR);
auto &selected_scores = block->FindRecursiveOrCreateVar(o);
selected_scores.SetType(framework::proto::VarType::LOD_TENSOR);
}
}
};
......
......@@ -132,6 +132,7 @@ class BeamSearch {
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores);
/*
......@@ -152,6 +153,14 @@ class BeamSearch {
};
protected:
/*
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing, since the end tokens
* must be writed out. Also the finished branchs with top 1 score can
* be pruned.
*/
void PruneEndBeams(const framework::LoDTensor& pre_ids,
std::vector<std::vector<Item>>* items);
/*
* Delete all the records that follows the end token.
*/
......@@ -160,7 +169,7 @@ class BeamSearch {
/*
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance
* NOTE low performance.
*/
std::vector<std::vector<Item>> ToMap(
const std::vector<std::vector<Item>>& inputs, size_t element_num);
......@@ -168,12 +177,16 @@ class BeamSearch {
/*
* For each source, select top beam_size records.
*/
std::vector<std::vector<Item>> SelectTopBeamSizeItems();
std::vector<std::vector<Item>> SelectTopBeamSizeItems(
const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores);
/*
* Get the items of next source sequence, return false if no remaining items.
*/
bool NextItemSet(std::vector<Item>* items);
bool NextItemSet(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
std::vector<Item>* items);
private:
size_t beam_size_;
......@@ -192,24 +205,25 @@ template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* ids_var = context.Input<framework::LoDTensor>("ids");
auto* scores_var = context.Input<framework::LoDTensor>("scores");
auto* pre_ids_var = context.Input<framework::LoDTensor>("pre_ids");
PADDLE_ENFORCE_NOT_NULL(ids_var);
PADDLE_ENFORCE_NOT_NULL(scores_var);
PADDLE_ENFORCE_NOT_NULL(pre_ids_var);
auto* ids = context.Input<framework::LoDTensor>("ids");
auto* scores = context.Input<framework::LoDTensor>("scores");
auto* pre_ids = context.Input<framework::LoDTensor>("pre_ids");
auto* pre_scores = context.Input<framework::LoDTensor>("pre_scores");
PADDLE_ENFORCE_NOT_NULL(ids);
PADDLE_ENFORCE_NOT_NULL(scores);
PADDLE_ENFORCE_NOT_NULL(pre_ids);
PADDLE_ENFORCE_NOT_NULL(pre_scores);
size_t level = context.Attr<int>("level");
size_t beam_size = context.Attr<int>("beam_size");
int end_id = context.Attr<int>("end_id");
BeamSearch alg(*ids_var, *scores_var, level, beam_size, end_id);
auto selected_ids_var =
context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores_var =
BeamSearch alg(*ids, *scores, level, beam_size, end_id);
auto selected_ids = context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores =
context.Output<framework::LoDTensor>("selected_scores");
PADDLE_ENFORCE_NOT_NULL(selected_ids_var);
PADDLE_ENFORCE_NOT_NULL(selected_scores_var);
alg(*pre_ids_var, selected_ids_var, selected_scores_var);
PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores);
alg(*pre_ids, *pre_scores, selected_ids, selected_scores);
}
};
} // namespace operators
......
......@@ -38,15 +38,14 @@ class WriteToArrayOp : public ArrayOp {
<< " to " << offset + 1;
out->resize(offset + 1);
}
if (x_tensor.memory_size() > 0) {
auto *out_tensor = &out->at(offset);
out_tensor->set_lod(x_tensor.lod());
if (x_tensor.memory_size() > 0) {
platform::DeviceContextPool &pool =
platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
TensorCopy(x_tensor, place, dev_ctx, out_tensor);
out_tensor->set_lod(x_tensor.lod());
} else {
VLOG(10) << "WARNING: The input tensor 'x_tensor' holds no memory, so "
"nothing has been written to output array["
......
......@@ -1686,7 +1686,7 @@ def layer_norm(input,
return helper.append_activation(layer_norm_out)
def beam_search_decode(ids, scores, name=None):
def beam_search_decode(ids, scores, beam_size, end_id, name=None):
helper = LayerHelper('beam_search_decode', **locals())
sentence_ids = helper.create_tmp_variable(dtype=ids.dtype)
sentence_scores = helper.create_tmp_variable(dtype=ids.dtype)
......@@ -1698,7 +1698,9 @@ def beam_search_decode(ids, scores, name=None):
outputs={
"SentenceIds": sentence_ids,
"SentenceScores": sentence_scores
})
},
attrs={"beam_size": beam_size,
"end_id": end_id})
return sentence_ids, sentence_scores
......@@ -1926,7 +1928,7 @@ def sequence_expand(x, y, ref_level=-1, name=None):
return tmp
def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
def beam_search(pre_ids, pre_scores, ids, scores, beam_size, end_id, level=0):
'''
This function implements the beam search algorithm.
'''
......@@ -1941,6 +1943,7 @@ def beam_search(pre_ids, ids, scores, beam_size, end_id, level=0):
type='beam_search',
inputs={
'pre_ids': pre_ids,
'pre_scores': pre_scores,
'ids': ids,
'scores': scores,
},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册