From 5bde12024303ca294681a9f0ba7224f3c9f44f30 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 8 Mar 2019 10:58:35 +0800 Subject: [PATCH] Make parent_idx a dispensable output for beam_search op to support models saved by older paddle version. (#16106) test=develop --- paddle/fluid/operators/beam_search_op.cc | 6 ++--- paddle/fluid/operators/beam_search_op.h | 1 - paddle/fluid/operators/math/beam_search.cc | 18 ++++++++------ paddle/fluid/operators/math/beam_search.cu | 29 ++++++++++++++++------ 4 files changed, 35 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/operators/beam_search_op.cc b/paddle/fluid/operators/beam_search_op.cc index e93cd8615e..fa6b09b4e7 100644 --- a/paddle/fluid/operators/beam_search_op.cc +++ b/paddle/fluid/operators/beam_search_op.cc @@ -51,9 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { AddOutput("selected_scores", "A LoDTensor containing the accumulated scores corresponding to " "Output(selected_ids)."); - AddOutput( - "parent_idx", - "A Tensor preserving the selected_ids' parent indice in pre_ids."); + AddOutput("parent_idx", + "A Tensor preserving the selected_ids' parent indice in pre_ids.") + .AsDispensable(); // Attributes stored in AttributeMap AddAttr("level", "the level of LoDTensor"); diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index f808020cc7..3d32ea0cc9 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -44,7 +44,6 @@ class BeamSearchOpKernel : public framework::OpKernel { auto* parent_idx = context.Output("parent_idx"); PADDLE_ENFORCE_NOT_NULL(selected_ids); PADDLE_ENFORCE_NOT_NULL(selected_scores); - PADDLE_ENFORCE_NOT_NULL(parent_idx); math::BeamSearchFunctor alg; alg(context.template device_context(), pre_ids, pre_scores, diff --git a/paddle/fluid/operators/math/beam_search.cc b/paddle/fluid/operators/math/beam_search.cc index 69971ef742..0155ef188e 100644 --- a/paddle/fluid/operators/math/beam_search.cc +++ b/paddle/fluid/operators/math/beam_search.cc @@ -56,15 +56,15 @@ class BeamSearchFunctor { // the output tensor shape should be [num_instances, 1] auto dims = framework::make_ddim( std::vector({static_cast(num_instances), 1})); - selected_ids->Resize(dims); - selected_scores->Resize(dims); - parent_idx->Resize({static_cast(num_instances)}); - auto *selected_ids_data = - selected_ids->mutable_data(platform::CPUPlace()); + selected_ids->mutable_data(dims, platform::CPUPlace()); auto *selected_scores_data = - selected_scores->mutable_data(platform::CPUPlace()); - auto *parent_idx_data = parent_idx->mutable_data(platform::CPUPlace()); + selected_scores->mutable_data(dims, platform::CPUPlace()); + auto *parent_idx_data = + parent_idx + ? parent_idx->mutable_data( + {static_cast(num_instances)}, platform::CPUPlace()) + : nullptr; // fill in data std::vector low_level; @@ -72,7 +72,9 @@ class BeamSearchFunctor { for (auto &items : selected_items) { low_level.push_back(low_offset); for (auto &item : items) { - parent_idx_data[low_offset] = static_cast(low_level.size() - 1); + if (parent_idx) { + parent_idx_data[low_offset] = static_cast(low_level.size() - 1); + } selected_ids_data[low_offset] = item.id; selected_scores_data[low_offset] = item.score; low_offset++; diff --git a/paddle/fluid/operators/math/beam_search.cu b/paddle/fluid/operators/math/beam_search.cu index d66778a6fe..ecfeba3384 100644 --- a/paddle/fluid/operators/math/beam_search.cu +++ b/paddle/fluid/operators/math/beam_search.cu @@ -168,6 +168,7 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local, return finish_flag; } +template __device__ __forceinline__ void WriteBack( int64_t* selected_ids, float* selected_scores, int* parent_idx, size_t* selected_offsets, Triple* top_beam_local, @@ -183,7 +184,9 @@ __device__ __forceinline__ void WriteBack( selected_ids[global_index] = static_cast(top_beam_local[local_index].id); selected_scores[global_index] = top_beam_local[local_index].score; - parent_idx[global_index] = static_cast(global_offset); + if (ReturnParentIdx) { + parent_idx[global_index] = static_cast(global_offset); + } global_index++; } } @@ -241,9 +244,15 @@ __device__ void BeamSearchDetails( selected_offsets[0] = 0; } - WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets, - top_beam_local, seq_offset_start, seq_offset_end, - selected_seq_start, selected_seq_length); + if (parent_idx) { + WriteBack(selected_ids, selected_scores, parent_idx, + selected_offsets, top_beam_local, seq_offset_start, + seq_offset_end, selected_seq_start, selected_seq_length); + } else { + WriteBack(selected_ids, selected_scores, parent_idx, + selected_offsets, top_beam_local, seq_offset_start, + seq_offset_end, selected_seq_start, selected_seq_length); + } } } @@ -337,8 +346,12 @@ class BeamSearchFunctor { selected_ids->mutable_data(selected_dims, context.GetPlace()); float* selected_scores_data = selected_scores->mutable_data(selected_dims, context.GetPlace()); - int* parent_idx_data = parent_idx->mutable_data( - {static_cast(num_seqs * beam_size)}, context.GetPlace()); + int* parent_idx_data = + parent_idx + ? parent_idx->mutable_data( + {static_cast(num_seqs * beam_size)}, + context.GetPlace()) + : nullptr; framework::LoD selected_lod(2); selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end()); @@ -396,7 +409,9 @@ class BeamSearchFunctor { {static_cast(selected_lod[1].back()), 1}); selected_ids->Resize(final_selected_dims); selected_scores->Resize(final_selected_dims); - parent_idx->Resize({static_cast(selected_lod[1].back())}); + if (parent_idx) { + parent_idx->Resize({static_cast(selected_lod[1].back())}); + } } } }; -- GitLab