提交 66ead07e 编写于 作者: Y Yiqun Liu 提交者: ceci3

Make parent_idx a dispensable output for beam_search op to support models...

Make parent_idx a dispensable output for beam_search op to support models saved by older paddle version. (#16106)

test=develop
上级 02170583
......@@ -51,9 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("selected_scores",
"A LoDTensor containing the accumulated scores corresponding to "
"Output(selected_ids).");
AddOutput(
"parent_idx",
"A Tensor preserving the selected_ids' parent indice in pre_ids.");
AddOutput("parent_idx",
"A Tensor preserving the selected_ids' parent indice in pre_ids.")
.AsDispensable();
// Attributes stored in AttributeMap
AddAttr<int>("level", "the level of LoDTensor");
......
......@@ -44,7 +44,6 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
auto* parent_idx = context.Output<framework::Tensor>("parent_idx");
PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores);
PADDLE_ENFORCE_NOT_NULL(parent_idx);
math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
......
......@@ -56,15 +56,15 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
// the output tensor shape should be [num_instances, 1]
auto dims = framework::make_ddim(
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
selected_ids->Resize(dims);
selected_scores->Resize(dims);
parent_idx->Resize({static_cast<int64_t>(num_instances)});
auto *selected_ids_data =
selected_ids->mutable_data<int64_t>(platform::CPUPlace());
selected_ids->mutable_data<int64_t>(dims, platform::CPUPlace());
auto *selected_scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace());
auto *parent_idx_data = parent_idx->mutable_data<int>(platform::CPUPlace());
selected_scores->mutable_data<float>(dims, platform::CPUPlace());
auto *parent_idx_data =
parent_idx
? parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_instances)}, platform::CPUPlace())
: nullptr;
// fill in data
std::vector<size_t> low_level;
......@@ -72,7 +72,9 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
for (auto &items : selected_items) {
low_level.push_back(low_offset);
for (auto &item : items) {
if (parent_idx) {
parent_idx_data[low_offset] = static_cast<int>(low_level.size() - 1);
}
selected_ids_data[low_offset] = item.id;
selected_scores_data[low_offset] = item.score;
low_offset++;
......
......@@ -168,6 +168,7 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local,
return finish_flag;
}
template <bool ReturnParentIdx = false>
__device__ __forceinline__ void WriteBack(
int64_t* selected_ids, float* selected_scores, int* parent_idx,
size_t* selected_offsets, Triple* top_beam_local,
......@@ -183,7 +184,9 @@ __device__ __forceinline__ void WriteBack(
selected_ids[global_index] =
static_cast<int64_t>(top_beam_local[local_index].id);
selected_scores[global_index] = top_beam_local[local_index].score;
if (ReturnParentIdx) {
parent_idx[global_index] = static_cast<int>(global_offset);
}
global_index++;
}
}
......@@ -241,9 +244,15 @@ __device__ void BeamSearchDetails(
selected_offsets[0] = 0;
}
WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets,
top_beam_local, seq_offset_start, seq_offset_end,
selected_seq_start, selected_seq_length);
if (parent_idx) {
WriteBack<true>(selected_ids, selected_scores, parent_idx,
selected_offsets, top_beam_local, seq_offset_start,
seq_offset_end, selected_seq_start, selected_seq_length);
} else {
WriteBack<false>(selected_ids, selected_scores, parent_idx,
selected_offsets, top_beam_local, seq_offset_start,
seq_offset_end, selected_seq_start, selected_seq_length);
}
}
}
......@@ -337,8 +346,12 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
float* selected_scores_data =
selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
int* parent_idx_data = parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_seqs * beam_size)}, context.GetPlace());
int* parent_idx_data =
parent_idx
? parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_seqs * beam_size)},
context.GetPlace())
: nullptr;
framework::LoD selected_lod(2);
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
......@@ -396,9 +409,11 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
{static_cast<int64_t>(selected_lod[1].back()), 1});
selected_ids->Resize(final_selected_dims);
selected_scores->Resize(final_selected_dims);
if (parent_idx) {
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())});
}
}
}
};
template class BeamSearchFunctor<platform::CUDADeviceContext, int>;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册