From 49ca424d6e965fc390e013fd5e3843c6136d184b Mon Sep 17 00:00:00 2001 From: guosheng Date: Thu, 14 Jun 2018 01:19:20 +0800 Subject: [PATCH] Fix src_idx out of range in beam_search_op --- paddle/fluid/operators/beam_search_op.cc | 2 +- .../machine_translation/test_machine_translation.py | 6 +++++- python/paddle/fluid/tests/book/test_machine_translation.py | 6 +++++- 3 files changed, 11 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index 89e74e35d..62771d09f 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -87,7 +87,7 @@ void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids, auto *pre_ids_data = pre_ids.data(); auto abs_lod = framework::ToAbsOffset(ids_->lod()); auto &high_level = abs_lod[lod_level_]; - for (size_t src_idx = 0; src_idx < high_level.size(); ++src_idx) { + for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { size_t src_prefix_start = high_level[src_idx]; size_t src_prefix_end = high_level[src_idx + 1]; bool finish_flag = true; diff --git a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py index ccb7a4f9a..f690a0d23 100644 --- a/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/high-level-api/machine_translation/test_machine_translation.py @@ -148,7 +148,11 @@ def decode(context, is_sparse): pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter) - pd.less_than(x=counter, y=array_len, cond=cond) + # update the break condition: up to the max length or all candidates of + # source sentences have ended. + length_cond = pd.less_than(x=counter, y=array_len) + finish_cond = pd.logical_not(pd.is_empty(x=selected_ids)) + pd.logical_and(x=length_cond, y=finish_cond, out=cond) translation_ids, translation_scores = pd.beam_search_decode( ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) diff --git a/python/paddle/fluid/tests/book/test_machine_translation.py b/python/paddle/fluid/tests/book/test_machine_translation.py index d8499fa3f..44e4c6264 100644 --- a/python/paddle/fluid/tests/book/test_machine_translation.py +++ b/python/paddle/fluid/tests/book/test_machine_translation.py @@ -147,7 +147,11 @@ def decoder_decode(context, is_sparse): pd.array_write(selected_ids, array=ids_array, i=counter) pd.array_write(selected_scores, array=scores_array, i=counter) - pd.less_than(x=counter, y=array_len, cond=cond) + # update the break condition: up to the max length or all candidates of + # source sentences have ended. + length_cond = pd.less_than(x=counter, y=array_len) + finish_cond = pd.logical_not(pd.is_empty(x=selected_ids)) + pd.logical_and(x=length_cond, y=finish_cond, out=cond) translation_ids, translation_scores = pd.beam_search_decode( ids=ids_array, scores=scores_array, beam_size=beam_size, end_id=10) -- GitLab