未验证 提交 3008fa12 编写于 作者: Y Yiqun Liu 提交者: GitHub

Add the CUDA kernel for beam_search op (#15020)

* Refine the beam_search op and test.

* A basic CUDA implementation of beam_search for small batch_size.

* Implement CUDA kernel for beam_search_op.

* Use multiple CUDA threads in the same block to select the top beam.

* Update the python api of beam_search op.

* Enable extend function in CPU kernel of beam_search op.

* Unify the CUDA codes.
test=develop

* Unify the CPU kernel of beam_search op.

* Ensure the seletced items of beam_search_op's CPU kernel sorted by scores.

* Update the description of beam_search in API.spec.

* Enable the use of CUDA kernel in beam_search op.

* Exclude the beam_search's CUDA unittest when there is no CUDA gpu, and delete some debuging statements.
test=develop

* Follow comments.
test=develop

* Call the CPU kernel for beam_search op when batch_size > 4.
test=develop

* Remove the except of is_empty op in PrepareData.
test=develop
上级 ed1726ea
......@@ -122,7 +122,7 @@ paddle.fluid.layers.transpose ArgSpec(args=['x', 'perm', 'name'], varargs=None,
paddle.fluid.layers.im2sequence ArgSpec(args=['input', 'filter_size', 'stride', 'padding', 'input_image_size', 'out_stride', 'name'], varargs=None, keywords=None, defaults=(1, 1, 0, None, 1, None))
paddle.fluid.layers.nce ArgSpec(args=['input', 'label', 'num_total_classes', 'sample_weight', 'param_attr', 'bias_attr', 'num_neg_samples', 'name', 'sampler', 'custom_dist', 'seed', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, 'uniform', None, 0, False))
paddle.fluid.layers.hsigmoid ArgSpec(args=['input', 'label', 'num_classes', 'param_attr', 'bias_attr', 'name', 'path_table', 'path_code', 'is_custom', 'is_sparse'], varargs=None, keywords=None, defaults=(None, None, None, None, None, False, False))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'name'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.beam_search ArgSpec(args=['pre_ids', 'pre_scores', 'ids', 'scores', 'beam_size', 'end_id', 'level', 'is_accumulated', 'name'], varargs=None, keywords=None, defaults=(0, True, None))
paddle.fluid.layers.row_conv ArgSpec(args=['input', 'future_context_size', 'param_attr', 'act'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.multiplex ArgSpec(args=['inputs', 'index'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.layer_norm ArgSpec(args=['input', 'scale', 'shift', 'begin_norm_axis', 'epsilon', 'param_attr', 'bias_attr', 'act', 'name'], varargs=None, keywords=None, defaults=(True, True, 1, 1e-05, None, None, None, None))
......
......@@ -54,13 +54,14 @@ std::ostream &operator<<(std::ostream &os, const LoD &lod) {
std::ostream &operator<<(std::ostream &os, const LoDTensor &t) {
if (!platform::is_cpu_place(t.place())) {
LoDTensor tt;
framework::TensorCopy(t, platform::CPUPlace(), &tt);
LoDTensor cpu_tensor;
cpu_tensor.set_lod(t.lod());
framework::TensorCopy(t, platform::CPUPlace(), &cpu_tensor);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(t.place());
dev_ctx.Wait();
os << tt;
os << cpu_tensor;
return os;
}
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
......
......@@ -66,7 +66,7 @@ set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} dynload_warpctc)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence_padding sequence_scale cos_sim_functor memory jit_kernel_helper concat_and_split cross_entropy softmax vol2col im2col sampler tree2col)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} sequence2batch lstm_compute matrix_bit_code gru_compute activation_functions beam_search)
if (WITH_GPU)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} depthwise_conv prelu)
endif()
......@@ -86,7 +86,6 @@ set(GLOB_OPERATOR_DEPS ${OPERATOR_DEPS} CACHE INTERNAL "Global Op dependencies")
cc_test(gather_test SRCS gather_test.cc DEPS tensor)
cc_test(scatter_test SRCS scatter_test.cc DEPS tensor math_function)
cc_test(beam_search_decode_op_test SRCS beam_search_decode_op_test.cc DEPS lod_tensor)
cc_test(beam_search_op_test SRCS beam_search_op_test.cc DEPS lod_tensor beam_search_op)
cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
......
......@@ -12,205 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <map>
#include "paddle/fluid/operators/beam_search_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/beam_search_op.h"
namespace paddle {
namespace operators {
void BeamSearch::operator()(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
framework::LoDTensor *selected_ids,
framework::LoDTensor *selected_scores) {
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
auto items = SelectTopBeamSizeItems(pre_ids, pre_scores);
auto selected_items = ToMap(items, high_level.back());
VLOG(3) << "selected_items:";
for (size_t i = 0; i < selected_items.size(); ++i) {
VLOG(3) << "offset:" << i;
for (auto &item : selected_items[i]) {
VLOG(3) << ItemToString(item);
}
}
PruneEndBeams(pre_ids, &selected_items);
// calculate the output tensor's height
size_t num_instances = std::accumulate(
std::begin(selected_items), std::end(selected_items), 0,
[](size_t a, std::vector<Item> &b) { return a + b.size(); });
// the output tensor shape should be [num_instances, 1]
auto dims = framework::make_ddim(
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
selected_ids->Resize(dims);
selected_scores->Resize(dims);
std::map<size_t /*offset*/, std::vector<Item>> hash;
framework::LoD new_lod;
auto *ids_data = selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace());
// fill in data
std::vector<size_t> low_level;
size_t low_offset = 0;
for (auto &items : selected_items) {
low_level.push_back(low_offset);
for (auto &item : items) {
ids_data[low_offset] = item.id;
scores_data[low_offset] = item.score;
low_offset++;
}
}
low_level.push_back(low_offset);
// fill lod
framework::LoD lod(2);
lod[0].assign(high_level.begin(), high_level.end());
lod[1].assign(low_level.begin(), low_level.end());
if (!framework::CheckLoD(lod)) {
PADDLE_THROW("lod %s is not right", framework::LoDToString(lod));
}
selected_ids->set_lod(lod);
selected_scores->set_lod(lod);
}
void BeamSearch::PruneEndBeams(const framework::LoDTensor &pre_ids,
std::vector<std::vector<Item>> *items) {
auto *pre_ids_data = pre_ids.data<int64_t>();
auto abs_lod = framework::ToAbsOffset(ids_->lod());
auto &high_level = abs_lod[lod_level_];
for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
size_t src_prefix_start = high_level[src_idx];
size_t src_prefix_end = high_level[src_idx + 1];
bool finish_flag = true;
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++) {
for (auto &item : items->at(offset)) {
if (item.id != static_cast<size_t>(end_id_) ||
pre_ids_data[offset] != end_id_) {
finish_flag = false;
break;
}
}
if (!finish_flag) break;
}
if (finish_flag) { // all branchs of the beam (source sentence) end and
// prune this beam
for (size_t offset = src_prefix_start; offset < src_prefix_end; offset++)
items->at(offset).clear();
}
}
}
std::vector<std::vector<BeamSearch::Item>> BeamSearch::ToMap(
const std::vector<std::vector<Item>> &items, size_t element_num) {
std::vector<std::vector<Item>> result;
result.resize(element_num);
for (auto &entries : items) {
for (const auto &item : entries) {
result[item.offset].push_back(item);
}
}
return result;
}
std::vector<std::vector<BeamSearch::Item>> BeamSearch::SelectTopBeamSizeItems(
const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores) {
std::vector<std::vector<Item>> result;
std::vector<Item> items;
// for each source sentence, select the top beam_size items across all
// candidate sets.
while (NextItemSet(pre_ids, pre_scores, &items)) {
std::nth_element(
std::begin(items), std::begin(items) + beam_size_, std::end(items),
[](const Item &a, const Item &b) { return a.score > b.score; });
// prune the top beam_size items.
if (items.size() > beam_size_) {
items.resize(beam_size_);
}
result.emplace_back(items);
}
VLOG(3) << "SelectTopBeamSizeItems result size " << result.size();
for (auto &items : result) {
VLOG(3) << "item set:";
for (auto &item : items) {
VLOG(3) << ItemToString(item);
}
}
return result;
}
// the candidates of a source
bool BeamSearch::NextItemSet(const framework::LoDTensor &pre_ids,
const framework::LoDTensor &pre_scores,
std::vector<BeamSearch::Item> *items) {
if (sent_offset_ >= ids_->NumElements(lod_level_)) {
return false;
}
// find the current candidates
auto ids = *ids_;
auto scores = *scores_;
auto abs_lod = framework::ToAbsOffset(ids.lod());
auto *ids_data = ids.data<int64_t>();
auto *scores_data = scores.data<float>();
size_t instance_dim = 1;
for (int i = 1; i < ids.dims().size(); i++) {
instance_dim *= ids.dims()[i];
}
auto *pre_ids_data = pre_ids.data<int64_t>();
auto *pre_scores_data = pre_scores.data<float>();
items->clear();
items->reserve(framework::product(ids.dims()));
for (size_t offset = abs_lod[lod_level_][sent_offset_];
offset < abs_lod[lod_level_][sent_offset_ + 1]; offset++) {
auto pre_id = pre_ids_data[offset];
auto pre_score = pre_scores_data[offset];
if (pre_id == end_id_) {
// Allocate all probability mass to eos_id for finished branchs and the
// other candidate ids can be ignored.
items->emplace_back(offset, end_id_, pre_score);
} else {
for (size_t d = 0; d < instance_dim; d++) {
const size_t dim_offset = offset * instance_dim + d;
items->emplace_back(offset, ids_data[dim_offset],
scores_data[dim_offset]);
}
}
}
sent_offset_++;
return true;
}
std::ostream &operator<<(std::ostream &os, const BeamSearch::Item &item) {
os << "{";
os << "offset: " << item.offset << ", ";
os << "id: " << item.id << ", ";
os << "score: " << item.score << "";
os << "}";
return os;
}
std::string ItemToString(const BeamSearch::Item &item) {
std::ostringstream stream;
stream << item;
return stream.str();
}
class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
......@@ -219,18 +29,23 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
"(LoDTensor) The LoDTensor containing the selected ids at the "
"previous step. It should be a tensor with shape (batch_size, 1) "
"and lod `[[0, 1, ... , batch_size], [0, 1, ..., batch_size]]` at "
"thefirst step.");
"the first step.");
AddInput("pre_scores",
"(LoDTensor) The LoDTensor containing the accumulated "
"scores corresponding to the selected ids at the previous step.");
AddInput("ids",
"(LoDTensor) The LoDTensor containing the candidates ids. Its "
"shape should be (batch_size * beam_size, K), where K supposed to "
"be beam_size.");
"shape should be (batch_size * beam_size, W). If not set, it will "
"be calculated out according to Input(scores) in this operator.")
.AsDispensable();
AddInput("scores",
"(LoDTensor) The LodTensor containing the accumulated scores "
"corresponding to Input(ids) and its shape is the same as the "
"shape of Input(ids).");
"(LoDTensor) The LoDTensor containing the current scores "
"corresponding to Input(ids). If Input(ids) is not nullptr, its "
"shape is the same as that of Input(ids)."
"If is_accumulated is true, Input(scores) is accumulated scores "
"and will be used derectedly. Else, each score will be "
"transformed to the log field and accumulate Input(pre_sores) "
"first.");
AddOutput("selected_ids",
"A LodTensor that stores the IDs selected by beam search.");
AddOutput("selected_scores",
......@@ -242,6 +57,9 @@ class BeamSearchOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<int>("beam_size", "beam size for beam search");
AddAttr<int>("end_id",
"the token id which indicates the end of a sequence");
AddAttr<bool>("is_accumulated",
"Whether the Input(scores) is accumulated scores.")
.SetDefault(true);
AddComment(R"DOC(
This operator does the search in beams for one time step.
......@@ -265,10 +83,9 @@ class BeamSearchOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *ctx) const override {
for (const std::string &arg :
std::vector<std::string>({"pre_ids", "ids", "scores"})) {
std::vector<std::string>({"pre_ids", "scores"})) {
PADDLE_ENFORCE(ctx->HasInput(arg), "BeamSearch need input argument '%s'",
arg);
}
......@@ -279,12 +96,22 @@ class BeamSearchOp : public framework::OperatorWithKernel {
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
framework::OpKernelType kt = framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(),
platform::CPUPlace());
return kt;
auto *scores = ctx.Input<framework::LoDTensor>("scores");
size_t level = ctx.Attr<int>("level");
size_t batch_size = scores->lod()[level].size() - 1;
// The current CUDA kernel only support cases with batch_size < 4.
// Compute on CPU for cases with batch_size > 4.
if (batch_size <= 4) {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(), ctx.GetPlace());
} else {
return framework::OpKernelType(
ctx.Input<framework::LoDTensor>("pre_ids")->type(),
platform::CPUPlace());
}
}
};
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include "paddle/fluid/framework/op_registry.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
beam_search,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, float>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, double>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, int>,
ops::BeamSearchOpKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
......@@ -14,187 +14,12 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/math/beam_search.h"
namespace paddle {
namespace operators {
/*
* This is an implementation of beam search.
*
* To explain the details, lets take machine translation task for example, in
* this task, one source sentence is translated to multiple target sentences,
* during this period, one sentence will be translated to multiple translation
* prefixes(target sentence that have not ended), in each time step a prefix
* will have some candidates, input the candidate ids and their corresponding
* scores (probabilities), it will sort and select the top beam_size candidates
* for each source sentence, and store the selected candidates's score and their
* corresponding ids to LoDTensors.
*
* A detailed example:
*
* Input
*
* ids:
* LoD (should have 2 levels)
* first level: [0, 1, 4]
* second level: [0, 1, 2, 3, 4]
*
* tensor's data
* [
* [4, 2, 5]
* [2, 1, 3]
* [3, 5, 2]
* [8, 2, 1]
* ]
*
* scores:
* LoD same as `ids`
* tensor's data
* [
* [0.5, 0.3, 0.2]
* [0.6, 0.3, 0.1]
* [0.9, 0.5, 0.1]
* [0.7, 0.5, 0.1]
* ]
*
* the inputs means that there are 2 source sentences to translate, and the
* first source has 1 prefix, the second source has 2 prefix.
*
* lets assume beam size is 2, and the beam search's output should be
* LoD
* first level:
* [0, 1, 2]
* second level:
* [0, 2, 4]
*
* id tensor's data
* [[
* 4,
* 1,
* 3,
* 8,
* ]]
*
* score tensor's data
* [[
* 0.5,
* 0.3,
* 0.9,
* 0.7
* ]]
*
* TODO all the prune operations should be in the beam search, so it is better
* to split the beam search algorithm into a sequence of smaller operators, and
* the prune operators can be inserted in this sequence.
*/
class BeamSearch {
public:
// TODO(superjom) make type customizable
using id_t = size_t;
using score_t = float;
/*
* Input the arguments that needed by this class.
*/
BeamSearch(const framework::LoDTensor& ids,
const framework::LoDTensor& scores, size_t level, size_t beam_size,
int end_id)
: beam_size_(beam_size),
ids_(&ids),
scores_(&scores),
lod_level_(level),
end_id_(end_id) {}
/*
* The main function of beam search.
*
* @selected_ids: a [None, 1]-shaped tensor with LoD.
* In a machine translation model, it might be the candidate term id sets,
* each set stored as a varience-length sequence.
* The format might be described with a two-level LoD
* - [[0 1]
* - [0 1 2]]
* - [[]
* - [0 1]]
* the first level of LoD tells that there are two source sentences. The
* second level describes the details of the candidate id set's offsets in
* the
* source sentences.
*
* @selected_scores: a LoD tensor with the same shape and LoD with
* selected_ids.
* It stores the corresponding scores of candidate ids in selected_ids.
*
* Return false if all the input tensor is empty, in machine translation task
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores);
/*
* The basic items help to sort.
*/
struct Item {
Item() {}
Item(size_t offset, size_t id, float score)
: offset(offset), id(id), score(score) {}
// offset in the higher lod level.
size_t offset;
// // prefix id in the lower lod level.
// size_t prefix;
// the candidate id
id_t id;
// the corresponding score
score_t score;
};
protected:
/*
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing (thus pre_ids is needed here),
* since the end tokens must be writed out.
*/
void PruneEndBeams(const framework::LoDTensor& pre_ids,
std::vector<std::vector<Item>>* items);
/*
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance.
*/
std::vector<std::vector<Item>> ToMap(
const std::vector<std::vector<Item>>& inputs, size_t element_num);
/*
* For each source, select top beam_size records.
*/
std::vector<std::vector<Item>> SelectTopBeamSizeItems(
const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores);
/*
* Get the items of next source sequence, return false if no remaining items.
*/
bool NextItemSet(const framework::LoDTensor& pre_ids,
const framework::LoDTensor& pre_scores,
std::vector<Item>* items);
private:
size_t beam_size_;
const framework::LoDTensor* ids_;
const framework::LoDTensor* scores_;
size_t lod_level_{0};
size_t sent_offset_{0};
int end_id_{0};
};
std::ostream& operator<<(std::ostream& os, const BeamSearch::Item& item);
std::string ItemToString(const BeamSearch::Item& item);
template <typename DeviceContext, typename T>
class BeamSearchOpKernel : public framework::OpKernel<T> {
public:
......@@ -203,7 +28,7 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
auto* scores = context.Input<framework::LoDTensor>("scores");
auto* pre_ids = context.Input<framework::LoDTensor>("pre_ids");
auto* pre_scores = context.Input<framework::LoDTensor>("pre_scores");
PADDLE_ENFORCE_NOT_NULL(ids);
PADDLE_ENFORCE_NOT_NULL(scores);
PADDLE_ENFORCE_NOT_NULL(pre_ids);
PADDLE_ENFORCE_NOT_NULL(pre_scores);
......@@ -211,14 +36,20 @@ class BeamSearchOpKernel : public framework::OpKernel<T> {
size_t level = context.Attr<int>("level");
size_t beam_size = context.Attr<int>("beam_size");
int end_id = context.Attr<int>("end_id");
BeamSearch alg(*ids, *scores, level, beam_size, end_id);
bool is_accumulated = context.Attr<bool>("is_accumulated");
auto selected_ids = context.Output<framework::LoDTensor>("selected_ids");
auto selected_scores =
context.Output<framework::LoDTensor>("selected_scores");
PADDLE_ENFORCE_NOT_NULL(selected_ids);
PADDLE_ENFORCE_NOT_NULL(selected_scores);
alg(*pre_ids, *pre_scores, selected_ids, selected_scores);
math::BeamSearchFunctor<DeviceContext, T> alg;
alg(context.template device_context<DeviceContext>(), pre_ids, pre_scores,
ids, scores, selected_ids, selected_scores, level, beam_size, end_id,
is_accumulated);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/beam_search_op.h"
#include <gtest/gtest.h>
#include <vector>
namespace paddle {
namespace test {
using std::vector;
using framework::LoDTensor;
using framework::LoD;
using operators::BeamSearch;
using paddle::platform::CPUPlace;
using std::cout;
using std::endl;
void CreateInput(LoDTensor* ids, LoDTensor* scores) {
LoD lod;
vector<size_t> level0({0, 2, 4});
vector<size_t> level1({0, 1, 2, 3, 4});
lod.push_back(level0);
lod.push_back(level1);
ids->set_lod(lod);
scores->set_lod(lod);
auto dims = framework::make_ddim(vector<int64_t>({4, 3}));
ids->Resize(dims);
scores->Resize(dims);
CPUPlace place;
auto* ids_data = ids->mutable_data<int64_t>(place);
auto* scores_data = scores->mutable_data<float>(place);
vector<int64_t> _ids({4, 2, 5, 2, 1, 3, 3, 5, 2, 8, 2, 1});
vector<float> _scores(
{0.5f, 0.3f, 0.2f, 0.6f, 0.3f, 0.1f, 0.9f, 0.5f, 0.1f, 0.7f, 0.5f, 0.1f});
for (int i = 0; i < 12; i++) {
ids_data[i] = _ids[i];
scores_data[i] = _scores[i];
}
}
// It seems that beam_search_op has bugs.
TEST(DISABLED_beam_search_op, run) {
CPUPlace place;
LoDTensor ids, scores;
CreateInput(&ids, &scores);
LoDTensor pre_ids;
pre_ids.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
for (int i = 0; i < 4; i++) {
pre_ids.mutable_data<int64_t>(place)[i] = i + 1;
}
LoDTensor pre_scores;
pre_scores.Resize(framework::make_ddim(vector<int64_t>(4, 1)));
for (int i = 0; i < 4; i++) {
pre_scores.mutable_data<float>(place)[i] = 0.1 * (i + 1);
}
BeamSearch beamsearch(ids, scores, (size_t)0, (size_t)2, 0);
LoDTensor sids, sscores;
beamsearch(pre_ids, pre_scores, &sids, &sscores);
LOG(INFO) << "score: " << sscores << endl;
ASSERT_EQ(sids.lod(), sscores.lod());
vector<int> tids({4, 2, 3, 8});
vector<float> tscores({0.5f, 0.6f, 0.9f, 0.7f});
for (int i = 0; i < 4; i++) {
ASSERT_EQ(tids[i], sids.data<int64_t>()[i]);
ASSERT_EQ(tscores[i], sscores.data<float>()[i]);
}
}
} // namespace test
} // namespace paddle
......@@ -87,8 +87,8 @@ class BprLossGradientOpKernel : public framework::OpKernel<T> {
auto* label = ctx.Input<Tensor>("Label");
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
const int step_size = x->dims()[0];
const int num_classes = x->dims()[1];
const size_t step_size = static_cast<size_t>(x->dims()[0]);
const size_t num_classes = static_cast<size_t>(x->dims()[1]);
T* dx_data = dx->mutable_data<T>(ctx.GetPlace());
const T* dy_data = dy->data<T>();
const T* x_data = x->data<T>();
......
......@@ -54,6 +54,7 @@ math_library(sequence_padding)
math_library(sequence_pooling DEPS math_function jit_kernel_helper)
math_library(sequence_scale)
math_library(softmax DEPS math_function)
math_library(beam_search DEPS math_function)
math_library(matrix_bit_code)
......@@ -68,6 +69,7 @@ cc_test(im2col_test SRCS im2col_test.cc DEPS im2col)
cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col)
cc_test(sequence_padding_test SRCS sequence_padding_test.cc DEPS sequence_padding)
cc_test(sequence_pooling_test SRCS sequence_pooling_test.cc DEPS sequence_pooling)
cc_test(beam_search_test SRCS beam_search_test.cc DEPS beam_search)
if(WITH_GPU)
nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function)
nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu.cc DEPS selected_rows_functor math_function)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include <algorithm>
#include <map>
namespace paddle {
namespace operators {
namespace math {
template <typename T>
class BeamSearchFunctor<platform::CPUDeviceContext, T> {
public:
void operator()(const platform::CPUDeviceContext &context,
const framework::LoDTensor *pre_ids,
const framework::LoDTensor *pre_scores,
const framework::LoDTensor *ids,
const framework::LoDTensor *scores,
framework::LoDTensor *selected_ids,
framework::LoDTensor *selected_scores, size_t level,
size_t beam_size, int end_id, bool is_accumulated) {
auto abs_lod = framework::ToAbsOffset(scores->lod());
auto &high_level = abs_lod[level];
auto items = SelectTopBeamSizeItems(pre_ids, pre_scores, ids, scores, level,
beam_size, end_id, is_accumulated);
auto selected_items = ToMap(items, high_level.back());
if (FLAGS_v == 3) {
VLOG(3) << "selected_items:";
for (size_t i = 0; i < selected_items.size(); ++i) {
VLOG(3) << "offset: " << i;
for (auto &item : selected_items[i]) {
VLOG(3) << item.ToString();
}
}
}
PruneEndBeams(pre_ids, abs_lod, &selected_items, level, end_id);
// calculate the output tensor's height
size_t num_instances = std::accumulate(
std::begin(selected_items), std::end(selected_items), 0,
[](size_t a, std::vector<Item> &b) { return a + b.size(); });
// the output tensor shape should be [num_instances, 1]
auto dims = framework::make_ddim(
std::vector<int64_t>({static_cast<int>(num_instances), 1}));
selected_ids->Resize(dims);
selected_scores->Resize(dims);
auto *selected_ids_data =
selected_ids->mutable_data<int64_t>(platform::CPUPlace());
auto *selected_scores_data =
selected_scores->mutable_data<float>(platform::CPUPlace());
// fill in data
std::vector<size_t> low_level;
size_t low_offset = 0;
for (auto &items : selected_items) {
low_level.push_back(low_offset);
for (auto &item : items) {
selected_ids_data[low_offset] = item.id;
selected_scores_data[low_offset] = item.score;
low_offset++;
}
}
low_level.push_back(low_offset);
// fill lod
framework::LoD lod(2);
lod[0].assign(high_level.begin(), high_level.end());
lod[1].assign(low_level.begin(), low_level.end());
if (!framework::CheckLoD(lod)) {
PADDLE_THROW("lod %s is not right", framework::LoDToString(lod));
}
selected_ids->set_lod(lod);
selected_scores->set_lod(lod);
}
/*
* The basic items help to sort.
*/
struct Item {
Item() {}
Item(size_t offset, size_t id, float score)
: offset(offset), id(id), score(score) {}
// offset in the higher lod level.
size_t offset;
// prefix id in the lower lod level.
// size_t prefix;
// the candidate id
size_t id;
// the corresponding score
float score;
inline bool operator<(const Item &in) const {
return (score < in.score) ||
((score == in.score) && (offset < in.offset));
}
inline void operator=(const Item &in) {
offset = in.offset;
id = in.id;
score = in.score;
}
std::string ToString() {
std::ostringstream os;
os << "{";
os << "offset: " << offset << ", ";
os << "id: " << id << ", ";
os << "score: " << score << "";
os << "}";
return os.str();
}
};
protected:
/*
* Prune the source sentences all branchs finished, and it is optional.
* Pruning must one step later than finishing (thus pre_ids is needed here),
* since the end tokens must be writed out.
*/
void PruneEndBeams(const framework::LoDTensor *pre_ids,
const framework::LoD &abs_lod,
std::vector<std::vector<Item>> *items, size_t lod_level,
int end_id) {
auto *pre_ids_data = pre_ids->data<int64_t>();
auto &high_level = abs_lod[lod_level];
for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
size_t src_prefix_start = high_level[src_idx];
size_t src_prefix_end = high_level[src_idx + 1];
bool finish_flag = true;
for (size_t offset = src_prefix_start; offset < src_prefix_end;
offset++) {
for (auto &item : items->at(offset)) {
if (item.id != static_cast<size_t>(end_id) ||
pre_ids_data[offset] != end_id) {
finish_flag = false;
break;
}
}
if (!finish_flag) break;
}
if (finish_flag) { // all branchs of the beam (source sentence) end and
// prune this beam
for (size_t offset = src_prefix_start; offset < src_prefix_end;
offset++)
items->at(offset).clear();
}
}
}
/*
* Transform the items into a map whose key is offset, value is the items.
* NOTE low performance.
*/
std::vector<std::vector<Item>> ToMap(
const std::vector<std::vector<Item>> &items, size_t element_num) {
std::vector<std::vector<Item>> result;
result.resize(element_num);
for (auto &entries : items) {
for (const auto &item : entries) {
result[item.offset].push_back(item);
}
}
return result;
}
void Insert(std::vector<Item> *top_beam_ptr, const Item &item,
size_t beam_size) {
std::vector<Item> &top_beam = *top_beam_ptr;
size_t num_beams = top_beam.size();
if (num_beams < beam_size) {
top_beam.resize(num_beams + 1);
num_beams++;
} else {
if (item < top_beam[beam_size - 1]) {
return;
}
}
for (int k = static_cast<int>(num_beams) - 2; k >= 0; --k) {
if (top_beam[k] < item) {
top_beam[k + 1] = top_beam[k];
} else {
top_beam[k + 1] = item;
return;
}
}
top_beam[0] = item;
}
/*
* For each source, select top beam_size records.
*/
std::vector<std::vector<Item>> SelectTopBeamSizeItems(
const framework::LoDTensor *pre_ids,
const framework::LoDTensor *pre_scores, const framework::LoDTensor *ids,
const framework::LoDTensor *scores, size_t lod_level, size_t beam_size,
int end_id, bool is_accumulated) {
std::vector<std::vector<Item>> result;
// find the current candidates
auto abs_lod = framework::ToAbsOffset(scores->lod());
auto *pre_ids_data = pre_ids->data<int64_t>();
auto *pre_scores_data = pre_scores->data<float>();
auto *ids_data = ids ? ids->data<int64_t>() : nullptr;
auto *scores_data = scores->data<float>();
size_t num_seqs = scores->NumElements(lod_level);
size_t seq_width = 1;
for (int i = 1; i < scores->dims().size(); i++) {
seq_width *= scores->dims()[i];
}
for (size_t seq_id = 0; seq_id < num_seqs; ++seq_id) {
size_t seq_offset_start = abs_lod[lod_level][seq_id];
size_t seq_offset_end = abs_lod[lod_level][seq_id + 1];
std::vector<Item> top_beam;
top_beam.reserve(beam_size);
for (size_t offset = seq_offset_start; offset < seq_offset_end;
++offset) {
auto pre_id = pre_ids_data[offset];
auto pre_score = pre_scores_data[offset];
if (pre_id == end_id) {
// Allocate all probability mass to end_id for finished branchs and
// the other candidate ids can be ignored.
Item item(offset, end_id, pre_score);
Insert(&top_beam, item, beam_size);
} else {
size_t index = offset * seq_width;
for (size_t d = 0; d < seq_width; d++, index++) {
int64_t id = ids_data ? ids_data[index] : static_cast<int64_t>(d);
float score = is_accumulated
? scores_data[index]
: pre_score + std::log(scores_data[index]);
Item item(offset, id, score);
Insert(&top_beam, item, beam_size);
}
}
}
result.emplace_back(top_beam);
}
if (FLAGS_v == 3) {
VLOG(3) << "SelectTopBeamSizeItems result size " << result.size();
for (auto &items : result) {
VLOG(3) << "item set:";
for (auto &item : items) {
VLOG(3) << item.ToString();
}
}
}
return result;
}
};
template class BeamSearchFunctor<platform::CPUDeviceContext, int>;
template class BeamSearchFunctor<platform::CPUDeviceContext, int64_t>;
template class BeamSearchFunctor<platform::CPUDeviceContext, float>;
template class BeamSearchFunctor<platform::CPUDeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include "paddle/fluid/platform/cuda_device_function.h"
namespace paddle {
namespace operators {
namespace math {
struct Triple {
__device__ __forceinline__ Triple() {}
__device__ __forceinline__ Triple(int o, int i, float s)
: offset(o), id(i), score(s) {}
__device__ __forceinline__ void set(int o, int i, float s) {
offset = o;
id = i;
score = s;
}
__device__ __forceinline__ void operator=(const Triple& in) {
offset = in.offset;
id = in.id;
score = in.score;
}
__device__ __forceinline__ bool operator<(const float s) const {
return score < s;
}
__device__ __forceinline__ bool operator<(const Triple& in) const {
return (score < in.score) || ((score == in.score) && (offset < in.offset));
}
int offset;
int id;
float score;
};
__device__ __forceinline__ void Insert(Triple* top_beam, const Triple& p,
int beam_size) {
if (p < top_beam[beam_size - 1]) {
return;
}
for (int k = beam_size - 2; k >= 0; --k) {
if (top_beam[k] < p) {
top_beam[k + 1] = top_beam[k];
} else {
top_beam[k + 1] = p;
return;
}
}
top_beam[0] = p;
}
template <int MaxThreadsPerSeq, bool IsAccumulated = true>
__device__ __forceinline__ int SelectTopBeam(
Triple* top_beam, const int64_t* pre_ids, const float* pre_scores,
const int64_t* ids, const float* scores, const int seq_offset_start,
const int seq_offset_end, const int seq_width, int beam_size, int end_id,
int used_threads) {
// top_beam is shared memory
const int tid = threadIdx.x;
const int tid_of_seq = threadIdx.x % MaxThreadsPerSeq;
int num_used_threads = used_threads;
Triple* top_beam_local = top_beam + tid * beam_size;
if (tid_of_seq < num_used_threads) {
for (int i = 0; i < beam_size; ++i) {
top_beam_local[i].set(-1, -1, -INFINITY);
}
for (int offset = seq_offset_start; offset < seq_offset_end; ++offset) {
int pre_id = static_cast<int>(pre_ids[offset]);
if (pre_id == end_id) {
if (tid_of_seq == 0) {
Triple tmp(offset, end_id, pre_scores[offset]);
Insert(top_beam_local, tmp, beam_size);
}
} else {
int index = offset * seq_width + tid_of_seq;
if (!IsAccumulated) {
float pre_score = pre_scores[offset];
for (int i = tid_of_seq; i < seq_width; i += num_used_threads) {
float score = pre_score + __logf(scores[index]);
int id = ids ? static_cast<int>(ids[index]) : i;
Triple tmp(offset, id, score);
Insert(top_beam_local, tmp, beam_size);
index += num_used_threads;
}
} else {
for (int i = tid_of_seq; i < seq_width; i += num_used_threads) {
int id = ids ? static_cast<int>(ids[index]) : i;
float score = scores[index];
Triple tmp(offset, id, score);
Insert(top_beam_local, tmp, beam_size);
index += num_used_threads;
}
}
}
}
}
while (num_used_threads > 1) {
if (num_used_threads > 16) {
__syncthreads();
}
num_used_threads = num_used_threads >> 1;
if (tid_of_seq < num_used_threads) {
int index_in_sh = (num_used_threads + tid) * beam_size;
for (int i = 0; i < beam_size; i++) {
Insert(top_beam_local, top_beam[index_in_sh], beam_size);
index_in_sh++;
}
}
}
if (tid_of_seq == 0) {
int num_items = 0;
for (int i = 0; i < beam_size; ++i) {
num_items =
(top_beam_local[i].score > -INFINITY) ? num_items + 1 : num_items;
}
return num_items;
}
return 0;
}
__device__ __forceinline__ bool PruneEndBeams(Triple* top_beam_local,
const int64_t* pre_ids,
const int end_id, int num_items) {
bool finish_flag = true;
for (int i = 0; i < num_items; ++i) {
int offset = top_beam_local[i].offset;
if (top_beam_local[i].id != end_id ||
static_cast<int>(pre_ids[offset]) != end_id) {
finish_flag = false;
break;
}
}
return finish_flag;
}
__device__ __forceinline__ void WriteBack(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets,
Triple* top_beam_local, const int seq_offset_start,
const int seq_offset_end, const int selected_seq_start,
const int selected_seq_length) {
const int tid = threadIdx.x; // use 1 thread only for each sequence
int global_index = selected_seq_start;
for (int global_offset = seq_offset_start; global_offset < seq_offset_end;
++global_offset) {
for (int local_index = 0; local_index < selected_seq_length;
++local_index) {
if (top_beam_local[local_index].offset == global_offset) {
selected_ids[global_index] =
static_cast<int64_t>(top_beam_local[local_index].id);
selected_scores[global_index] = top_beam_local[local_index].score;
global_index++;
}
}
selected_offsets[global_offset + 1] = static_cast<size_t>(global_index);
}
}
template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs>
__device__ void BeamSearchDetails(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets,
const int64_t* pre_ids, const float* pre_scores, const int64_t* ids,
const float* scores, const int seq_offset_start, const int seq_offset_end,
const int seq_width, int beam_size, int end_id, bool is_accumulated,
int num_used_threads) {
__shared__ Triple top_beam[MaxLength];
int num_items = 0;
if (is_accumulated) {
num_items = SelectTopBeam<MaxThreadsPerSeq, true>(
top_beam, pre_ids, pre_scores, ids, scores, seq_offset_start,
seq_offset_end, seq_width, beam_size, end_id, num_used_threads);
} else {
num_items = SelectTopBeam<MaxThreadsPerSeq, false>(
top_beam, pre_ids, pre_scores, ids, scores, seq_offset_start,
seq_offset_end, seq_width, beam_size, end_id, num_used_threads);
}
const int tid = threadIdx.x; // use 1 thread only for each sequence
const int tid_of_seq = tid % MaxThreadsPerSeq;
if (tid_of_seq == 0) {
// Use 1 thread for each sequence.
Triple* top_beam_local = top_beam + tid * beam_size;
bool finish_flag =
PruneEndBeams(top_beam_local, pre_ids, end_id, num_items);
int selected_seq_start = 0;
int selected_seq_length = finish_flag ? 0 : num_items;
if (MaxSeqs > 1) {
const int seq_id = (MaxSeqs > 1) ? tid / MaxThreadsPerSeq : tid;
__shared__ int shared_mem[MaxSeqs];
// [0, MaxSeqs - 1], length of each sequences
shared_mem[seq_id] = selected_seq_length;
__syncthreads();
for (int s = 0; s < seq_id; ++s) {
selected_seq_start += shared_mem[s];
}
if (seq_id == 0) {
selected_offsets[0] = 0;
}
} else {
selected_offsets[0] = 0;
}
WriteBack(selected_ids, selected_scores, selected_offsets, top_beam_local,
seq_offset_start, seq_offset_end, selected_seq_start,
selected_seq_length);
}
}
template <int MaxLength, int MaxThreadsPerSeq, int MaxSeqs>
__global__ void BeamSearchKernel(int64_t* selected_ids, float* selected_scores,
size_t* selected_offsets,
const int64_t* pre_ids,
const float* pre_scores, const int64_t* ids,
const float* scores, const size_t* seq_offsets,
const int num_seqs, const int seq_width,
int beam_size, int end_id, bool is_accumulated,
int num_used_threads) {
const int tid = threadIdx.x;
const int seq_id = (MaxSeqs > 1) ? tid / MaxThreadsPerSeq : tid;
int seq_offset_start = static_cast<int>(seq_offsets[seq_id]);
int seq_offset_end = static_cast<int>(seq_offsets[seq_id + 1]);
BeamSearchDetails<MaxLength, MaxThreadsPerSeq, MaxSeqs>(
selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids,
scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id,
is_accumulated, num_used_threads);
}
template <int MaxLength, int MaxThreadsPerSeq>
__global__ void BeamSearchKernelSingle(
int64_t* selected_ids, float* selected_scores, size_t* selected_offsets,
const int64_t* pre_ids, const float* pre_scores, const int64_t* ids,
const float* scores, const int seq_length, const int seq_width,
int beam_size, int end_id, bool is_accumulated, int num_used_threads) {
const int seq_offset_start = 0;
const int seq_offset_end = seq_length;
BeamSearchDetails<MaxLength, MaxThreadsPerSeq, 1>(
selected_ids, selected_scores, selected_offsets, pre_ids, pre_scores, ids,
scores, seq_offset_start, seq_offset_end, seq_width, beam_size, end_id,
is_accumulated, num_used_threads);
}
static inline int GetNumUsedThreads(const int max_threads_per_seq,
const int seq_width, int beam_size) {
int num_used_threads = (seq_width + beam_size - 1) / beam_size;
num_used_threads = max_threads_per_seq < num_used_threads
? max_threads_per_seq
: num_used_threads;
num_used_threads =
num_used_threads > 32
? (num_used_threads >> 5) << 5
: (num_used_threads > 16
? 32
: (num_used_threads > 8
? 16
: (num_used_threads > 4
? 8
: (num_used_threads > 2 ? 4
: num_used_threads))));
return num_used_threads;
}
template <typename T>
class BeamSearchFunctor<platform::CUDADeviceContext, T> {
public:
void operator()(const platform::CUDADeviceContext& context,
const framework::LoDTensor* pre_ids,
const framework::LoDTensor* pre_scores,
const framework::LoDTensor* ids,
const framework::LoDTensor* scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores, size_t level,
size_t beam_size, int end_id, bool is_accumulated) {
auto abs_lod = framework::ToAbsOffset(scores->lod());
const int64_t* pre_ids_data = pre_ids->data<int64_t>();
const float* pre_scores_data = pre_scores->data<float>();
const int64_t* ids_data = ids ? ids->data<int64_t>() : nullptr;
const float* scores_data = scores->data<float>();
const size_t num_seqs = abs_lod[level].size() - 1;
size_t seq_width = 1;
for (int i = 1; i < scores->dims().size(); i++) {
seq_width *= scores->dims()[i];
}
// Reserve a big enough memory.
auto selected_dims =
framework::make_ddim({static_cast<int64_t>(num_seqs * beam_size), 1});
int64_t* selected_ids_data =
selected_ids->mutable_data<int64_t>(selected_dims, context.GetPlace());
float* selected_scores_data =
selected_scores->mutable_data<float>(selected_dims, context.GetPlace());
framework::LoD selected_lod(2);
selected_lod[0].assign(abs_lod[level].begin(), abs_lod[level].end());
selected_lod[1].resize(scores->dims()[0] + 1);
size_t* selected_offsets =
selected_lod[1].CUDAMutableData(context.GetPlace());
if (num_seqs == 1) {
const int seq_length = static_cast<int>(abs_lod[level][1]);
const int kMaxThreadsPerSeq = 1024;
int num_used_threads =
GetNumUsedThreads(kMaxThreadsPerSeq, static_cast<int>(seq_width),
static_cast<int>(beam_size));
switch (platform::RoundToPowerOfTwo(beam_size * seq_width)) {
CUDA_LAUNCH_KERNEL_HELPER(
BeamSearchKernelSingle<kPowerOfTwoDim, kMaxThreadsPerSeq><<<
1, kMaxThreadsPerSeq, 0, context.stream()>>>(
selected_ids_data, selected_scores_data, selected_offsets,
pre_ids_data, pre_scores_data, ids_data, scores_data,
seq_length, static_cast<int>(seq_width),
static_cast<int>(beam_size), static_cast<int>(end_id),
is_accumulated, num_used_threads));
}
} else if (num_seqs <= 4) {
const size_t* seq_offsets = abs_lod[level].CUDAData(context.GetPlace());
// Use only 1 block
const int kMaxThreadsPerSeq = 32;
const int kMaxSeqs = 4;
int num_used_threads =
GetNumUsedThreads(kMaxThreadsPerSeq, static_cast<int>(seq_width),
static_cast<int>(beam_size));
switch (platform::RoundToPowerOfTwo(beam_size * num_seqs * 32)) {
CUDA_LAUNCH_KERNEL_HELPER(
BeamSearchKernel<kPowerOfTwoDim, kMaxThreadsPerSeq, kMaxSeqs><<<
1, num_seqs * kMaxThreadsPerSeq, 0, context.stream()>>>(
selected_ids_data, selected_scores_data, selected_offsets,
pre_ids_data, pre_scores_data, ids_data, scores_data,
seq_offsets, static_cast<int>(num_seqs),
static_cast<int>(seq_width), static_cast<int>(beam_size),
end_id, is_accumulated, num_used_threads));
}
} else {
LOG(FATAL) << "Not implemented.";
}
context.Wait();
if (!framework::CheckLoD(selected_lod)) {
PADDLE_THROW("lod %s is not right", framework::LoDToString(selected_lod));
}
selected_ids->set_lod(selected_lod);
selected_scores->set_lod(selected_lod);
if (selected_lod[1].back() < num_seqs * beam_size) {
auto final_selected_dims = framework::make_ddim(
{static_cast<int64_t>(selected_lod[1].back()), 1});
selected_ids->Resize(final_selected_dims);
selected_scores->Resize(final_selected_dims);
}
}
};
template class BeamSearchFunctor<platform::CUDADeviceContext, int>;
template class BeamSearchFunctor<platform::CUDADeviceContext, int64_t>;
template class BeamSearchFunctor<platform::CUDADeviceContext, float>;
template class BeamSearchFunctor<platform::CUDADeviceContext, double>;
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
namespace math {
/*
* This is an implementation of beam search.
*
* To explain the details, lets take machine translation task for example, in
* this task, one source sentence is translated to multiple target sentences,
* during this period, one sentence will be translated to multiple translation
* prefixes(target sentence that have not ended), in each time step a prefix
* will have some candidates, input the candidate ids and their corresponding
* scores (probabilities), it will sort and select the top beam_size candidates
* for each source sentence, and store the selected candidates's score and their
* corresponding ids to LoDTensors.
*
* A detailed example:
*
* Input
*
* ids:
* - LoD (should have 2 levels)
* - first level: [0, 1, 4]
* - second level: [0, 1, 2, 3, 4]
* - tensor's data:
* [[4, 2, 5]
* [2, 1, 3]
* [3, 5, 2]
* [8, 2, 1]]
*
* scores:
* - LoD same as `ids`
* - tensor's data
* [[0.5, 0.3, 0.2]
* [0.6, 0.3, 0.1]
* [0.9, 0.5, 0.1]
* [0.7, 0.5, 0.1]]
*
* The inputs means that there are 2 source sentences to translate, and the
* first source has 1 prefix, the second source has 2 prefix.
*
* Lets assume beam size is 2, and the beam search's output should be
* - LoD
* - first level: [0, 1, 2]
* - second level: [0, 2, 4]
* - id tensor's data
* [[4,
* 1,
* 3,
* 8]]
* - score tensor's data
* [[0.5,
* 0.3,
* 0.9,
* 0.7]]
*
* TODO all the prune operations should be in the beam search, so it is better
* to split the beam search algorithm into a sequence of smaller operators, and
* the prune operators can be inserted in this sequence.
*/
template <typename DeviceContext, typename T>
class BeamSearchFunctor {
public:
/*
* The main function of beam search.
*
* @selected_ids: a [None, 1]-shaped tensor with LoD.
* In a machine translation model, it might be the candidate term id sets,
* each set stored as a varience-length sequence.
* The format might be described with a two-level LoD
* - [[0 1],
* [0 1 2]]
* - [[]
* [0 1]]
* the first level of LoD tells that there are two source sentences. The
* second level describes the details of the candidate id set's offsets in
* the source sentences.
*
* @selected_scores: a LoD tensor with the same shape and LoD with
* selected_ids.
* It stores the corresponding scores of candidate ids in selected_ids.
*
* Return false if all the input tensor is empty, in machine translation task
* that means no candidates is provided, and the task will stop running.
*/
void operator()(const DeviceContext& context,
const framework::LoDTensor* pre_ids,
const framework::LoDTensor* pre_scores,
const framework::LoDTensor* ids,
const framework::LoDTensor* scores,
framework::LoDTensor* selected_ids,
framework::LoDTensor* selected_scores, size_t level,
size_t beam_size, int end_id, bool is_accumulated);
};
} // namespace math
} // namespace operators
} // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/math/beam_search.h"
#include <gtest/gtest.h>
#include <vector>
void PrepareCPUTensors(paddle::framework::LoDTensor* ids,
paddle::framework::LoDTensor* scores,
paddle::framework::LoDTensor* pre_ids,
paddle::framework::LoDTensor* pre_scores) {
// lod
paddle::framework::LoD lod;
std::vector<size_t> level0({0, 2, 4});
std::vector<size_t> level1({0, 1, 2, 3, 4});
lod.push_back(level0);
lod.push_back(level1);
ids->set_lod(lod);
scores->set_lod(lod);
auto dims = paddle::framework::make_ddim({4, 3});
ids->Resize(dims);
scores->Resize(dims);
paddle::platform::CPUPlace place;
auto* ids_data = ids->mutable_data<int64_t>(place);
auto* scores_data = scores->mutable_data<float>(place);
std::vector<int64_t> ids_vec_data({4, 2, 5, 2, 1, 3, 3, 5, 2, 8, 2, 1});
std::vector<float> scores_vec_data(
{0.6f, 0.3f, 0.5f, 0.2f, 0.3f, 0.1f, 0.9f, 0.5f, 0.1f, 0.7f, 0.5f, 0.1f});
CHECK_EQ(static_cast<size_t>(ids->numel()), ids_vec_data.size());
CHECK_EQ(static_cast<size_t>(ids->numel()), scores_vec_data.size());
for (int i = 0; i < ids->numel(); i++) {
ids_data[i] = ids_vec_data[i];
scores_data[i] = scores_vec_data[i];
}
// pre_ids
pre_ids->Resize(paddle::framework::make_ddim({4, 1}));
for (int i = 0; i < 4; i++) {
pre_ids->mutable_data<int64_t>(place)[i] = i + 1;
}
// pre_scores
pre_scores->Resize(paddle::framework::make_ddim({4, 1}));
for (int i = 0; i < 4; i++) {
pre_scores->mutable_data<float>(place)[i] = 0.1 * (i + 1);
}
}
template <typename DeviceContext, typename Place>
void TestBeamSearch() {
paddle::framework::LoDTensor ids;
paddle::framework::LoDTensor scores;
paddle::framework::LoDTensor pre_ids;
paddle::framework::LoDTensor pre_scores;
auto* place = new Place();
DeviceContext* context = new DeviceContext(*place);
if (paddle::platform::is_cpu_place(*place)) {
PrepareCPUTensors(&ids, &scores, &pre_ids, &pre_scores);
} else {
paddle::framework::LoDTensor cpu_ids;
paddle::framework::LoDTensor cpu_scores;
paddle::framework::LoDTensor cpu_pre_ids;
paddle::framework::LoDTensor cpu_pre_scores;
PrepareCPUTensors(&cpu_ids, &cpu_scores, &cpu_pre_ids, &cpu_pre_scores);
TensorCopySync(cpu_ids, *place, &ids);
TensorCopySync(cpu_scores, *place, &scores);
TensorCopySync(cpu_pre_ids, *place, &pre_ids);
TensorCopySync(cpu_pre_scores, *place, &pre_scores);
ids.set_lod(cpu_ids.lod());
scores.set_lod(cpu_scores.lod());
pre_ids.set_lod(cpu_pre_ids.lod());
pre_scores.set_lod(cpu_pre_scores.lod());
}
paddle::framework::LoDTensor selected_ids;
paddle::framework::LoDTensor selected_scores;
size_t level = 0;
size_t beam_size = 2;
int end_id = 0;
paddle::operators::math::BeamSearchFunctor<DeviceContext, float> beamsearch;
beamsearch(*context, &pre_ids, &pre_scores, &ids, &scores, &selected_ids,
&selected_scores, level, beam_size, end_id, true);
ASSERT_EQ(selected_ids.lod(), selected_scores.lod());
paddle::framework::LoDTensor cpu_selected_ids;
paddle::framework::LoDTensor cpu_selected_scores;
if (paddle::platform::is_cpu_place(*place)) {
cpu_selected_ids = selected_ids;
cpu_selected_scores = selected_scores;
} else {
TensorCopySync(selected_ids, paddle::platform::CPUPlace(),
&cpu_selected_ids);
TensorCopySync(selected_scores, paddle::platform::CPUPlace(),
&cpu_selected_scores);
cpu_selected_ids.set_lod(selected_ids.lod());
cpu_selected_scores.set_lod(selected_scores.lod());
}
std::vector<int64_t> expected_ids({4, 5, 3, 8});
std::vector<float> expected_scores({0.6f, 0.5f, 0.9f, 0.7f});
for (int i = 0; i < 4; i++) {
ASSERT_EQ(expected_ids[i], cpu_selected_ids.data<int64_t>()[i]);
ASSERT_EQ(expected_scores[i], cpu_selected_scores.data<float>()[i]);
}
delete place;
delete context;
}
TEST(BeamSearch, CPU) {
TestBeamSearch<paddle::platform::CPUDeviceContext,
paddle::platform::CPUPlace>();
}
#ifdef PADDLE_WITH_CUDA
TEST(BeamSearch, GPU) {
TestBeamSearch<paddle::platform::CUDADeviceContext,
paddle::platform::CUDAPlace>();
}
#endif
......@@ -354,7 +354,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) {
auto* out_data = output->value().data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
for (size_t j = 0; j < static_cast<size_t>(row_numel); ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}
......
......@@ -301,7 +301,7 @@ TEST(selected_rows_functor, gpu_merge_add) {
auto* out_data = output_cpu.data<float>();
for (size_t i = 0; i < ret_rows.size(); ++i) {
for (size_t j = 0; j < row_numel; ++j) {
for (size_t j = 0; j < static_cast<size_t>(row_numel); ++j) {
EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]);
}
}
......
......@@ -66,7 +66,7 @@ void TestSequencePoolingSum(const paddle::framework::LoD& lod) {
cpu_in_grad.set_lod(in_grad.lod());
}
EXPECT_EQ(in_grad.numel(), lod[0].back() * second_dim);
EXPECT_EQ(in_grad.numel(), static_cast<int64_t>(lod[0].back() * second_dim));
EXPECT_EQ(in_grad.lod(), lod);
if (paddle::platform::is_cpu_place(*place)) {
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda.h>
// NOTE(): support float16 to half in header file.
#define PADDLE_CUDA_FP16
......@@ -30,6 +31,34 @@ namespace platform {
mask = __ballot_sync(FULL_WARP_MASK, (predicate))
#endif
inline static int RoundToPowerOfTwo(int dim) {
if (dim > 512) {
return 1024;
} else if (dim > 256) {
return 512;
} else if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
#define CUDA_LAUNCH_KERNEL_BASE(dim, ...) \
case (dim): { \
constexpr auto kPowerOfTwoDim = (dim); \
__VA_ARGS__; \
} break
#define CUDA_LAUNCH_KERNEL_HELPER(...) \
CUDA_LAUNCH_KERNEL_BASE(256, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(128, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(64, ##__VA_ARGS__); \
CUDA_LAUNCH_KERNEL_BASE(32, ##__VA_ARGS__);
template <typename T>
__forceinline__ __device__ T CudaShuffleDownSync(unsigned mask, T val,
int delta, int width = 32) {
......
......@@ -221,13 +221,17 @@ size_t GpuMaxChunkSize() {
void GpuMemcpyAsync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind, cudaStream_t stream) {
PADDLE_ENFORCE(cudaMemcpyAsync(dst, src, count, kind, stream),
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync");
"cudaMemcpyAsync failed in paddle::platform::GpuMemcpyAsync "
"(%p -> %p, length: %d)",
src, dst, static_cast<int>(count));
}
void GpuMemcpySync(void *dst, const void *src, size_t count,
enum cudaMemcpyKind kind) {
PADDLE_ENFORCE(cudaMemcpy(dst, src, count, kind),
"cudaMemcpy failed in paddle::platform::GpuMemcpySync");
"cudaMemcpy failed in paddle::platform::GpuMemcpySync (%p -> "
"%p, length: %d)",
src, dst, static_cast<int>(count));
}
void GpuMemcpyPeerAsync(void *dst, int dst_device, const void *src,
......
......@@ -3875,6 +3875,7 @@ def beam_search(pre_ids,
beam_size,
end_id,
level=0,
is_accumulated=True,
name=None):
"""
Beam search is a classical algorithm for selecting candidate words in a
......@@ -3887,14 +3888,17 @@ def beam_search(pre_ids,
selects the top-K candidate word ids of current step from :attr:`ids`
according to their :attr:`scores` for all source sentences, where K is
:attr:`beam_size` and :attr:`ids, scores` are predicted results from the
computation cell. Additionally, :attr:`pre_ids` and :attr:`pre_scores` are
the output of beam_search at previous step, they are needed for special use
to handle ended candidate translations.
Note that the :attr:`scores` passed in should be accumulated scores, and
length penalty should be done with extra operators before calculating the
accumulated scores if needed, also suggest finding top-K before it and
using the top-K candidates following.
computation cell. If :attr:`ids` is not set, it will be calculated out
according to :attr:`scores`. Additionally, :attr:`pre_ids` and
:attr:`pre_scores` are the output of beam_search at previous step, they
are needed for special use to handle ended candidate translations.
Note that if :attr:`is_accumulated` is :attr:`True`, the :attr:`scores`
passed in should be accumulated scores. Else, the :attr:`scores` are
considered as the straightforward scores and will be transformed to the
log field and accumulated the :attr:`pre_scores` in this operator.
Length penalty should be done with extra operators before calculating the
accumulated scores if needed.
Please see the following demo for a fully beam search usage example:
......@@ -3924,6 +3928,8 @@ def beam_search(pre_ids,
describes how these candidates belong to the prefix. The paths
linking prefixes and selected candidates are organized and reserved
in lod.
is_accumulated(bool, default True): Whether the input :attr:`score` is
accumulated scores.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
......@@ -3952,8 +3958,12 @@ def beam_search(pre_ids,
end_id=end_id)
"""
helper = LayerHelper('beam_search', **locals())
score_type = scores.dtype
id_type = ids.dtype
score_type = pre_scores.dtype
id_type = pre_ids.dtype
inputs = {"pre_ids": pre_ids, "pre_scores": pre_scores, "scores": scores}
if ids is not None:
inputs["ids"] = ids
selected_scores = helper.create_variable_for_type_inference(
dtype=score_type)
......@@ -3961,12 +3971,7 @@ def beam_search(pre_ids,
helper.append_op(
type='beam_search',
inputs={
'pre_ids': pre_ids,
'pre_scores': pre_scores,
'ids': ids,
'scores': scores,
},
inputs=inputs,
outputs={
'selected_ids': selected_ids,
'selected_scores': selected_scores,
......@@ -3976,6 +3981,7 @@ def beam_search(pre_ids,
'level': level,
'beam_size': beam_size,
'end_id': end_id,
'is_accumulated': is_accumulated,
})
return selected_ids, selected_scores
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册