未验证 提交 16d54f7f 编写于 作者: Y Yiqun Liu 提交者: GitHub

Return parent_idx in beam_search op (#15520)

* Refine beam_search_op to output an extra parent_idx tensor.
test=develop

* Fix the unittest test_beam_search_op.
test=develop

* Fix the merging mistake.
test=develop
上级 72ee3c62
...@@ -122,7 +122,7 @@ paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None, ...@@ -122,7 +122,7 @@ paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None,
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None)) paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False)) paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False)) paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name'], varargs=None, keywords=None, defaults=(0, True, None)) paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name', 'return_parent_idx'], varargs=None, keywords=None, defaults=(0, True, None, False))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None)) paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None)) paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
......
...@@ -51,6 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -51,6 +51,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
AddOutput("selected_scores", AddOutput("selected_scores",
"A LoDTensor containing the accumulated scores corresponding to " "A LoDTensor containing the accumulated scores corresponding to "
"Output(selected_ids)."); "Output(selected_ids).");
AddOutput(
"parent_idx",
"A Tensor preserving the selected_ids' parent indice in pre_ids.");
// Attributes stored in AttributeMap // Attributes stored in AttributeMap
AddAttr<int>("level", "the level of LoDTensor"); AddAttr<int>("level", "the level of LoDTensor");
......
...@@ -41,13 +41,15 @@ class BeamSearchOpKernel : public framework::OpKernel<T> { ...@@ -41,13 +41,15 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
auto selected_ids = context.Output<framework::LoDTensor>("selected_ids"); auto selected_ids = context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores = auto selected_scores =
context.Output<framework::LoDTensor>("selected_scores"); context.Output<framework::LoDTensor>("selected_scores");
auto* parent_idx = context.Output<framework::Tensor>("parent_idx");
PADDLE_ENFORCE_NOT_NULL(selected_ids); PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores); PADDLE_ENFORCE_NOT_NULL(selected_scores);
PADDLE_ENFORCE_NOT_NULL(parent_idx);
math::BeamSearchFunctor<DeviceContext, T> alg; math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores, alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
ids, scores, selected_ids, selected_scores, level, beam_size, end_id, ids, scores, selected_ids, selected_scores, parent_idx, level,
is_accumulated); beam_size, end_id, is_accumulated);
} }
}; };
......
...@@ -31,7 +31,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> { ...@@ -31,7 +31,7 @@ class GatherOpCUDAKernel : public framework::OpKernel<T> {
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
GPUGather<T>(ctx.device_context(), *x, *index, output); GPUGather<T>(ctx.device_context(), *x, *index, output);
} }
}; };
...@@ -45,14 +45,13 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> { ...@@ -45,14 +45,13 @@ class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
auto *Index = ctx.Input<Tensor>("Index"); auto *Index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X")); auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out")); auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto *x = ctx.Input<Tensor>("X");
dX->mutable_data<T>(ctx.GetPlace()); dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX); auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>() auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device(); .eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX); GPUScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
} }
}; };
......
...@@ -35,7 +35,7 @@ class GatherOpKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,7 @@ class GatherOpKernel : public framework::OpKernel<T> {
auto *output = ctx.Output<Tensor>("Out"); auto *output = ctx.Output<Tensor>("Out");
output->mutable_data<T>(ctx.GetPlace()); output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
CPUGather<T>(ctx.device_context(), *x, *index, output); CPUGather<T>(ctx.device_context(), *x, *index, output);
} }
}; };
...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class GatherGradientOpKernel : public framework::OpKernel<T> {
auto &place = *ctx.template device_context<platform::CPUDeviceContext>() auto &place = *ctx.template device_context<platform::CPUDeviceContext>()
.eigen_device(); .eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0)); dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX); ScatterAssign<T>(ctx.device_context(), *dO, *Index, dX);
} }
}; };
......
...@@ -29,8 +29,9 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> { ...@@ -29,8 +29,9 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
const framework::LoDTensor *ids, const framework::LoDTensor *ids,
const framework::LoDTensor *scores, const framework::LoDTensor *scores,
framework::LoDTensor *selected_ids, framework::LoDTensor *selected_ids,
framework::LoDTensor *selected_scores, size_t level, framework::LoDTensor *selected_scores,
size_t beam_size, int end_id, bool is_accumulated) { framework::Tensor *parent_idx, size_t level, size_t beam_size,
int end_id, bool is_accumulated) {
auto abs_lod = framework::ToAbsOffset(scores->lod()); auto abs_lod = framework::ToAbsOffset(scores->lod());
auto &high_level = abs_lod[level]; auto &high_level = abs_lod[level];
...@@ -57,11 +58,13 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> { ...@@ -57,11 +58,13 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
std::vector<int64_t>({static_cast<int>(num_instances), 1})); std::vector<int64_t>({static_cast<int>(num_instances), 1}));
selected_ids->Resize(dims); selected_ids->Resize(dims);
selected_scores->Resize(dims); selected_scores->Resize(dims);
parent_idx->Resize({static_cast<int64_t>(num_instances)});
auto *selected_ids_data = auto *selected_ids_data =
selected_ids->mutable_data<int64_t>(platform::CPUPlace()); selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *selected_scores_data = auto *selected_scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace()); selected_scores->mutable_data<float>(platform::CPUPlace());
auto *parent_idx_data = parent_idx->mutable_data<int>(platform::CPUPlace());
// fill in data // fill in data
std::vector<size_t> low_level; std::vector<size_t> low_level;
...@@ -69,6 +72,7 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> { ...@@ -69,6 +72,7 @@ class BeamSearchFunctor<platform::CPUDeviceContext, T> {
for (auto &items : selected_items) { for (auto &items : selected_items) {
low_level.push_back(low_offset); low_level.push_back(low_offset);
for (auto &item : items) { for (auto &item : items) {
parent_idx_data[low_offset] = static_cast<int>(low_level.size() - 1);
selected_ids_data[low_offset] = item.id; selected_ids_data[low_offset] = item.id;
selected_scores_data[low_offset] = item.score; selected_scores_data[low_offset] = item.score;
low_offset++; low_offset++;
......
...@@ -157,10 +157,10 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local, ...@@ -157,10 +157,10 @@ __device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local,
} }
__device__ __forceinline__ void WriteBack( __device__ __forceinline__ void WriteBack(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, int64_t* selected_ids, float* selected_scores, int* parent_idx,
Triple* top_beam_local, const int seq_offset_start, size_t* selected_offsets, Triple* top_beam_local,
const int seq_offset_end, const int selected_seq_start, const int seq_offset_start, const int seq_offset_end,
const int selected_seq_length) { const int selected_seq_start, const int selected_seq_length) {
const int tid = threadIdx.x; // use 1 thread only for each sequence const int tid = threadIdx.x; // use 1 thread only for each sequence
int global_index = selected_seq_start; int global_index = selected_seq_start;
for (int global_offset = seq_offset_start; global_offset < seq_offset_end; for (int global_offset = seq_offset_start; global_offset < seq_offset_end;
...@@ -171,6 +171,7 @@ __device__ __forceinline__ void WriteBack( ...@@ -171,6 +171,7 @@ __device__ __forceinline__ void WriteBack(
selected_ids[global_index] = selected_ids[global_index] =
static_cast<int64_t>(top_beam_local[local_index].id); static_cast<int64_t>(top_beam_local[local_index].id);
selected_scores[global_index] = top_beam_local[local_index].score; selected_scores[global_index] = top_beam_local[local_index].score;
parent_idx[global_index] = static_cast<int>(global_offset);
global_index++; global_index++;
} }
} }
...@@ -180,11 +181,11 @@ __device__ __forceinline__ void WriteBack( ...@@ -180,11 +181,11 @@ __device__ __forceinline__ void WriteBack(
template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs> template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs>
__device__ void BeamSearchDetails( __device__ void BeamSearchDetails(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, int64_t* selected_ids, float* selected_scores, int* parent_idx,
const int64_t* pre_ids, const float* pre_scores, const int64_t* ids, size_t* selected_offsets, const int64_t* pre_ids, const float* pre_scores,
const float* scores, const int seq_offset_start, const int seq_offset_end, const int64_t* ids, const float* scores, const int seq_offset_start,
const int seq_width, int beam_size, int end_id, bool is_accumulated, const int seq_offset_end, const int seq_width, int beam_size, int end_id,
int num_used_threads) { bool is_accumulated, int num_used_threads) {
__shared__ Triple top_beam[MaxLength]; __shared__ Triple top_beam[MaxLength];
int num_items = 0; int num_items = 0;
...@@ -228,15 +229,15 @@ __device__ void BeamSearchDetails( ...@@ -228,15 +229,15 @@ __device__ void BeamSearchDetails(
selected_offsets[0] = 0; selected_offsets[0] = 0;
} }
WriteBack(selected_ids, selected_scores, selected_offsets, top_beam_local, WriteBack(selected_ids, selected_scores, parent_idx, selected_offsets,
seq_offset_start, seq_offset_end, selected_seq_start, top_beam_local, seq_offset_start, seq_offset_end,
selected_seq_length); selected_seq_start, selected_seq_length);
} }
} }
template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs> template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs>
__global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores, __global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores,
size_t* selected_offsets, int* parent_idx, size_t* selected_offsets,
const int64_t* pre_ids, const int64_t* pre_ids,
const float* pre_scores, const int64_t* ids, const float* pre_scores, const int64_t* ids,
const float* scores, const size_t* seq_offsets, const float* scores, const size_t* seq_offsets,
...@@ -250,24 +251,25 @@ __global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores, ...@@ -250,24 +251,25 @@ __global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores,
int seq_offset_end = static_cast<int>(seq_offsets[seq_id + 1]); int seq_offset_end = static_cast<int>(seq_offsets[seq_id + 1]);
BeamSearchDetails<MaxLength, MaxThreadsPerSeq, MaxSeqs>( BeamSearchDetails<MaxLength, MaxThreadsPerSeq, MaxSeqs>(
selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids, selected_ids, selected_scores, parent_idx, selected_offsets, pre_ids,
scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id, pre_scores, ids, scores, seq_offset_start, seq_offset_end, seq_width,
is_accumulated, num_used_threads); beam_size, end_id, is_accumulated, num_used_threads);
} }
template <int MaxLength, int MaxThreadsPerSeq> template <int MaxLength, int MaxThreadsPerSeq>
__global__ void BeamSearchKernelSingle( __global__ void BeamSearchKernelSingle(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets, int64_t* selected_ids, float* selected_scores, int* parent_idx,
const int64_t* pre_ids, const float* pre_scores, const int64_t* ids, size_t* selected_offsets, const int64_t* pre_ids, const float* pre_scores,
const float* scores, const int seq_length, const int seq_width, const int64_t* ids, const float* scores, const int seq_length,
int beam_size, int end_id, bool is_accumulated, int num_used_threads) { const int seq_width, int beam_size, int end_id, bool is_accumulated,
int num_used_threads) {
const int seq_offset_start = 0; const int seq_offset_start = 0;
const int seq_offset_end = seq_length; const int seq_offset_end = seq_length;
BeamSearchDetails<MaxLength, MaxThreadsPerSeq, 1>( BeamSearchDetails<MaxLength, MaxThreadsPerSeq, 1>(
selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids, selected_ids, selected_scores, parent_idx, selected_offsets, pre_ids,
scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id, pre_scores, ids, scores, seq_offset_start, seq_offset_end, seq_width,
is_accumulated, num_used_threads); beam_size, end_id, is_accumulated, num_used_threads);
} }
static inline int GetNumUsedThreads(const int max_threads_per_seq, static inline int GetNumUsedThreads(const int max_threads_per_seq,
...@@ -300,8 +302,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -300,8 +302,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
const framework::LoDTensor* ids, const framework::LoDTensor* ids,
const framework::LoDTensor* scores, const framework::LoDTensor* scores,
framework::LoDTensor* selected_ids, framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores, size_t level, framework::LoDTensor* selected_scores,
size_t beam_size, int end_id, bool is_accumulated) { framework::Tensor* parent_idx, size_t level, size_t beam_size,
int end_id, bool is_accumulated) {
auto abs_lod = framework::ToAbsOffset(scores->lod()); auto abs_lod = framework::ToAbsOffset(scores->lod());
const int64_t* pre_ids_data = pre_ids->data<int64_t>(); const int64_t* pre_ids_data = pre_ids->data<int64_t>();
...@@ -322,6 +325,8 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -322,6 +325,8 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace()); selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
float* selected_scores_data = float* selected_scores_data =
selected_scores->mutable_data<float>(selected_dims, context.GetPlace()); selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
int* parent_idx_data = parent_idx->mutable_data<int>(
{static_cast<int64_t>(num_seqs * beam_size)}, context.GetPlace());
framework::LoD selected_lod(2); framework::LoD selected_lod(2);
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end()); selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
...@@ -339,9 +344,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -339,9 +344,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
CUDA_LAUNCH_KERNEL_HELPER( CUDA_LAUNCH_KERNEL_HELPER(
BeamSearchKernelSingle<kPowerOfTwoDim, kMaxThreadsPerSeq><<< BeamSearchKernelSingle<kPowerOfTwoDim, kMaxThreadsPerSeq><<<
1, kMaxThreadsPerSeq, 0, context.stream()>>>( 1, kMaxThreadsPerSeq, 0, context.stream()>>>(
selected_ids_data, selected_scores_data, selected_offsets, selected_ids_data, selected_scores_data, parent_idx_data,
pre_ids_data, pre_scores_data, ids_data, scores_data, selected_offsets, pre_ids_data, pre_scores_data, ids_data,
seq_length, static_cast<int>(seq_width), scores_data, seq_length, static_cast<int>(seq_width),
static_cast<int>(beam_size), static_cast<int>(end_id), static_cast<int>(beam_size), static_cast<int>(end_id),
is_accumulated, num_used_threads)); is_accumulated, num_used_threads));
} }
...@@ -357,9 +362,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -357,9 +362,9 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
CUDA_LAUNCH_KERNEL_HELPER( CUDA_LAUNCH_KERNEL_HELPER(
BeamSearchKernel<kPowerOfTwoDim, kMaxThreadsPerSeq, kMaxSeqs><<< BeamSearchKernel<kPowerOfTwoDim, kMaxThreadsPerSeq, kMaxSeqs><<<
1, num_seqs * kMaxThreadsPerSeq, 0, context.stream()>>>( 1, num_seqs * kMaxThreadsPerSeq, 0, context.stream()>>>(
selected_ids_data, selected_scores_data, selected_offsets, selected_ids_data, selected_scores_data, parent_idx_data,
pre_ids_data, pre_scores_data, ids_data, scores_data, selected_offsets, pre_ids_data, pre_scores_data, ids_data,
seq_offsets, static_cast<int>(num_seqs), scores_data, seq_offsets, static_cast<int>(num_seqs),
static_cast<int>(seq_width), static_cast<int>(beam_size), static_cast<int>(seq_width), static_cast<int>(beam_size),
end_id, is_accumulated, num_used_threads)); end_id, is_accumulated, num_used_threads));
} }
...@@ -379,6 +384,7 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> { ...@@ -379,6 +384,7 @@ class BeamSearchFunctor<platform::CUDADeviceContext, T> {
{static_cast<int64_t>(selected_lod[1].back()), 1}); {static_cast<int64_t>(selected_lod[1].back()), 1});
selected_ids->Resize(final_selected_dims); selected_ids->Resize(final_selected_dims);
selected_scores->Resize(final_selected_dims); selected_scores->Resize(final_selected_dims);
parent_idx->Resize({static_cast<int64_t>(selected_lod[1].back())});
} }
} }
}; };
......
...@@ -104,14 +104,12 @@ class BeamSearchFunctor { ...@@ -104,14 +104,12 @@ class BeamSearchFunctor {
* Return false if all the input tensor is empty, in machine translation task * Return false if all the input tensor is empty, in machine translation task
* that means no candidates is provided, and the task will stop running. * that means no candidates is provided, and the task will stop running.
*/ */
void operator()(const DeviceContext& context, void operator()(
const framework::LoDTensor* pre_ids, const DeviceContext& context, const framework::LoDTensor* pre_ids,
const framework::LoDTensor* pre_scores, const framework::LoDTensor* pre_scores, const framework::LoDTensor* ids,
const framework::LoDTensor* ids, const framework::LoDTensor* scores, framework::LoDTensor* selected_ids,
const framework::LoDTensor* scores, framework::LoDTensor* selected_scores, framework::Tensor* parent_idx,
framework::LoDTensor* selected_ids, size_t level, size_t beam_size, int end_id, bool is_accumulated);
framework::LoDTensor* selected_scores, size_t level,
size_t beam_size, int end_id, bool is_accumulated);
}; };
} // namespace math } // namespace math
......
...@@ -93,13 +93,14 @@ void TestBeamSearch() { ...@@ -93,13 +93,14 @@ void TestBeamSearch() {
paddle::framework::LoDTensor selected_ids; paddle::framework::LoDTensor selected_ids;
paddle::framework::LoDTensor selected_scores; paddle::framework::LoDTensor selected_scores;
paddle::framework::LoDTensor parent_idx;
size_t level = 0; size_t level = 0;
size_t beam_size = 2; size_t beam_size = 2;
int end_id = 0; int end_id = 0;
paddle::operators::math::BeamSearchFunctor<DeviceContext, float> beamsearch; paddle::operators::math::BeamSearchFunctor<DeviceContext, float> beamsearch;
beamsearch(*context, &pre_ids, &pre_scores, &ids, &scores, &selected_ids, beamsearch(*context, &pre_ids, &pre_scores, &ids, &scores, &selected_ids,
&selected_scores, level, beam_size, end_id, true); &selected_scores, &parent_idx, level, beam_size, end_id, true);
ASSERT_EQ(selected_ids.lod(), selected_scores.lod()); ASSERT_EQ(selected_ids.lod(), selected_scores.lod());
......
...@@ -3877,7 +3877,8 @@ def beam_search(pre_ids, ...@@ -3877,7 +3877,8 @@ def beam_search(pre_ids,
end_id, end_id,
level=0, level=0,
is_accumulated=True, is_accumulated=True,
name=None): name=None,
return_parent_idx=False):
""" """
Beam search is a classical algorithm for selecting candidate words in a Beam search is a classical algorithm for selecting candidate words in a
machine translation task. machine translation task.
...@@ -3933,10 +3934,16 @@ def beam_search(pre_ids, ...@@ -3933,10 +3934,16 @@ def beam_search(pre_ids,
accumulated scores. accumulated scores.
name(str|None): A name for this layer(optional). If set None, the layer name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically. will be named automatically.
return_parent_idx(bool): Whether to return an extra Tensor variable
preserving the selected_ids' parent indice in pre_ids
in output, which can be used to gather cell states at
the next time step.
Returns: Returns:
Variable: The LodTensor pair containing the selected ids and the \ Variable: The LodTensor tuple containing the selected ids and the \
corresponding scores. corresponding scores. If :attr:`return_parent_idx` is :attr:`True`, \
an extra Tensor variable preserving the selected_ids' parent indice \
is included.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -3969,6 +3976,11 @@ def beam_search(pre_ids, ...@@ -3969,6 +3976,11 @@ def beam_search(pre_ids,
selected_scores = helper.create_variable_for_type_inference( selected_scores = helper.create_variable_for_type_inference(
dtype=score_type) dtype=score_type)
selected_ids = helper.create_variable_for_type_inference(dtype=id_type) selected_ids = helper.create_variable_for_type_inference(dtype=id_type)
# parent_idx is a tensor used to gather cell states at the next time
# step. Though lod in selected_ids can also be used to gather by
# sequence_expand, it is not efficient.
# gather_op's index input only supports int32 dtype currently
parent_idx = helper.create_variable_for_type_inference(dtype="int32")
helper.append_op( helper.append_op(
type='beam_search', type='beam_search',
...@@ -3976,6 +3988,7 @@ def beam_search(pre_ids, ...@@ -3976,6 +3988,7 @@ def beam_search(pre_ids,
outputs={ outputs={
'selected_ids': selected_ids, 'selected_ids': selected_ids,
'selected_scores': selected_scores, 'selected_scores': selected_scores,
'parent_idx': parent_idx
}, },
attrs={ attrs={
# TODO(ChunweiYan) to assure other value support # TODO(ChunweiYan) to assure other value support
...@@ -3984,7 +3997,9 @@ def beam_search(pre_ids, ...@@ -3984,7 +3997,9 @@ def beam_search(pre_ids,
'end_id': end_id, 'end_id': end_id,
'is_accumulated': is_accumulated, 'is_accumulated': is_accumulated,
}) })
if return_parent_idx:
return selected_ids, selected_scores, parent_idx
else:
return selected_ids, selected_scores return selected_ids, selected_scores
......
...@@ -38,6 +38,7 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -38,6 +38,7 @@ class BeamSearchOpTester(unittest.TestCase):
self._create_pre_ids() self._create_pre_ids()
self.scope.var('selected_ids') self.scope.var('selected_ids')
self.scope.var('selected_scores') self.scope.var('selected_scores')
self.scope.var('parent_idx')
def test_run(self): def test_run(self):
op = Operator( op = Operator(
...@@ -48,12 +49,14 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -48,12 +49,14 @@ class BeamSearchOpTester(unittest.TestCase):
scores='scores', scores='scores',
selected_ids='selected_ids', selected_ids='selected_ids',
selected_scores='selected_scores', selected_scores='selected_scores',
parent_idx='parent_idx',
level=0, level=0,
beam_size=2, beam_size=2,
end_id=0, ) end_id=0, )
op.run(self.scope, core.CPUPlace()) op.run(self.scope, core.CPUPlace())
selected_ids = self.scope.find_var("selected_ids").get_tensor() selected_ids = self.scope.find_var("selected_ids").get_tensor()
selected_scores = self.scope.find_var("selected_scores").get_tensor() selected_scores = self.scope.find_var("selected_scores").get_tensor()
parent_idx = self.scope.find_var("parent_idx").get_tensor()
self.assertTrue( self.assertTrue(
np.allclose( np.allclose(
np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis])) np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis]))
...@@ -62,6 +65,8 @@ class BeamSearchOpTester(unittest.TestCase): ...@@ -62,6 +65,8 @@ class BeamSearchOpTester(unittest.TestCase):
np.array(selected_scores), np.array(selected_scores),
np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis])) np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis]))
self.assertEqual(selected_ids.lod(), [[0, 2, 4], [0, 1, 2, 3, 4]]) self.assertEqual(selected_ids.lod(), [[0, 2, 4], [0, 1, 2, 3, 4]])
self.assertTrue(
np.allclose(np.array(parent_idx), np.array([0, 1, 2, 3])))
def _create_pre_ids(self): def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int64') np_data = np.array([[1, 2, 3, 4]], dtype='int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册