// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "lite/backends/arm/math/beam_search.h" #include #include #include #include #include "lite/utils/cp_logging.h" namespace paddle { namespace lite { namespace arm { namespace math { /* * The basic items help to sort. */ struct Item { Item() {} Item(size_t offset, size_t id, float score) : offset(offset), id(id), score(score) {} // offset in the higher lod level. size_t offset; // prefix id in the lower lod level. // size_t prefix; // the candidate id size_t id; // the corresponding score float score; inline bool operator<(const Item &in) const { return (score < in.score) || ((score == in.score) && (offset < in.offset)); } inline void operator=(const Item &in) { offset = in.offset; id = in.id; score = in.score; } std::string ToString() { std::ostringstream os; os << "{"; os << "offset: " << offset << ", "; os << "id: " << id << ", "; os << "score: " << score << ""; os << "}"; return os.str(); } }; /* * Prune the source sentences all branchs finished, and it is optional. * Pruning must one step later than finishing (thus pre_ids is needed here), * since the end tokens must be writed out. */ void PruneEndBeams(const Tensor *pre_ids, const LoD &abs_lod, std::vector> *items, size_t lod_level, int end_id) { auto *pre_ids_data = pre_ids->data(); auto &high_level = abs_lod[lod_level]; for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { size_t src_prefix_start = high_level[src_idx]; size_t src_prefix_end = high_level[src_idx + 1]; bool finish_flag = true; for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) { for (auto &item : items->at(offset)) { if (item.id != static_cast(end_id) || pre_ids_data[offset] != end_id) { finish_flag = false; break; } } if (!finish_flag) break; } if (finish_flag) { // all branchs of the beam (source sentence) end and // prune this beam for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) items->at(offset).clear(); } } } /* * Transform the items into a map whose key is offset, value is the items. * NOTE low performance. */ std::vector> ToMap( const std::vector> &items, size_t element_num) { std::vector> result; result.resize(element_num); for (auto &entries : items) { for (const auto &item : entries) { result[item.offset].push_back(item); } } return result; } void Insert(std::vector *top_beam_ptr, const Item &item, size_t beam_size) { std::vector &top_beam = *top_beam_ptr; size_t num_beams = top_beam.size(); if (num_beams < beam_size) { top_beam.resize(num_beams + 1); num_beams++; } else { if (item < top_beam[beam_size - 1]) { return; } } for (int k = static_cast(num_beams) - 2; k >= 0; --k) { if (top_beam[k] < item) { top_beam[k + 1] = top_beam[k]; } else { top_beam[k + 1] = item; return; } } top_beam[0] = item; } /* * For each source, select top beam_size records. */ std::vector> SelectTopBeamSizeItems(const Tensor *pre_ids, const Tensor *pre_scores, const Tensor *ids, const Tensor *scores, size_t lod_level, size_t beam_size, int end_id, bool is_accumulated) { std::vector> result; // find the current candidates // auto abs_lod = framework::ToAbsOffset(scores->lod()); auto abs_lod = scores->lod(); auto *pre_ids_data = pre_ids->data(); auto *pre_scores_data = pre_scores->data(); auto *ids_data = ids ? ids->data() : nullptr; auto *scores_data = scores->data(); size_t num_seqs = abs_lod[lod_level].size() - 1; size_t seq_width = 1; for (int i = 1; i < scores->dims().size(); i++) { seq_width *= scores->dims()[i]; } for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) { size_t seq_offset_start = abs_lod[lod_level][seq_id]; size_t seq_offset_end = abs_lod[lod_level][seq_id + 1]; std::vector top_beam; top_beam.reserve(beam_size); for (size_t offset = seq_offset_start; offset < seq_offset_end; ++offset) { auto pre_id = pre_ids_data[offset]; auto pre_score = pre_scores_data[offset]; if (pre_id == end_id) { // Allocate all probability mass to end_id for finished branchs and // the other candidate ids can be ignored. Item item(offset, end_id, pre_score); Insert(&top_beam, item, beam_size); } else { size_t index = offset * seq_width; for (size_t d = 0; d < seq_width; d++, index++) { int64_t id = ids_data ? ids_data[index] : static_cast(d); float score = is_accumulated ? scores_data[index] : pre_score + std::log(scores_data[index]); Item item(offset, id, score); Insert(&top_beam, item, beam_size); } } } result.emplace_back(top_beam); } return result; } void beam_search(const Tensor *pre_ids, const Tensor *pre_scores, const Tensor *ids, const Tensor *scores, Tensor *selected_ids, Tensor *selected_scores, Tensor *parent_idx, int level, int beam_size, int end_id, bool is_accumulated, Context *ctx) { // auto abs_lod = framework::ToAbsOffset(scores->lod()); auto abs_lod = scores->lod(); auto &high_level = abs_lod[level]; auto items = SelectTopBeamSizeItems(pre_ids, pre_scores, ids, scores, level, beam_size, end_id, is_accumulated); auto selected_items = ToMap(items, high_level.back()); PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id); // calculate the output tensor's height size_t num_instances = std::accumulate( std::begin(selected_items), std::end(selected_items), 0, [](size_t a, std::vector &b) { return a + b.size(); }); // the output tensor shape should be [num_instances, 1] auto dims = std::vector({static_cast(num_instances), 1}); selected_ids->Resize(dims); selected_scores->Resize(dims); if (parent_idx) { parent_idx->Resize({static_cast(num_instances)}); } auto *selected_ids_data = selected_ids->mutable_data(); auto *selected_scores_data = selected_scores->mutable_data(); auto *parent_idx_data = parent_idx ? parent_idx->mutable_data() : nullptr; // fill in data std::vector low_level; size_t low_offset = 0; for (auto &items : selected_items) { low_level.push_back(low_offset); for (auto &item : items) { if (parent_idx) { parent_idx_data[low_offset] = static_cast(low_level.size() - 1); } selected_ids_data[low_offset] = item.id; selected_scores_data[low_offset] = item.score; low_offset++; } } low_level.push_back(low_offset); // fill lod LoD lod(2); lod[0].assign(high_level.begin(), high_level.end()); lod[1].assign(low_level.begin(), low_level.end()); *(selected_ids->mutable_lod()) = lod; *(selected_scores->mutable_lod()) = lod; } } // namespace math } // namespace arm } // namespace lite } // namespace paddle