提交 66ead07e 编写于 作者: Y Yiqun Liu 提交者: ceci3

Make parent_idx a dispensable output for beam_search op to support models...

Make parent_idx a dispensable output for beam_search op to support models saved by older paddle version. (#16106)

test=develop
上级 02170583
...@@ -51,9 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,9 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("selected_scores", AddOutput("selected_scores",
"A LoDTensor containing the accumulated scores corresponding to " "A LoDTensor containing the accumulated scores corresponding to "
"Output(selected_ids)."); "Output(selected_ids).");
AddOutput( AddOutput("parent_idx",
"parent_idx", "A Tensor preserving the selected_ids' parent indice in pre_ids.")
"A Tensor preserving the selected_ids' parent indice in pre_ids."); .AsDispensable();
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
AddAttr<int>("level", "the level of LoDTensor"); AddAttr<int>("level", "the level of LoDTensor");
......
...@@ -44,7 +44,6 @@ class BeamSearchOpKernel : public framework::OpKernel<T> { ...@@ -44,7 +44,6 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
auto* parent_idx = context.Output<framework::Tensor>("parent_idx"); auto* parent_idx = context.Output<framework::Tensor>("parent_idx");
PADDLE_ENFORCE_NOT_NULL(selected_ids); PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores); PADDLE_ENFORCE_NOT_NULL(selected_scores);
PADDLE_ENFORCE_NOT_NULL(parent_idx);
math::BeamSearchFunctor<DeviceContext, T> alg; math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores, alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
......
...@@ -56,15 +56,15 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> { ...@@ -56,15 +56,15 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
// the output tensor shape should be [num_instances, 1] // the output tensor shape should be [num_instances, 1]
auto dims = framework::make_ddim( auto dims = framework::make_ddim(
std::vector<int64_t>({static_cast<int>(num_instances), 1})); std::vector<int64_t>({static_cast<int>(num_instances), 1}));
selected_ids->Resize(dims);
selected_scores->Resize(dims);
parent_idx->Resize({static_cast<int64_t>(num_instances)});
auto *selected_ids_data = auto *selected_ids_data =
selected_ids->mutable_data<int64_t>(platform::CPUPlace()); selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace());
auto *selected_scores_data = auto *selected_scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace()); selected_scores->mutable_data<float>(dims, platform::CPUPlace());
auto *parent_idx_data = parent_idx->mutable_data<int>(platform::CPUPlace()); auto *parent_idx_data =
parent_idx
? parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_instances)}, platform::CPUPlace())
: nullptr;
// fill in data // fill in data
std::vector<size_t> low_level; std::vector<size_t> low_level;
...@@ -72,7 +72,9 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> { ...@@ -72,7 +72,9 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
for (auto &items : selected_items) { for (auto &items : selected_items) {
low_level.push_back(low_offset); low_level.push_back(low_offset);
for (auto &item : items) { for (auto &item : items) {
parent_idx_data[low_offset] = static_cast<int>(low_level.size() - 1); if (parent_idx) {
parent_idx_data[low_offset] = static_cast<int>(low_level.size() - 1);
}
selected_ids_data[low_offset] = item.id; selected_ids_data[low_offset] = item.id;
selected_scores_data[low_offset] = item.score; selected_scores_data[low_offset] = item.score;
low_offset++; low_offset++;
......
...@@ -168,6 +168,7 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local, ...@@ -168,6 +168,7 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local,
return finish_flag; return finish_flag;
} }
template <bool ReturnParentIdx = false>
__device__ __forceinline__ void WriteBack( __device__ __forceinline__ void WriteBack(
int64_t* selected_ids, float* selected_scores, int* parent_idx, int64_t* selected_ids, float* selected_scores, int* parent_idx,
size_t* selected_offsets, Triple* top_beam_local, size_t* selected_offsets, Triple* top_beam_local,
...@@ -183,7 +184,9 @@ __device__ __forceinline__ void WriteBack( ...@@ -183,7 +184,9 @@ __device__ __forceinline__ void WriteBack(
selected_ids[global_index] = selected_ids[global_index] =
static_cast<int64_t>(top_beam_local[local_index].id); static_cast<int64_t>(top_beam_local[local_index].id);
selected_scores[global_index] = top_beam_local[local_index].score; selected_scores[global_index] = top_beam_local[local_index].score;
parent_idx[global_index] = static_cast<int>(global_offset); if (ReturnParentIdx) {
parent_idx[global_index] = static_cast<int>(global_offset);
}
global_index++; global_index++;
} }
} }
...@@ -241,9 +244,15 @@ __device__ void BeamSearchDetails( ...@@ -241,9 +244,15 @@ __device__ void BeamSearchDetails(
selected_offsets[0] = 0; selected_offsets[0] = 0;
} }
WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets, if (parent_idx) {
top_beam_local, seq_offset_start, seq_offset_end, WriteBack<true>(selected_ids, selected_scores, parent_idx,
selected_seq_start, selected_seq_length); selected_offsets, top_beam_local, seq_offset_start,
seq_offset_end, selected_seq_start, selected_seq_length);
} else {
WriteBack<false>(selected_ids, selected_scores, parent_idx,
selected_offsets, top_beam_local, seq_offset_start,
seq_offset_end, selected_seq_start, selected_seq_length);
}
} }
} }
...@@ -337,8 +346,12 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -337,8 +346,12 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace()); selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
float* selected_scores_data = float* selected_scores_data =
selected_scores->mutable_data<float>(selected_dims, context.GetPlace()); selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
int* parent_idx_data = parent_idx->mutable_data<int>( int* parent_idx_data =
{static_cast<int64_t>(num_seqs * beam_size)}, context.GetPlace()); parent_idx
? parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_seqs * beam_size)},
context.GetPlace())
: nullptr;
framework::LoD selected_lod(2); framework::LoD selected_lod(2);
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end()); selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
...@@ -396,7 +409,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -396,7 +409,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
{static_cast<int64_t>(selected_lod[1].back()), 1}); {static_cast<int64_t>(selected_lod[1].back()), 1});
selected_ids->Resize(final_selected_dims); selected_ids->Resize(final_selected_dims);
selected_scores->Resize(final_selected_dims); selected_scores->Resize(final_selected_dims);
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())}); if (parent_idx) {
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())});
}
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册