提交 49ca424d 编写于 作者: G guosheng

Fix src_idx out of range in beam_search_op

上级 6e38cc33
...@@ -87,7 +87,7 @@ void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids, ...@@ -87,7 +87,7 @@ void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids,
auto *pre_ids_data = pre_ids.data<int64_t>(); auto *pre_ids_data = pre_ids.data<int64_t>();
auto abs_lod = framework::ToAbsOffset(ids_->lod()); auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_]; auto &high_level = abs_lod[lod_level_];
for (size_t src_idx = 0; src_idx < high_level.size(); ++src_idx) { for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
size_t src_prefix_start = high_level[src_idx]; size_t src_prefix_start = high_level[src_idx];
size_t src_prefix_end = high_level[src_idx + 1]; size_t src_prefix_end = high_level[src_idx + 1];
bool finish_flag = true; bool finish_flag = true;
......
...@@ -148,7 +148,11 @@ def decode(context, is_sparse): ...@@ -148,7 +148,11 @@ def decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter)
pd.less_than(x=counter, y=array_len, cond=cond) # update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode( translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
......
...@@ -147,7 +147,11 @@ def decoder_decode(context, is_sparse): ...@@ -147,7 +147,11 @@ def decoder_decode(context, is_sparse):
pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_ids, array=ids_array, i=counter)
pd.array_write(selected_scores, array=scores_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter)
pd.less_than(x=counter, y=array_len, cond=cond) # update the break condition: up to the max length or all candidates of
# source sentences have ended.
length_cond = pd.less_than(x=counter, y=array_len)
finish_cond = pd.logical_not(pd.is_empty(x=selected_ids))
pd.logical_and(x=length_cond, y=finish_cond, out=cond)
translation_ids, translation_scores = pd.beam_search_decode( translation_ids, translation_scores = pd.beam_search_decode(
ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册