From 3760be067aaa5e0da250a4751866b368e9aea790 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Wed, 15 Sep 2021 16:00:42 +0800 Subject: [PATCH] [NPU] add beam_search npu op (#34860) * add beam_search npu op * fix CMakeList and add unittest * fix bug of beam search npu op * fix unittest * let input ids become int64 * set output ids to int64_t * delete check_dygraph * fix beam_width=1 --- paddle/fluid/operators/beam_search_op.h | 4 +- paddle/fluid/operators/beam_search_op_npu.cc | 24 + paddle/fluid/operators/math/CMakeLists.txt | 10 +- .../fluid/operators/math/beam_search_npu.cc | 538 ++++++++++++++++++ .../unittests/npu/test_beam_search_op_npu.py | 175 ++++++ .../tests/unittests/test_beam_search_op.py | 206 ++++++- 6 files changed, 944 insertions(+), 13 deletions(-) create mode 100644 paddle/fluid/operators/beam_search_op_npu.cc create mode 100644 paddle/fluid/operators/math/beam_search_npu.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_beam_search_op_npu.py diff --git a/paddle/fluid/operators/beam_search_op.h b/paddle/fluid/operators/beam_search_op.h index 708f3cd808e..a977f5a4c01 100644 --- a/paddle/fluid/operators/beam_search_op.h +++ b/paddle/fluid/operators/beam_search_op.h @@ -51,11 +51,11 @@ class BeamSearchOpKernel : public framework::OpKernel { PADDLE_ENFORCE_NOT_NULL( selected_ids, platform::errors::NotFound( - "Output(selected_scores) of BeamSearchOp is not found.")); + "Output(selected_ids) of BeamSearchOp is not found.")); PADDLE_ENFORCE_NOT_NULL( selected_scores, platform::errors::NotFound( - "Output(parent_idx) of BeamSearchOp is not found.")); + "Output(selected_scores) of BeamSearchOp is not found.")); math::BeamSearchFunctor alg; alg(context.template device_context(), pre_ids, pre_scores, diff --git a/paddle/fluid/operators/beam_search_op_npu.cc b/paddle/fluid/operators/beam_search_op_npu.cc new file mode 100644 index 00000000000..cae3d0e55fc --- /dev/null +++ b/paddle/fluid/operators/beam_search_op_npu.cc @@ -0,0 +1,24 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/beam_search_op.h" +#include "paddle/fluid/framework/op_registry.h" + +namespace ops = paddle::operators; +REGISTER_OP_NPU_KERNEL( + beam_search, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel, + ops::BeamSearchOpKernel); diff --git a/paddle/fluid/operators/math/CMakeLists.txt b/paddle/fluid/operators/math/CMakeLists.txt index a13fffe15cf..25cea2a6711 100644 --- a/paddle/fluid/operators/math/CMakeLists.txt +++ b/paddle/fluid/operators/math/CMakeLists.txt @@ -39,6 +39,10 @@ function(math_library TARGET) endif() endfunction() +if (WITH_ASCEND_CL) + cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner) +endif() + # please add new math_library in alphabetical order math_library(concat_and_split) math_library(context_project DEPS im2col math_function) @@ -68,7 +72,11 @@ math_library(sequence_padding) math_library(sequence_pooling DEPS math_function jit_kernel_helper) math_library(sequence_scale) math_library(softmax DEPS math_function jit_kernel_helper) -math_library(beam_search DEPS math_function) +if (WITH_ASCEND_CL) + math_library(beam_search DEPS math_function beam_search_npu) +else() + math_library(beam_search DEPS math_function) +endif() math_library(fc DEPS blas) math_library(matrix_bit_code) diff --git a/paddle/fluid/operators/math/beam_search_npu.cc b/paddle/fluid/operators/math/beam_search_npu.cc new file mode 100644 index 00000000000..891822a2923 --- /dev/null +++ b/paddle/fluid/operators/math/beam_search_npu.cc @@ -0,0 +1,538 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/math/beam_search.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace framework { +class LoDTensor; +class Tensor; +} // namespace framework +namespace platform { +class NPUDeviceContext; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { +namespace math { + +template +class BeamSearchFunctor { + public: + void operator()(const platform::NPUDeviceContext& ctx, + const framework::LoDTensor* pre_ids, + const framework::LoDTensor* pre_scores, + const framework::LoDTensor* ids, + const framework::LoDTensor* scores, + framework::LoDTensor* selected_ids, + framework::LoDTensor* selected_scores, + framework::Tensor* parent_idx, size_t level, size_t beam_size, + int end_id, bool is_accumulated) { + auto abs_lod = framework::ToAbsOffset(scores->lod()); + auto& high_level = abs_lod[level]; + + int64_t num_seqs = scores->NumElements(level); + // size of the first beam is 1, others are equal to beam_size + int64_t real_beam_size = static_cast(scores->dims()[0] / num_seqs); + // K + int64_t seq_width = 1; + for (int i = 1; i < scores->dims().size(); i++) { + seq_width *= scores->dims()[i]; + } + + auto place = ctx.GetPlace(); + auto stream = ctx.stream(); + + int64_t total_length = num_seqs * beam_size; + int64_t batch_size = static_cast(scores->dims()[0]); + selected_ids->mutable_data(framework::make_ddim({total_length, 1}), + place); + selected_scores->mutable_data( + framework::make_ddim({total_length, 1}), place); + parent_idx->mutable_data(framework::make_ddim({total_length}), + place); + + // Step1: Define Tensors and Preprocess the situation that pre_id == end_id + + // cast ids and pre_ids from int to float32 + Tensor ids_int32(framework::proto::VarType::INT32); + if (ids->type() != framework::proto::VarType::INT32) { + ids_int32.Resize(ids->dims()); + ids_int32.mutable_data(ctx.GetPlace()); + auto dst_dtype_ids_int32 = ConvertToNpuDtype(ids_int32.type()); + const auto& runner_ids_int32 = + NpuOpRunner("Cast", {*ids}, {ids_int32}, + {{"dst_type", static_cast(dst_dtype_ids_int32)}}); + runner_ids_int32.Run(stream); + } else { + ids_int32.ShareDataWith(*ids); + } + + Tensor pre_ids_int32(framework::proto::VarType::INT32); + if (pre_ids->type() != framework::proto::VarType::INT32) { + pre_ids_int32.Resize(pre_ids->dims()); + pre_ids_int32.mutable_data(ctx.GetPlace()); + auto dst_dtype_pre_ids_int32 = ConvertToNpuDtype(pre_ids_int32.type()); + const auto& runner_pre_ids_int32 = NpuOpRunner( + "Cast", {*pre_ids}, {pre_ids_int32}, + {{"dst_type", static_cast(dst_dtype_pre_ids_int32)}}); + runner_pre_ids_int32.Run(stream); + } else { + pre_ids_int32.ShareDataWith(*pre_ids); + } + + Tensor expand_pre_ids(pre_ids_int32.type()); + expand_pre_ids.Resize(framework::make_ddim({batch_size, seq_width})); + expand_pre_ids.mutable_data(place); + const auto& runner_tile_pre_ids = + NpuOpRunner("TileWithAxis", {pre_ids_int32}, {expand_pre_ids}, + {{"axis", 1}, {"tiles", seq_width}}); + runner_tile_pre_ids.Run(stream); + expand_pre_ids.Resize(ids_int32.dims()); + + Tensor expand_pre_scores(pre_scores->type()); + expand_pre_scores.Resize(framework::make_ddim({batch_size, seq_width})); + expand_pre_scores.mutable_data(place); + const auto& runner_tile_pre_scores = + NpuOpRunner("TileWithAxis", {*pre_scores}, {expand_pre_scores}, + {{"axis", 1}, {"tiles", seq_width}}); + runner_tile_pre_scores.Run(stream); + expand_pre_scores.Resize(scores->dims()); + + // End_id Tensors + Tensor end_id_tmp_tensor(framework::proto::VarType::INT32); + end_id_tmp_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&end_id_tmp_tensor, end_id); + + Tensor end_id_tensors(ids_int32.type()); + end_id_tensors.mutable_data(ids_int32.dims(), place); + const auto& runner_fill_end_id = + NpuOpRunner("FillD", {end_id_tmp_tensor}, {end_id_tensors}, + {{"dims", framework::vectorize(ids_int32.dims())}}); + runner_fill_end_id.Run(stream); + + // whether expand_pre_ids == end_ids? + Tensor equal_end_ids(framework::proto::VarType::BOOL); + equal_end_ids.mutable_data(ids_int32.dims(), place); + const auto& runner_equal_end_ids = NpuOpRunner( + "Equal", {expand_pre_ids, end_id_tensors}, {equal_end_ids}, {}); + runner_equal_end_ids.Run(stream); + + // construct a Tensor with dimension ids->dims(): + // [[False, True, True, True, ...], + // [False, True, True, True, ...], + // ...] + Tensor false_tmp_tensor(framework::proto::VarType::INT32); + false_tmp_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&false_tmp_tensor, static_cast(false)); + + Tensor first_pos_false_tensors(framework::proto::VarType::INT32); + first_pos_false_tensors.Resize(framework::make_ddim({batch_size, 1})); + first_pos_false_tensors.mutable_data(place); + std::vector fill_dims = {batch_size, 1}; + framework::NPUAttributeMap fill_attr = {{"dims", fill_dims}}; + const auto& runner_fill_false_tensors = NpuOpRunner( + "FillD", {false_tmp_tensor}, {first_pos_false_tensors}, fill_attr); + runner_fill_false_tensors.Run(stream); + + Tensor pos_tensors(framework::proto::VarType::INT32); + if (seq_width > 1) { + pos_tensors.Resize(framework::make_ddim({batch_size, seq_width})); + pos_tensors.mutable_data(place); + + Tensor true_tmp_tensor(framework::proto::VarType::INT32); + true_tmp_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&true_tmp_tensor, static_cast(true)); + + Tensor second_pos_true_tensors(framework::proto::VarType::INT32); + second_pos_true_tensors.Resize( + framework::make_ddim({batch_size, seq_width - 1})); + second_pos_true_tensors.mutable_data(place); + std::vector fill_dims2 = {batch_size, seq_width - 1}; + framework::NPUAttributeMap fill_attr2 = {{"dims", fill_dims2}}; + const auto& runner_fill_true_tensors = NpuOpRunner( + "FillD", {true_tmp_tensor}, {second_pos_true_tensors}, fill_attr2); + runner_fill_true_tensors.Run(stream); + + std::vector concat_inputs = {first_pos_false_tensors, + second_pos_true_tensors}; + std::vector concat_names = {"x0", "x1"}; + NpuOpRunner runner_concat_false_true{"ConcatD", + {concat_inputs}, + {pos_tensors}, + {{"concat_dim", 1}, {"N", 2}}}; + runner_concat_false_true.AddInputNames(concat_names); + runner_concat_false_true.Run(stream); + pos_tensors.Resize(ids_int32.dims()); + } else { + pos_tensors.ShareDataWith(first_pos_false_tensors); + } + + Tensor cast_pos_tensors_bool(framework::proto::VarType::BOOL); + cast_pos_tensors_bool.Resize(pos_tensors.dims()); + cast_pos_tensors_bool.mutable_data(ctx.GetPlace()); + auto dst_dtype = ConvertToNpuDtype(cast_pos_tensors_bool.type()); + const auto& runner_cast_pos_tensors = + NpuOpRunner("Cast", {pos_tensors}, {cast_pos_tensors_bool}, + {{"dst_type", static_cast(dst_dtype)}}); + runner_cast_pos_tensors.Run(stream); + + // if pre_ids == end_ids, save only one score, and others become -inf + // construct pre_ids == end_ids and save only one score + Tensor save_one_end_score(framework::proto::VarType::BOOL); + save_one_end_score.mutable_data(ids_int32.dims(), place); + const auto& runner_logical_and = + NpuOpRunner("LogicalAnd", {equal_end_ids, cast_pos_tensors_bool}, + {save_one_end_score}, {}); + runner_logical_and.Run(stream); + + // if save_one_end_score is True, set score to -inf + // define -Inf Tensors + Tensor ninf_tmp_tensor(scores->type()); + ninf_tmp_tensor.mutable_data({1}, ctx.GetPlace()); + float ninf_value = + static_cast(-std::numeric_limits::infinity()); + FillNpuTensorWithConstant(&ninf_tmp_tensor, ninf_value); + + Tensor ninf_tensors(scores->type()); + ninf_tensors.mutable_data(scores->dims(), place); + const auto& runner_fill_ninf = + NpuOpRunner("FillD", {ninf_tmp_tensor}, {ninf_tensors}, + {{"dims", framework::vectorize(scores->dims())}}); + runner_fill_ninf.Run(stream); + + // Step2: calculate topk scores + + // get scores used in topk op + Tensor tmp_scores(scores->type()); + tmp_scores.mutable_data(scores->dims(), place); + if (!is_accumulated) { + // if pre_id == end_id, cal_scores = pre_score, and id = end_id + // else, cal_score = pre_score + log(score) + + // calculate log(scores) + Tensor log_scores(scores->type()); + log_scores.mutable_data(scores->dims(), place); + + Tensor one(scores->type()); + one.mutable_data(scores->dims(), place); + const auto& runner_one = NpuOpRunner("OnesLike", {*scores}, {one}, {}); + runner_one.Run(stream); + + Tensor sub(scores->type()); + sub.mutable_data(scores->dims(), place); + const auto& runner_sub = NpuOpRunner("Sub", {*scores, one}, {sub}, {}); + runner_sub.Run(stream); + + const auto& runner_log_scores = + NpuOpRunner("Log1p", {sub}, {log_scores}, {}); + runner_log_scores.Run(stream); + + // tmp_scores = pre_score + log(scores) + const auto& runner_add_scores = + NpuOpRunner("Add", {log_scores, *pre_scores}, {tmp_scores}, {}); + runner_add_scores.Run(stream); + + // if pre_ids == end_ids, use pre_score rather than score + const auto& runner_select_equal_end_score = + NpuOpRunner("Select", {equal_end_ids, expand_pre_scores, tmp_scores}, + {tmp_scores}, {}); + runner_select_equal_end_score.Run(stream); + } else { + // if pre_ids == end_ids, use pre_score rather than score + const auto& runner_select_equal_end_score2 = + NpuOpRunner("Select", {equal_end_ids, expand_pre_scores, *scores}, + {tmp_scores}, {}); + runner_select_equal_end_score2.Run(stream); + } + + // if pre_ids == end_ids, save only one score, and others become -inf + Tensor cal_scores(scores->type()); + cal_scores.mutable_data(scores->dims(), place); + const auto& runner_select_inf_score = + NpuOpRunner("Select", {save_one_end_score, ninf_tensors, tmp_scores}, + {cal_scores}, {}); + runner_select_inf_score.Run(stream); + + // resize scores from [num_seqs * beam_size, K] to [num_seqs, beam_size * K] + // real_beam_size = 1 or beam_size + cal_scores.Resize( + framework::make_ddim({num_seqs, real_beam_size * seq_width})); + + Tensor topk_scores(scores->type()); + topk_scores.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + topk_scores.mutable_data(ctx.GetPlace()); + + Tensor tmp_indices(framework::proto::VarType::INT32); + tmp_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + tmp_indices.mutable_data(ctx.GetPlace()); + + // run topk op + NpuOpRunner runner_topk; + runner_topk.SetType("TopKV2") + .AddInput(cal_scores) + .AddInput(std::vector{static_cast(beam_size)}) + .AddOutput(topk_scores) + .AddOutput(tmp_indices) + .AddAttr("sorted", true) + .AddAttr("dim", -1) + .AddAttr("largest", true); + runner_topk.Run(stream); + + // cast tmp_indices from int to float32 for Sort op + Tensor cast_tmp_indices(framework::proto::VarType::FP32); + cast_tmp_indices.Resize(tmp_indices.dims()); + cast_tmp_indices.mutable_data(ctx.GetPlace()); + auto dst_dtype_tmp_indices_fp32 = + ConvertToNpuDtype(cast_tmp_indices.type()); + const auto& runner_cast_tmp_indices = NpuOpRunner( + "Cast", {tmp_indices}, {cast_tmp_indices}, + {{"dst_type", static_cast(dst_dtype_tmp_indices_fp32)}}); + runner_cast_tmp_indices.Run(stream); + + // sort tmp_indices + Tensor sorted_tmp_indices(framework::proto::VarType::FP32); + sorted_tmp_indices.Resize(tmp_indices.dims()); + sorted_tmp_indices.mutable_data(ctx.GetPlace()); + Tensor sorted_score_indices(framework::proto::VarType::INT32); + sorted_score_indices.Resize(tmp_indices.dims()); + sorted_score_indices.mutable_data(ctx.GetPlace()); + const auto& runner_sort_tmp_indices = NpuOpRunner( + "Sort", {cast_tmp_indices}, {sorted_tmp_indices, sorted_score_indices}, + {{"axis", 1}, {"descending", false}}); + runner_sort_tmp_indices.Run(stream); + + // cast sorted_tmp_indices from float32 to int + Tensor cast_sort_tmp_indices(framework::proto::VarType::INT32); + cast_sort_tmp_indices.Resize(sorted_tmp_indices.dims()); + cast_sort_tmp_indices.mutable_data(ctx.GetPlace()); + auto dst_dtype_tmp_indices_int32 = + ConvertToNpuDtype(cast_sort_tmp_indices.type()); + const auto& runner_cast_sort_tmp_indices = NpuOpRunner( + "Cast", {sorted_tmp_indices}, {cast_sort_tmp_indices}, + {{"dst_type", static_cast(dst_dtype_tmp_indices_int32)}}); + runner_cast_sort_tmp_indices.Run(stream); + + // Step 3: infer selected ids from tmp_indices and ids + + // if pre_ids == end_ids, use pre_ids rather than ids + Tensor cal_ids(ids_int32.type()); + cal_ids.mutable_data(ids_int32.dims(), place); + const auto& runner_select_equal_end_id = NpuOpRunner( + "Select", {equal_end_ids, expand_pre_ids, ids_int32}, {cal_ids}, {}); + runner_select_equal_end_id.Run(stream); + + // resize ids from [num_seqs * real_beam_size, K] to [num_seqs, + // real_beam_size * K] + // real_beam_size = 1 or beam_size + cal_ids.Resize( + framework::make_ddim({num_seqs, real_beam_size * seq_width})); + + // construct batch_ids like [[0, 0, 0], [1, 1, 1], ..., [bs-1, bs-1, bs-1]] + // construct arange(num_seqs*beam_size).reshape((num_seqs, beam_size)) // + // beam_size + Tensor batch_ids(framework::proto::VarType::INT32); + batch_ids.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 1})); + batch_ids.mutable_data(place); + + std::vector vector_batch_ids; + for (int i = 0; i < num_seqs * static_cast(beam_size); ++i) { + vector_batch_ids.push_back(static_cast(i / beam_size)); + } + framework::TensorFromVector(vector_batch_ids, ctx, &batch_ids); + batch_ids.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 1})); + + // sort topk_scores to get selected_scores + // get indices of gather_nd op for calculating selected_scores + Tensor gather_nd_score_indices(framework::proto::VarType::INT32); + gather_nd_score_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 2})); + gather_nd_score_indices.mutable_data(place); + + sorted_score_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 1})); + std::vector concat_inputs2 = {batch_ids, + sorted_score_indices}; + std::vector concat_names = {"x0", "x1"}; + NpuOpRunner runner_concat_score_indices{"ConcatD", + {concat_inputs2}, + {gather_nd_score_indices}, + {{"concat_dim", 2}, {"N", 2}}}; + runner_concat_score_indices.AddInputNames(concat_names); + runner_concat_score_indices.Run(stream); + + // use gather_nd to get selected_scores + const auto& runner_gather_nd_scores = + NpuOpRunner("GatherNd", {topk_scores, gather_nd_score_indices}, + {*selected_scores}, {}); + runner_gather_nd_scores.Run(stream); + + // get indices of gather_nd op + cast_sort_tmp_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 1})); + Tensor gather_nd_id_indices(framework::proto::VarType::INT32); + gather_nd_id_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size), 2})); + gather_nd_id_indices.mutable_data(place); + + std::vector concat_inputs3 = {batch_ids, + cast_sort_tmp_indices}; + NpuOpRunner runner_concat_id_indices{"ConcatD", + {concat_inputs3}, + {gather_nd_id_indices}, + {{"concat_dim", 2}, {"N", 2}}}; + runner_concat_id_indices.AddInputNames(concat_names); + runner_concat_id_indices.Run(stream); + + // use gather_nd to get selected_ids + Tensor topk_ids(framework::proto::VarType::INT32); + topk_ids.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + topk_ids.mutable_data(ctx.GetPlace()); + + const auto& runner_gather_nd_ids = NpuOpRunner( + "GatherNd", {cal_ids, gather_nd_id_indices}, {topk_ids}, {}); + runner_gather_nd_ids.Run(stream); + + // cast topk_ids from int to int64 to get selected_ids + auto dst_dtype_selected_ids = ConvertToNpuDtype(selected_ids->type()); + const auto& runner_cast_selected_ids = + NpuOpRunner("Cast", {topk_ids}, {*selected_ids}, + {{"dst_type", static_cast(dst_dtype_selected_ids)}}); + runner_cast_selected_ids.Run(stream); + + // TODO(pangyoki): PruneEndBeams + + // Step 4: set lod of output Tensor + // define Tensor with value `seq_width` + Tensor seq_width_tensor(framework::proto::VarType::INT32); + seq_width_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&seq_width_tensor, + static_cast(seq_width)); + + // beam_ids = tmp_indices // seq_width + Tensor beam_ids(framework::proto::VarType::INT32); + beam_ids.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + beam_ids.mutable_data(ctx.GetPlace()); + cast_sort_tmp_indices.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + + const auto& runner_div = NpuOpRunner( + "Div", {cast_sort_tmp_indices, seq_width_tensor}, {beam_ids}, {}); + runner_div.Run(stream); + + // get parent_idx by adding batch_ids to beam_ids + // construct scale_batch_ids like [[0, 0, 0], [bw, bw, bw], ..., [bs-1*bw, + // bs-1*bw, bs-1*bw]] + batch_ids.Resize( + framework::make_ddim({num_seqs, static_cast(beam_size)})); + + // cast batch_ids from int to float32 + Tensor cast_batch_ids(framework::proto::VarType::FP32); + cast_batch_ids.Resize(batch_ids.dims()); + cast_batch_ids.mutable_data(ctx.GetPlace()); + auto dst_dtype1 = ConvertToNpuDtype(cast_batch_ids.type()); + const auto& runner_cast_batch_ids = + NpuOpRunner("Cast", {batch_ids}, {cast_batch_ids}, + {{"dst_type", static_cast(dst_dtype1)}}); + runner_cast_batch_ids.Run(stream); + + // scale batch_ids with beam_size + Tensor scale_batch_ids(framework::proto::VarType::FP32); + scale_batch_ids.Resize(batch_ids.dims()); + scale_batch_ids.mutable_data(place); + const auto& runner_power = + NpuOpRunner("Power", {cast_batch_ids}, {scale_batch_ids}, + {{"power", static_cast(1.0)}, + {"scale", static_cast(beam_size)}, + {"shift", static_cast(0.0)}}); + runner_power.Run(stream); + + // cast cast_scale_batch_ids from float32 to int + Tensor cast_scale_batch_ids(framework::proto::VarType::INT32); + cast_scale_batch_ids.Resize(scale_batch_ids.dims()); + cast_scale_batch_ids.mutable_data(ctx.GetPlace()); + auto dst_dtype2 = ConvertToNpuDtype(cast_scale_batch_ids.type()); + const auto& runner_cast_scale_batch_ids = + NpuOpRunner("Cast", {scale_batch_ids}, {cast_scale_batch_ids}, + {{"dst_type", static_cast(dst_dtype2)}}); + runner_cast_scale_batch_ids.Run(stream); + + // calculate parent_idx + Tensor tmp_parent_idx(framework::proto::VarType::INT32); + tmp_parent_idx.Resize(parent_idx->dims()); + tmp_parent_idx.mutable_data(place); + const auto& runner_add_beam_id = NpuOpRunner( + "Add", {beam_ids, cast_scale_batch_ids}, {tmp_parent_idx}, {}); + runner_add_beam_id.Run(stream); + + // cast tmp_parent_idx from int to int64 to get parent_idx + auto dst_dtype_parent_idx = ConvertToNpuDtype(parent_idx->type()); + const auto& runner_cast_parent_idx = + NpuOpRunner("Cast", {tmp_parent_idx}, {*parent_idx}, + {{"dst_type", static_cast(dst_dtype_parent_idx)}}); + runner_cast_parent_idx.Run(stream); + + std::vector vector_parent_idx; + framework::TensorToVector(tmp_parent_idx, ctx, &vector_parent_idx); + + // set low level, len(low_level) = high_level[-1] + std::vector low_level; + std::vector num_parent_ids(num_seqs * beam_size, + static_cast(0)); + size_t low_level_size = high_level[num_seqs]; + size_t sum_parent_id = 0; + + // calculate number of every parent_id + for (size_t i = 0; i < num_seqs * beam_size; ++i) { + num_parent_ids[vector_parent_idx[i]]++; + } + + // update low_level + low_level.push_back(0); + for (size_t i = 0; i < low_level_size; ++i) { + sum_parent_id += num_parent_ids[i]; + low_level.push_back(sum_parent_id); + } + + // fill lod + framework::LoD lod(2); + lod[0].assign(high_level.begin(), high_level.end()); + lod[1].assign(low_level.begin(), low_level.end()); + if (!framework::CheckLoD(lod)) { + PADDLE_THROW(platform::errors::InvalidArgument( + "lod %s is not right in" + " beam_search, please check your code.", + framework::LoDToString(lod))); + } + selected_ids->set_lod(lod); + selected_scores->set_lod(lod); + } +}; + +template class BeamSearchFunctor; +template class BeamSearchFunctor; +template class BeamSearchFunctor; +template class BeamSearchFunctor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/npu/test_beam_search_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_beam_search_op_npu.py new file mode 100644 index 00000000000..14e4fbb73fd --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_beam_search_op_npu.py @@ -0,0 +1,175 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle +import sys +sys.path.append("..") +from op_test import OpTest +import unittest +import numpy as np +import paddle.fluid as fluid + +paddle.enable_static() + + +class TestBeamSearchNPUOp(OpTest): + def setUp(self): + self.set_npu() + self.place = paddle.NPUPlace(0) + self.op_type = "beam_search" + self.init_data() + self.inputs = { + 'pre_ids': (self.pre_ids, self.lod), + 'pre_scores': (self.pre_score, self.lod), + 'ids': (self.ids, self.lod), + 'scores': (self.score, self.lod) + } + # The `target_lod` attribute is still based on offset + self.attrs = { + 'level': 0, + 'beam_size': self.beam_size, + 'end_id': 0, + 'is_accumulated': self.is_accumulated + } + self.outputs = { + 'selected_ids': (self.selected_ids, self.out_lod), + 'selected_scores': (self.selected_scores, self.out_lod), + 'parent_idx': self.parent_idx + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_data(self): + self.beam_size = 2 + self.is_accumulated = True + self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64') + self.ids = np.array( + [[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64') + self.lod = [[2, 2], [1, 1, 1, 1]] + self.out_lod = [[2, 2], [1, 1, 1, 1]] + self.offset_lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + self.score = np.array( + [ + [0.5, 0.3, 0.2], + [0.6, 0.3, 0.1], + [0.9, 0.5, 0.1], + [0.7, 0.5, 0.1], + ], + dtype='float32') + self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32') + self.selected_ids = np.array([4, 2, 3, 8])[:, np.newaxis] + self.selected_scores = np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis] + self.parent_idx = np.array([0, 1, 2, 3]) + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-3) + + +class TestBeamSearchNPUOp2(TestBeamSearchNPUOp): + def init_data(self): + self.beam_size = 2 + self.is_accumulated = True + self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64') + self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + self.lod = [[2, 2], [1, 1, 1, 1]] + self.out_lod = [[2, 2], [2, 0, 1, 1]] + self.offset_lod = [[0, 2, 4], [0, 2, 2, 3, 4]] + self.score = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32') + self.selected_ids = np.array([4, 2, 3, 1])[:, np.newaxis] + self.selected_scores = np.array([0.6, 0.9, 0.9, 0.7])[:, np.newaxis] + self.parent_idx = np.array([0, 0, 2, 3]) + + +class TestBeamSearchNPUOp3(TestBeamSearchNPUOp): + def init_data(self): + # end_id = 0 + self.beam_size = 2 + self.is_accumulated = True + self.pre_ids = np.array([[1], [0], [0], [4]], dtype='int64') + self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + self.lod = [[2, 2], [1, 1, 1, 1]] + self.out_lod = [[2, 2], [1, 1, 0, 2]] + self.offset_lod = [[0, 2, 4], [0, 1, 2, 2, 4]] + self.score = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.6, 0.7], + ], dtype='float32') + self.pre_score = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32') + self.selected_ids = np.array([2, 0, 8, 1])[:, np.newaxis] + self.selected_scores = np.array([0.9, 1.2, 0.6, 0.7])[:, np.newaxis] + self.parent_idx = np.array([0, 1, 3, 3]) + + +class TestBeamSearchNPUOp4(TestBeamSearchNPUOp): + def init_data(self): + # is_accumulated = False + self.beam_size = 2 + self.is_accumulated = False + self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64') + self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + self.lod = [[2, 2], [1, 1, 1, 1]] + self.out_lod = [[2, 2], [0, 2, 1, 1]] + self.offset_lod = [[0, 2, 4], [0, 0, 2, 3, 4]] + self.score = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + self.pre_score = np.array([[0.1], [2.2], [0.3], [0.4]], dtype='float32') + self.selected_ids = np.array([7, 3, 3, 1])[:, np.newaxis] + self.selected_scores = np.array( + [1.50685, 0.996027, 0.194639, 0.043325])[:, np.newaxis] + self.parent_idx = np.array([1, 1, 2, 3]) + + +class TestBeamSearchNPUOp5(TestBeamSearchNPUOp): + def init_data(self): + # beam_size = 1 + self.beam_size = 1 + self.is_accumulated = True + self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64') + self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + self.lod = [[1, 1, 1, 1], [1, 1, 1, 1]] + self.out_lod = [[1, 1, 1, 1], [1, 1, 1, 1]] + self.offset_lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + self.score = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32') + self.selected_ids = np.array([2, 7, 3, 1])[:, np.newaxis] + self.selected_scores = np.array([0.9, 0.5, 0.9, 0.7])[:, np.newaxis] + self.parent_idx = np.array([0, 1, 2, 3]) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_beam_search_op.py b/python/paddle/fluid/tests/unittests/test_beam_search_op.py index 346cd1e2129..99ca5779a69 100644 --- a/python/paddle/fluid/tests/unittests/test_beam_search_op.py +++ b/python/paddle/fluid/tests/unittests/test_beam_search_op.py @@ -38,6 +38,7 @@ class BeamSearchOpTester(unittest.TestCase): self._create_pre_scores() self._create_scores() self._create_pre_ids() + self.set_outputs() self.scope.var('selected_ids').get_tensor() self.scope.var('selected_scores').get_tensor() self.scope.var('parent_idx').get_tensor() @@ -53,22 +54,19 @@ class BeamSearchOpTester(unittest.TestCase): selected_scores='selected_scores', parent_idx='parent_idx', level=0, - beam_size=2, - end_id=0, ) + beam_size=self.beam_size, + end_id=0, + is_accumulated=self.is_accumulated) op.run(self.scope, core.CPUPlace()) selected_ids = self.scope.find_var("selected_ids").get_tensor() selected_scores = self.scope.find_var("selected_scores").get_tensor() parent_idx = self.scope.find_var("parent_idx").get_tensor() + self.assertTrue(np.allclose(np.array(selected_ids), self.output_ids)) self.assertTrue( - np.allclose( - np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis])) + np.allclose(np.array(selected_scores), self.output_scores)) + self.assertEqual(selected_ids.lod(), self.output_lod) self.assertTrue( - np.allclose( - np.array(selected_scores), - np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis])) - self.assertEqual(selected_ids.lod(), [[0, 2, 4], [0, 1, 2, 3, 4]]) - self.assertTrue( - np.allclose(np.array(parent_idx), np.array([0, 1, 2, 3]))) + np.allclose(np.array(parent_idx), self.output_parent_idx)) def _create_pre_ids(self): np_data = np.array([[1, 2, 3, 4]], dtype='int64') @@ -97,6 +95,194 @@ class BeamSearchOpTester(unittest.TestCase): tensor = create_tensor(self.scope, "scores", np_data) tensor.set_lod(self.lod) + def set_outputs(self): + self.beam_size = 2 + self.is_accumulated = True + self.output_ids = np.array([4, 2, 3, 8])[:, np.newaxis] + self.output_scores = np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis] + self.output_lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + self.output_parent_idx = np.array([0, 1, 2, 3]) + + +class BeamSearchOpTester2(BeamSearchOpTester): + def _create_pre_ids(self): + np_data = np.array([[1], [2], [3], [4]], dtype='int64') + tensor = create_tensor(self.scope, 'pre_ids', np_data) + + def _create_pre_scores(self): + np_data = np.array([[0.1, 0.2, 0.3, 0.4]], dtype='float32') + tensor = create_tensor(self.scope, 'pre_scores', np_data) + + def _create_ids(self): + self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + def set_outputs(self): + self.beam_size = 2 + self.is_accumulated = True + self.output_ids = np.array([2, 4, 3, 1])[:, np.newaxis] + self.output_scores = np.array([0.9, 0.6, 0.9, 0.7])[:, np.newaxis] + self.output_lod = [[0, 2, 4], [0, 2, 2, 3, 4]] + self.output_parent_idx = np.array([0, 0, 2, 3]) + + +class BeamSearchOpTester3(BeamSearchOpTester): + # pre_id = end_id + def _create_pre_ids(self): + np_data = np.array([[1], [0], [0], [4]], dtype='int64') + tensor = create_tensor(self.scope, 'pre_ids', np_data) + + def _create_pre_scores(self): + np_data = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32') + tensor = create_tensor(self.scope, 'pre_scores', np_data) + + def _create_ids(self): + self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.6, 0.7], + ], dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + def set_outputs(self): + self.beam_size = 2 + self.is_accumulated = True + self.output_ids = np.array([2, 0, 1, 8])[:, np.newaxis] + self.output_scores = np.array([0.9, 1.2, 0.7, 0.6])[:, np.newaxis] + self.output_lod = [[0, 2, 4], [0, 1, 2, 2, 4]] + self.output_parent_idx = np.array([0, 1, 3, 3]) + + +class BeamSearchOpTester4(BeamSearchOpTester): + # prune beam search while pre_id of in all beams is end_id + def _create_pre_ids(self): + np_data = np.array([[0], [0], [0], [4]], dtype='int64') + tensor = create_tensor(self.scope, 'pre_ids', np_data) + + def _create_pre_scores(self): + np_data = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32') + tensor = create_tensor(self.scope, 'pre_scores', np_data) + + def _create_ids(self): + self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.6, 0.7], + ], dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + def set_outputs(self): + self.beam_size = 2 + self.is_accumulated = True + self.output_ids = np.array([1, 8])[:, np.newaxis] + self.output_scores = np.array([0.7, 0.6])[:, np.newaxis] + self.output_lod = [[0, 2, 4], [0, 0, 0, 0, 2]] + self.output_parent_idx = np.array([3, 3]) + + +class BeamSearchOpTester5(BeamSearchOpTester): + # is_accumulated = False + def _create_pre_ids(self): + np_data = np.array([[1], [2], [3], [4]], dtype='int64') + tensor = create_tensor(self.scope, 'pre_ids', np_data) + + def _create_pre_scores(self): + np_data = np.array([[0.1, 2.2, 0.3, 0.4]], dtype='float32') + tensor = create_tensor(self.scope, 'pre_scores', np_data) + + def _create_ids(self): + self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]] + np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + def set_outputs(self): + self.beam_size = 2 + self.is_accumulated = False + self.output_ids = np.array([7, 3, 3, 1])[:, np.newaxis] + self.output_scores = np.array( + [1.50685, 0.996027, 0.194639, 0.043325])[:, np.newaxis] + self.output_lod = [[0, 2, 4], [0, 0, 2, 3, 4]] + self.output_parent_idx = np.array([1, 1, 2, 3]) + + +class BeamSearchOpTester6(BeamSearchOpTester): + # beam_size = 1 + def _create_pre_ids(self): + np_data = np.array([[1], [2], [3], [4]], dtype='int64') + tensor = create_tensor(self.scope, 'pre_ids', np_data) + + def _create_pre_scores(self): + np_data = np.array([[0.1, 0.2, 0.3, 0.4]], dtype='float32') + tensor = create_tensor(self.scope, 'pre_scores', np_data) + + def _create_ids(self): + self.lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64') + tensor = create_tensor(self.scope, "ids", np_data) + tensor.set_lod(self.lod) + + def _create_scores(self): + np_data = np.array( + [ + [0.6, 0.9], + [0.5, 0.3], + [0.9, 0.5], + [0.1, 0.7], + ], dtype='float32') + tensor = create_tensor(self.scope, "scores", np_data) + tensor.set_lod(self.lod) + + def set_outputs(self): + self.beam_size = 1 + self.is_accumulated = True + self.output_ids = np.array([2, 7, 3, 1])[:, np.newaxis] + self.output_scores = np.array([0.9, 0.5, 0.9, 0.7])[:, np.newaxis] + self.output_lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]] + self.output_parent_idx = np.array([0, 1, 2, 3]) + class TestBeamSearchOpError(unittest.TestCase): def test_errors(self): -- GitLab