未验证 提交 3760be06 编写于 作者: P pangyoki 提交者: GitHub

[NPU] add beam_search npu op (#34860)

* add beam_search npu op

* fix CMakeList and add unittest

* fix bug of beam search npu op

* fix unittest

* let input ids become int64

* set output ids to int64_t

* delete check_dygraph

* fix beam_width=1
上级 9f588cc2
......@@ -51,11 +51,11 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_NOT_NULL(
selected_ids,
platform::errors::NotFound(
"Output(selected_scores) of BeamSearchOp is not found."));
"Output(selected_ids) of BeamSearchOp is not found."));
PADDLE_ENFORCE_NOT_NULL(
selected_scores,
platform::errors::NotFound(
"Output(parent_idx) of BeamSearchOp is not found."));
"Output(selected_scores) of BeamSearchOp is not found."));
math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::NPUDeviceContext, int64_t>);
......@@ -39,6 +39,10 @@ function(math_library TARGET)
endif()
endfunction()
if (WITH_ASCEND_CL)
cc_library(beam_search_npu SRCS beam_search_npu.cc DEPS npu_op_runner)
endif()
# please add new math_library in alphabetical order
math_library(concat_and_split)
math_library(context_project DEPS im2col math_function)
......@@ -68,7 +72,11 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function jit_kernel_helper)
math_library(beam_search DEPS math_function)
if (WITH_ASCEND_CL)
math_library(beam_search DEPS math_function beam_search_npu)
else()
math_library(beam_search DEPS math_function)
endif()
math_library(fc DEPS blas)
math_library(matrix_bit_code)
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace framework {
class LoDTensor;
class Tensor;
} // namespace framework
namespace platform {
class NPUDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class BeamSearchFunctor<platform::NPUDeviceContext, T> {
public:
void operator()(const platform::NPUDeviceContext& ctx,
const framework::LoDTensor* pre_ids,
const framework::LoDTensor* pre_scores,
const framework::LoDTensor* ids,
const framework::LoDTensor* scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores,
framework::Tensor* parent_idx, size_t level, size_t beam_size,
int end_id, bool is_accumulated) {
auto abs_lod = framework::ToAbsOffset(scores->lod());
auto& high_level = abs_lod[level];
int64_t num_seqs = scores->NumElements(level);
// size of the first beam is 1, others are equal to beam_size
int64_t real_beam_size = static_cast<int64_t>(scores->dims()[0] / num_seqs);
// K
int64_t seq_width = 1;
for (int i = 1; i < scores->dims().size(); i++) {
seq_width *= scores->dims()[i];
}
auto place = ctx.GetPlace();
auto stream = ctx.stream();
int64_t total_length = num_seqs * beam_size;
int64_t batch_size = static_cast<int64_t>(scores->dims()[0]);
selected_ids->mutable_data<int64_t>(framework::make_ddim({total_length, 1}),
place);
selected_scores->mutable_data<float>(
framework::make_ddim({total_length, 1}), place);
parent_idx->mutable_data<int64_t>(framework::make_ddim({total_length}),
place);
// Step1: Define Tensors and Preprocess the situation that pre_id == end_id
// cast ids and pre_ids from int to float32
Tensor ids_int32(framework::proto::VarType::INT32);
if (ids->type() != framework::proto::VarType::INT32) {
ids_int32.Resize(ids->dims());
ids_int32.mutable_data<int>(ctx.GetPlace());
auto dst_dtype_ids_int32 = ConvertToNpuDtype(ids_int32.type());
const auto& runner_ids_int32 =
NpuOpRunner("Cast", {*ids}, {ids_int32},
{{"dst_type", static_cast<int>(dst_dtype_ids_int32)}});
runner_ids_int32.Run(stream);
} else {
ids_int32.ShareDataWith(*ids);
}
Tensor pre_ids_int32(framework::proto::VarType::INT32);
if (pre_ids->type() != framework::proto::VarType::INT32) {
pre_ids_int32.Resize(pre_ids->dims());
pre_ids_int32.mutable_data<int>(ctx.GetPlace());
auto dst_dtype_pre_ids_int32 = ConvertToNpuDtype(pre_ids_int32.type());
const auto& runner_pre_ids_int32 = NpuOpRunner(
"Cast", {*pre_ids}, {pre_ids_int32},
{{"dst_type", static_cast<int>(dst_dtype_pre_ids_int32)}});
runner_pre_ids_int32.Run(stream);
} else {
pre_ids_int32.ShareDataWith(*pre_ids);
}
Tensor expand_pre_ids(pre_ids_int32.type());
expand_pre_ids.Resize(framework::make_ddim({batch_size, seq_width}));
expand_pre_ids.mutable_data<int>(place);
const auto& runner_tile_pre_ids =
NpuOpRunner("TileWithAxis", {pre_ids_int32}, {expand_pre_ids},
{{"axis", 1}, {"tiles", seq_width}});
runner_tile_pre_ids.Run(stream);
expand_pre_ids.Resize(ids_int32.dims());
Tensor expand_pre_scores(pre_scores->type());
expand_pre_scores.Resize(framework::make_ddim({batch_size, seq_width}));
expand_pre_scores.mutable_data<float>(place);
const auto& runner_tile_pre_scores =
NpuOpRunner("TileWithAxis", {*pre_scores}, {expand_pre_scores},
{{"axis", 1}, {"tiles", seq_width}});
runner_tile_pre_scores.Run(stream);
expand_pre_scores.Resize(scores->dims());
// End_id Tensors
Tensor end_id_tmp_tensor(framework::proto::VarType::INT32);
end_id_tmp_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&end_id_tmp_tensor, end_id);
Tensor end_id_tensors(ids_int32.type());
end_id_tensors.mutable_data<int>(ids_int32.dims(), place);
const auto& runner_fill_end_id =
NpuOpRunner("FillD", {end_id_tmp_tensor}, {end_id_tensors},
{{"dims", framework::vectorize(ids_int32.dims())}});
runner_fill_end_id.Run(stream);
// whether expand_pre_ids == end_ids?
Tensor equal_end_ids(framework::proto::VarType::BOOL);
equal_end_ids.mutable_data<bool>(ids_int32.dims(), place);
const auto& runner_equal_end_ids = NpuOpRunner(
"Equal", {expand_pre_ids, end_id_tensors}, {equal_end_ids}, {});
runner_equal_end_ids.Run(stream);
// construct a Tensor with dimension ids->dims():
// [[False, True, True, True, ...],
// [False, True, True, True, ...],
// ...]
Tensor false_tmp_tensor(framework::proto::VarType::INT32);
false_tmp_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&false_tmp_tensor, static_cast<int>(false));
Tensor first_pos_false_tensors(framework::proto::VarType::INT32);
first_pos_false_tensors.Resize(framework::make_ddim({batch_size, 1}));
first_pos_false_tensors.mutable_data<int>(place);
std::vector<int64_t> fill_dims = {batch_size, 1};
framework::NPUAttributeMap fill_attr = {{"dims", fill_dims}};
const auto& runner_fill_false_tensors = NpuOpRunner(
"FillD", {false_tmp_tensor}, {first_pos_false_tensors}, fill_attr);
runner_fill_false_tensors.Run(stream);
Tensor pos_tensors(framework::proto::VarType::INT32);
if (seq_width > 1) {
pos_tensors.Resize(framework::make_ddim({batch_size, seq_width}));
pos_tensors.mutable_data<int>(place);
Tensor true_tmp_tensor(framework::proto::VarType::INT32);
true_tmp_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&true_tmp_tensor, static_cast<int>(true));
Tensor second_pos_true_tensors(framework::proto::VarType::INT32);
second_pos_true_tensors.Resize(
framework::make_ddim({batch_size, seq_width - 1}));
second_pos_true_tensors.mutable_data<int>(place);
std::vector<int64_t> fill_dims2 = {batch_size, seq_width - 1};
framework::NPUAttributeMap fill_attr2 = {{"dims", fill_dims2}};
const auto& runner_fill_true_tensors = NpuOpRunner(
"FillD", {true_tmp_tensor}, {second_pos_true_tensors}, fill_attr2);
runner_fill_true_tensors.Run(stream);
std::vector<framework::Tensor> concat_inputs = {first_pos_false_tensors,
second_pos_true_tensors};
std::vector<std::string> concat_names = {"x0", "x1"};
NpuOpRunner runner_concat_false_true{"ConcatD",
{concat_inputs},
{pos_tensors},
{{"concat_dim", 1}, {"N", 2}}};
runner_concat_false_true.AddInputNames(concat_names);
runner_concat_false_true.Run(stream);
pos_tensors.Resize(ids_int32.dims());
} else {
pos_tensors.ShareDataWith(first_pos_false_tensors);
}
Tensor cast_pos_tensors_bool(framework::proto::VarType::BOOL);
cast_pos_tensors_bool.Resize(pos_tensors.dims());
cast_pos_tensors_bool.mutable_data<bool>(ctx.GetPlace());
auto dst_dtype = ConvertToNpuDtype(cast_pos_tensors_bool.type());
const auto& runner_cast_pos_tensors =
NpuOpRunner("Cast", {pos_tensors}, {cast_pos_tensors_bool},
{{"dst_type", static_cast<int>(dst_dtype)}});
runner_cast_pos_tensors.Run(stream);
// if pre_ids == end_ids, save only one score, and others become -inf
// construct pre_ids == end_ids and save only one score
Tensor save_one_end_score(framework::proto::VarType::BOOL);
save_one_end_score.mutable_data<bool>(ids_int32.dims(), place);
const auto& runner_logical_and =
NpuOpRunner("LogicalAnd", {equal_end_ids, cast_pos_tensors_bool},
{save_one_end_score}, {});
runner_logical_and.Run(stream);
// if save_one_end_score is True, set score to -inf
// define -Inf Tensors
Tensor ninf_tmp_tensor(scores->type());
ninf_tmp_tensor.mutable_data<float>({1}, ctx.GetPlace());
float ninf_value =
static_cast<float>(-std::numeric_limits<float>::infinity());
FillNpuTensorWithConstant<float>(&ninf_tmp_tensor, ninf_value);
Tensor ninf_tensors(scores->type());
ninf_tensors.mutable_data<float>(scores->dims(), place);
const auto& runner_fill_ninf =
NpuOpRunner("FillD", {ninf_tmp_tensor}, {ninf_tensors},
{{"dims", framework::vectorize(scores->dims())}});
runner_fill_ninf.Run(stream);
// Step2: calculate topk scores
// get scores used in topk op
Tensor tmp_scores(scores->type());
tmp_scores.mutable_data<float>(scores->dims(), place);
if (!is_accumulated) {
// if pre_id == end_id, cal_scores = pre_score, and id = end_id
// else, cal_score = pre_score + log(score)
// calculate log(scores)
Tensor log_scores(scores->type());
log_scores.mutable_data<float>(scores->dims(), place);
Tensor one(scores->type());
one.mutable_data<float>(scores->dims(), place);
const auto& runner_one = NpuOpRunner("OnesLike", {*scores}, {one}, {});
runner_one.Run(stream);
Tensor sub(scores->type());
sub.mutable_data<float>(scores->dims(), place);
const auto& runner_sub = NpuOpRunner("Sub", {*scores, one}, {sub}, {});
runner_sub.Run(stream);
const auto& runner_log_scores =
NpuOpRunner("Log1p", {sub}, {log_scores}, {});
runner_log_scores.Run(stream);
// tmp_scores = pre_score + log(scores)
const auto& runner_add_scores =
NpuOpRunner("Add", {log_scores, *pre_scores}, {tmp_scores}, {});
runner_add_scores.Run(stream);
// if pre_ids == end_ids, use pre_score rather than score
const auto& runner_select_equal_end_score =
NpuOpRunner("Select", {equal_end_ids, expand_pre_scores, tmp_scores},
{tmp_scores}, {});
runner_select_equal_end_score.Run(stream);
} else {
// if pre_ids == end_ids, use pre_score rather than score
const auto& runner_select_equal_end_score2 =
NpuOpRunner("Select", {equal_end_ids, expand_pre_scores, *scores},
{tmp_scores}, {});
runner_select_equal_end_score2.Run(stream);
}
// if pre_ids == end_ids, save only one score, and others become -inf
Tensor cal_scores(scores->type());
cal_scores.mutable_data<float>(scores->dims(), place);
const auto& runner_select_inf_score =
NpuOpRunner("Select", {save_one_end_score, ninf_tensors, tmp_scores},
{cal_scores}, {});
runner_select_inf_score.Run(stream);
// resize scores from [num_seqs * beam_size, K] to [num_seqs, beam_size * K]
// real_beam_size = 1 or beam_size
cal_scores.Resize(
framework::make_ddim({num_seqs, real_beam_size * seq_width}));
Tensor topk_scores(scores->type());
topk_scores.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
topk_scores.mutable_data<float>(ctx.GetPlace());
Tensor tmp_indices(framework::proto::VarType::INT32);
tmp_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
tmp_indices.mutable_data<int>(ctx.GetPlace());
// run topk op
NpuOpRunner runner_topk;
runner_topk.SetType("TopKV2")
.AddInput(cal_scores)
.AddInput(std::vector<int>{static_cast<int>(beam_size)})
.AddOutput(topk_scores)
.AddOutput(tmp_indices)
.AddAttr("sorted", true)
.AddAttr("dim", -1)
.AddAttr("largest", true);
runner_topk.Run(stream);
// cast tmp_indices from int to float32 for Sort op
Tensor cast_tmp_indices(framework::proto::VarType::FP32);
cast_tmp_indices.Resize(tmp_indices.dims());
cast_tmp_indices.mutable_data<float>(ctx.GetPlace());
auto dst_dtype_tmp_indices_fp32 =
ConvertToNpuDtype(cast_tmp_indices.type());
const auto& runner_cast_tmp_indices = NpuOpRunner(
"Cast", {tmp_indices}, {cast_tmp_indices},
{{"dst_type", static_cast<int>(dst_dtype_tmp_indices_fp32)}});
runner_cast_tmp_indices.Run(stream);
// sort tmp_indices
Tensor sorted_tmp_indices(framework::proto::VarType::FP32);
sorted_tmp_indices.Resize(tmp_indices.dims());
sorted_tmp_indices.mutable_data<float>(ctx.GetPlace());
Tensor sorted_score_indices(framework::proto::VarType::INT32);
sorted_score_indices.Resize(tmp_indices.dims());
sorted_score_indices.mutable_data<int>(ctx.GetPlace());
const auto& runner_sort_tmp_indices = NpuOpRunner(
"Sort", {cast_tmp_indices}, {sorted_tmp_indices, sorted_score_indices},
{{"axis", 1}, {"descending", false}});
runner_sort_tmp_indices.Run(stream);
// cast sorted_tmp_indices from float32 to int
Tensor cast_sort_tmp_indices(framework::proto::VarType::INT32);
cast_sort_tmp_indices.Resize(sorted_tmp_indices.dims());
cast_sort_tmp_indices.mutable_data<int>(ctx.GetPlace());
auto dst_dtype_tmp_indices_int32 =
ConvertToNpuDtype(cast_sort_tmp_indices.type());
const auto& runner_cast_sort_tmp_indices = NpuOpRunner(
"Cast", {sorted_tmp_indices}, {cast_sort_tmp_indices},
{{"dst_type", static_cast<int>(dst_dtype_tmp_indices_int32)}});
runner_cast_sort_tmp_indices.Run(stream);
// Step 3: infer selected ids from tmp_indices and ids
// if pre_ids == end_ids, use pre_ids rather than ids
Tensor cal_ids(ids_int32.type());
cal_ids.mutable_data<int>(ids_int32.dims(), place);
const auto& runner_select_equal_end_id = NpuOpRunner(
"Select", {equal_end_ids, expand_pre_ids, ids_int32}, {cal_ids}, {});
runner_select_equal_end_id.Run(stream);
// resize ids from [num_seqs * real_beam_size, K] to [num_seqs,
// real_beam_size * K]
// real_beam_size = 1 or beam_size
cal_ids.Resize(
framework::make_ddim({num_seqs, real_beam_size * seq_width}));
// construct batch_ids like [[0, 0, 0], [1, 1, 1], ..., [bs-1, bs-1, bs-1]]
// construct arange(num_seqs*beam_size).reshape((num_seqs, beam_size)) //
// beam_size
Tensor batch_ids(framework::proto::VarType::INT32);
batch_ids.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 1}));
batch_ids.mutable_data<int>(place);
std::vector<int> vector_batch_ids;
for (int i = 0; i < num_seqs * static_cast<int>(beam_size); ++i) {
vector_batch_ids.push_back(static_cast<int>(i / beam_size));
}
framework::TensorFromVector(vector_batch_ids, ctx, &batch_ids);
batch_ids.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 1}));
// sort topk_scores to get selected_scores
// get indices of gather_nd op for calculating selected_scores
Tensor gather_nd_score_indices(framework::proto::VarType::INT32);
gather_nd_score_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 2}));
gather_nd_score_indices.mutable_data<int>(place);
sorted_score_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 1}));
std::vector<framework::Tensor> concat_inputs2 = {batch_ids,
sorted_score_indices};
std::vector<std::string> concat_names = {"x0", "x1"};
NpuOpRunner runner_concat_score_indices{"ConcatD",
{concat_inputs2},
{gather_nd_score_indices},
{{"concat_dim", 2}, {"N", 2}}};
runner_concat_score_indices.AddInputNames(concat_names);
runner_concat_score_indices.Run(stream);
// use gather_nd to get selected_scores
const auto& runner_gather_nd_scores =
NpuOpRunner("GatherNd", {topk_scores, gather_nd_score_indices},
{*selected_scores}, {});
runner_gather_nd_scores.Run(stream);
// get indices of gather_nd op
cast_sort_tmp_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 1}));
Tensor gather_nd_id_indices(framework::proto::VarType::INT32);
gather_nd_id_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size), 2}));
gather_nd_id_indices.mutable_data<int>(place);
std::vector<framework::Tensor> concat_inputs3 = {batch_ids,
cast_sort_tmp_indices};
NpuOpRunner runner_concat_id_indices{"ConcatD",
{concat_inputs3},
{gather_nd_id_indices},
{{"concat_dim", 2}, {"N", 2}}};
runner_concat_id_indices.AddInputNames(concat_names);
runner_concat_id_indices.Run(stream);
// use gather_nd to get selected_ids
Tensor topk_ids(framework::proto::VarType::INT32);
topk_ids.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
topk_ids.mutable_data<int>(ctx.GetPlace());
const auto& runner_gather_nd_ids = NpuOpRunner(
"GatherNd", {cal_ids, gather_nd_id_indices}, {topk_ids}, {});
runner_gather_nd_ids.Run(stream);
// cast topk_ids from int to int64 to get selected_ids
auto dst_dtype_selected_ids = ConvertToNpuDtype(selected_ids->type());
const auto& runner_cast_selected_ids =
NpuOpRunner("Cast", {topk_ids}, {*selected_ids},
{{"dst_type", static_cast<int>(dst_dtype_selected_ids)}});
runner_cast_selected_ids.Run(stream);
// TODO(pangyoki): PruneEndBeams
// Step 4: set lod of output Tensor
// define Tensor with value `seq_width`
Tensor seq_width_tensor(framework::proto::VarType::INT32);
seq_width_tensor.mutable_data<int>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<int>(&seq_width_tensor,
static_cast<int>(seq_width));
// beam_ids = tmp_indices // seq_width
Tensor beam_ids(framework::proto::VarType::INT32);
beam_ids.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
beam_ids.mutable_data<int>(ctx.GetPlace());
cast_sort_tmp_indices.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
const auto& runner_div = NpuOpRunner(
"Div", {cast_sort_tmp_indices, seq_width_tensor}, {beam_ids}, {});
runner_div.Run(stream);
// get parent_idx by adding batch_ids to beam_ids
// construct scale_batch_ids like [[0, 0, 0], [bw, bw, bw], ..., [bs-1*bw,
// bs-1*bw, bs-1*bw]]
batch_ids.Resize(
framework::make_ddim({num_seqs, static_cast<int64_t>(beam_size)}));
// cast batch_ids from int to float32
Tensor cast_batch_ids(framework::proto::VarType::FP32);
cast_batch_ids.Resize(batch_ids.dims());
cast_batch_ids.mutable_data<float>(ctx.GetPlace());
auto dst_dtype1 = ConvertToNpuDtype(cast_batch_ids.type());
const auto& runner_cast_batch_ids =
NpuOpRunner("Cast", {batch_ids}, {cast_batch_ids},
{{"dst_type", static_cast<int>(dst_dtype1)}});
runner_cast_batch_ids.Run(stream);
// scale batch_ids with beam_size
Tensor scale_batch_ids(framework::proto::VarType::FP32);
scale_batch_ids.Resize(batch_ids.dims());
scale_batch_ids.mutable_data<float>(place);
const auto& runner_power =
NpuOpRunner("Power", {cast_batch_ids}, {scale_batch_ids},
{{"power", static_cast<float>(1.0)},
{"scale", static_cast<float>(beam_size)},
{"shift", static_cast<float>(0.0)}});
runner_power.Run(stream);
// cast cast_scale_batch_ids from float32 to int
Tensor cast_scale_batch_ids(framework::proto::VarType::INT32);
cast_scale_batch_ids.Resize(scale_batch_ids.dims());
cast_scale_batch_ids.mutable_data<int>(ctx.GetPlace());
auto dst_dtype2 = ConvertToNpuDtype(cast_scale_batch_ids.type());
const auto& runner_cast_scale_batch_ids =
NpuOpRunner("Cast", {scale_batch_ids}, {cast_scale_batch_ids},
{{"dst_type", static_cast<int>(dst_dtype2)}});
runner_cast_scale_batch_ids.Run(stream);
// calculate parent_idx
Tensor tmp_parent_idx(framework::proto::VarType::INT32);
tmp_parent_idx.Resize(parent_idx->dims());
tmp_parent_idx.mutable_data<int>(place);
const auto& runner_add_beam_id = NpuOpRunner(
"Add", {beam_ids, cast_scale_batch_ids}, {tmp_parent_idx}, {});
runner_add_beam_id.Run(stream);
// cast tmp_parent_idx from int to int64 to get parent_idx
auto dst_dtype_parent_idx = ConvertToNpuDtype(parent_idx->type());
const auto& runner_cast_parent_idx =
NpuOpRunner("Cast", {tmp_parent_idx}, {*parent_idx},
{{"dst_type", static_cast<int>(dst_dtype_parent_idx)}});
runner_cast_parent_idx.Run(stream);
std::vector<int> vector_parent_idx;
framework::TensorToVector(tmp_parent_idx, ctx, &vector_parent_idx);
// set low level, len(low_level) = high_level[-1]
std::vector<int> low_level;
std::vector<int> num_parent_ids(num_seqs * beam_size,
static_cast<int64_t>(0));
size_t low_level_size = high_level[num_seqs];
size_t sum_parent_id = 0;
// calculate number of every parent_id
for (size_t i = 0; i < num_seqs * beam_size; ++i) {
num_parent_ids[vector_parent_idx[i]]++;
}
// update low_level
low_level.push_back(0);
for (size_t i = 0; i < low_level_size; ++i) {
sum_parent_id += num_parent_ids[i];
low_level.push_back(sum_parent_id);
}
// fill lod
framework::LoD lod(2);
lod[0].assign(high_level.begin(), high_level.end());
lod[1].assign(low_level.begin(), low_level.end());
if (!framework::CheckLoD(lod)) {
PADDLE_THROW(platform::errors::InvalidArgument(
"lod %s is not right in"
" beam_search, please check your code.",
framework::LoDToString(lod)));
}
selected_ids->set_lod(lod);
selected_scores->set_lod(lod);
}
};
template class BeamSearchFunctor<platform::NPUDeviceContext, int>;
template class BeamSearchFunctor<platform::NPUDeviceContext, int64_t>;
template class BeamSearchFunctor<platform::NPUDeviceContext, float>;
template class BeamSearchFunctor<platform::NPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle
import sys
sys.path.append("..")
from op_test import OpTest
import unittest
import numpy as np
import paddle.fluid as fluid
paddle.enable_static()
class TestBeamSearchNPUOp(OpTest):
def setUp(self):
self.set_npu()
self.place = paddle.NPUPlace(0)
self.op_type = "beam_search"
self.init_data()
self.inputs = {
'pre_ids': (self.pre_ids, self.lod),
'pre_scores': (self.pre_score, self.lod),
'ids': (self.ids, self.lod),
'scores': (self.score, self.lod)
}
# The `target_lod` attribute is still based on offset
self.attrs = {
'level': 0,
'beam_size': self.beam_size,
'end_id': 0,
'is_accumulated': self.is_accumulated
}
self.outputs = {
'selected_ids': (self.selected_ids, self.out_lod),
'selected_scores': (self.selected_scores, self.out_lod),
'parent_idx': self.parent_idx
}
def set_npu(self):
self.__class__.use_npu = True
def init_data(self):
self.beam_size = 2
self.is_accumulated = True
self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64')
self.ids = np.array(
[[4, 2, 5], [2, 1, 3], [3, 5, 2], [8, 2, 1]], dtype='int64')
self.lod = [[2, 2], [1, 1, 1, 1]]
self.out_lod = [[2, 2], [1, 1, 1, 1]]
self.offset_lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
self.score = np.array(
[
[0.5, 0.3, 0.2],
[0.6, 0.3, 0.1],
[0.9, 0.5, 0.1],
[0.7, 0.5, 0.1],
],
dtype='float32')
self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32')
self.selected_ids = np.array([4, 2, 3, 8])[:, np.newaxis]
self.selected_scores = np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis]
self.parent_idx = np.array([0, 1, 2, 3])
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-3)
class TestBeamSearchNPUOp2(TestBeamSearchNPUOp):
def init_data(self):
self.beam_size = 2
self.is_accumulated = True
self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64')
self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
self.lod = [[2, 2], [1, 1, 1, 1]]
self.out_lod = [[2, 2], [2, 0, 1, 1]]
self.offset_lod = [[0, 2, 4], [0, 2, 2, 3, 4]]
self.score = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32')
self.selected_ids = np.array([4, 2, 3, 1])[:, np.newaxis]
self.selected_scores = np.array([0.6, 0.9, 0.9, 0.7])[:, np.newaxis]
self.parent_idx = np.array([0, 0, 2, 3])
class TestBeamSearchNPUOp3(TestBeamSearchNPUOp):
def init_data(self):
# end_id = 0
self.beam_size = 2
self.is_accumulated = True
self.pre_ids = np.array([[1], [0], [0], [4]], dtype='int64')
self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
self.lod = [[2, 2], [1, 1, 1, 1]]
self.out_lod = [[2, 2], [1, 1, 0, 2]]
self.offset_lod = [[0, 2, 4], [0, 1, 2, 2, 4]]
self.score = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.6, 0.7],
], dtype='float32')
self.pre_score = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32')
self.selected_ids = np.array([2, 0, 8, 1])[:, np.newaxis]
self.selected_scores = np.array([0.9, 1.2, 0.6, 0.7])[:, np.newaxis]
self.parent_idx = np.array([0, 1, 3, 3])
class TestBeamSearchNPUOp4(TestBeamSearchNPUOp):
def init_data(self):
# is_accumulated = False
self.beam_size = 2
self.is_accumulated = False
self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64')
self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
self.lod = [[2, 2], [1, 1, 1, 1]]
self.out_lod = [[2, 2], [0, 2, 1, 1]]
self.offset_lod = [[0, 2, 4], [0, 0, 2, 3, 4]]
self.score = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
self.pre_score = np.array([[0.1], [2.2], [0.3], [0.4]], dtype='float32')
self.selected_ids = np.array([7, 3, 3, 1])[:, np.newaxis]
self.selected_scores = np.array(
[1.50685, 0.996027, 0.194639, 0.043325])[:, np.newaxis]
self.parent_idx = np.array([1, 1, 2, 3])
class TestBeamSearchNPUOp5(TestBeamSearchNPUOp):
def init_data(self):
# beam_size = 1
self.beam_size = 1
self.is_accumulated = True
self.pre_ids = np.array([[1], [2], [3], [4]], dtype='int64')
self.ids = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
self.lod = [[1, 1, 1, 1], [1, 1, 1, 1]]
self.out_lod = [[1, 1, 1, 1], [1, 1, 1, 1]]
self.offset_lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
self.score = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
self.pre_score = np.array([[0.1], [0.2], [0.3], [0.4]], dtype='float32')
self.selected_ids = np.array([2, 7, 3, 1])[:, np.newaxis]
self.selected_scores = np.array([0.9, 0.5, 0.9, 0.7])[:, np.newaxis]
self.parent_idx = np.array([0, 1, 2, 3])
if __name__ == '__main__':
unittest.main()
......@@ -38,6 +38,7 @@ class BeamSearchOpTester(unittest.TestCase):
self._create_pre_scores()
self._create_scores()
self._create_pre_ids()
self.set_outputs()
self.scope.var('selected_ids').get_tensor()
self.scope.var('selected_scores').get_tensor()
self.scope.var('parent_idx').get_tensor()
......@@ -53,22 +54,19 @@ class BeamSearchOpTester(unittest.TestCase):
selected_scores='selected_scores',
parent_idx='parent_idx',
level=0,
beam_size=2,
end_id=0, )
beam_size=self.beam_size,
end_id=0,
is_accumulated=self.is_accumulated)
op.run(self.scope, core.CPUPlace())
selected_ids = self.scope.find_var("selected_ids").get_tensor()
selected_scores = self.scope.find_var("selected_scores").get_tensor()
parent_idx = self.scope.find_var("parent_idx").get_tensor()
self.assertTrue(np.allclose(np.array(selected_ids), self.output_ids))
self.assertTrue(
np.allclose(
np.array(selected_ids), np.array([4, 2, 3, 8])[:, np.newaxis]))
np.allclose(np.array(selected_scores), self.output_scores))
self.assertEqual(selected_ids.lod(), self.output_lod)
self.assertTrue(
np.allclose(
np.array(selected_scores),
np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis]))
self.assertEqual(selected_ids.lod(), [[0, 2, 4], [0, 1, 2, 3, 4]])
self.assertTrue(
np.allclose(np.array(parent_idx), np.array([0, 1, 2, 3])))
np.allclose(np.array(parent_idx), self.output_parent_idx))
def _create_pre_ids(self):
np_data = np.array([[1, 2, 3, 4]], dtype='int64')
......@@ -97,6 +95,194 @@ class BeamSearchOpTester(unittest.TestCase):
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 2
self.is_accumulated = True
self.output_ids = np.array([4, 2, 3, 8])[:, np.newaxis]
self.output_scores = np.array([0.5, 0.6, 0.9, 0.7])[:, np.newaxis]
self.output_lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
self.output_parent_idx = np.array([0, 1, 2, 3])
class BeamSearchOpTester2(BeamSearchOpTester):
def _create_pre_ids(self):
np_data = np.array([[1], [2], [3], [4]], dtype='int64')
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1, 0.2, 0.3, 0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 2
self.is_accumulated = True
self.output_ids = np.array([2, 4, 3, 1])[:, np.newaxis]
self.output_scores = np.array([0.9, 0.6, 0.9, 0.7])[:, np.newaxis]
self.output_lod = [[0, 2, 4], [0, 2, 2, 3, 4]]
self.output_parent_idx = np.array([0, 0, 2, 3])
class BeamSearchOpTester3(BeamSearchOpTester):
# pre_id = end_id
def _create_pre_ids(self):
np_data = np.array([[1], [0], [0], [4]], dtype='int64')
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.6, 0.7],
], dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 2
self.is_accumulated = True
self.output_ids = np.array([2, 0, 1, 8])[:, np.newaxis]
self.output_scores = np.array([0.9, 1.2, 0.7, 0.6])[:, np.newaxis]
self.output_lod = [[0, 2, 4], [0, 1, 2, 2, 4]]
self.output_parent_idx = np.array([0, 1, 3, 3])
class BeamSearchOpTester4(BeamSearchOpTester):
# prune beam search while pre_id of in all beams is end_id
def _create_pre_ids(self):
np_data = np.array([[0], [0], [0], [4]], dtype='int64')
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1], [1.2], [0.5], [0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.6, 0.7],
], dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 2
self.is_accumulated = True
self.output_ids = np.array([1, 8])[:, np.newaxis]
self.output_scores = np.array([0.7, 0.6])[:, np.newaxis]
self.output_lod = [[0, 2, 4], [0, 0, 0, 0, 2]]
self.output_parent_idx = np.array([3, 3])
class BeamSearchOpTester5(BeamSearchOpTester):
# is_accumulated = False
def _create_pre_ids(self):
np_data = np.array([[1], [2], [3], [4]], dtype='int64')
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1, 2.2, 0.3, 0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[0, 2, 4], [0, 1, 2, 3, 4]]
np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 2
self.is_accumulated = False
self.output_ids = np.array([7, 3, 3, 1])[:, np.newaxis]
self.output_scores = np.array(
[1.50685, 0.996027, 0.194639, 0.043325])[:, np.newaxis]
self.output_lod = [[0, 2, 4], [0, 0, 2, 3, 4]]
self.output_parent_idx = np.array([1, 1, 2, 3])
class BeamSearchOpTester6(BeamSearchOpTester):
# beam_size = 1
def _create_pre_ids(self):
np_data = np.array([[1], [2], [3], [4]], dtype='int64')
tensor = create_tensor(self.scope, 'pre_ids', np_data)
def _create_pre_scores(self):
np_data = np.array([[0.1, 0.2, 0.3, 0.4]], dtype='float32')
tensor = create_tensor(self.scope, 'pre_scores', np_data)
def _create_ids(self):
self.lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
np_data = np.array([[4, 2], [7, 3], [3, 5], [8, 1]], dtype='int64')
tensor = create_tensor(self.scope, "ids", np_data)
tensor.set_lod(self.lod)
def _create_scores(self):
np_data = np.array(
[
[0.6, 0.9],
[0.5, 0.3],
[0.9, 0.5],
[0.1, 0.7],
], dtype='float32')
tensor = create_tensor(self.scope, "scores", np_data)
tensor.set_lod(self.lod)
def set_outputs(self):
self.beam_size = 1
self.is_accumulated = True
self.output_ids = np.array([2, 7, 3, 1])[:, np.newaxis]
self.output_scores = np.array([0.9, 0.5, 0.9, 0.7])[:, np.newaxis]
self.output_lod = [[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]
self.output_parent_idx = np.array([0, 1, 2, 3])
class TestBeamSearchOpError(unittest.TestCase):
def test_errors(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册