提交 d15b2e02 编写于 作者: G guosheng

Fix copying empty tensor in beam_search_decode_op

上级 983566d9
...@@ -42,9 +42,11 @@ struct BeamSearchDecodeFunctor { ...@@ -42,9 +42,11 @@ struct BeamSearchDecodeFunctor {
// Copy all tensors in the input tensor array // Copy all tensors in the input tensor array
for (auto& step_id : step_ids_origin_) { for (auto& step_id : step_ids_origin_) {
framework::LoDTensor out; framework::LoDTensor out;
dev_ctx->Wait(); if (step_id.numel() > 0) {
framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out); dev_ctx->Wait();
dev_ctx->Wait(); framework::TensorCopy(step_id, platform::CPUPlace(), *dev_ctx, &out);
dev_ctx->Wait();
}
out.set_lod(step_id.lod()); out.set_lod(step_id.lod());
step_ids_.push_back(out); step_ids_.push_back(out);
...@@ -58,9 +60,12 @@ struct BeamSearchDecodeFunctor { ...@@ -58,9 +60,12 @@ struct BeamSearchDecodeFunctor {
// Copy all tensors in the input tensor array // Copy all tensors in the input tensor array
for (auto& step_score : step_scores_origin_) { for (auto& step_score : step_scores_origin_) {
framework::LoDTensor out; framework::LoDTensor out;
dev_ctx->Wait(); if (step_score.numel() > 0) {
framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx, &out); dev_ctx->Wait();
dev_ctx->Wait(); framework::TensorCopy(step_score, platform::CPUPlace(), *dev_ctx,
&out);
dev_ctx->Wait();
}
out.set_lod(step_score.lod()); out.set_lod(step_score.lod());
step_scores_.push_back(out); step_scores_.push_back(out);
......
...@@ -151,8 +151,7 @@ void BeamSearchDecoder<T>::Backtrace(const LoDTensorArray& step_ids, ...@@ -151,8 +151,7 @@ void BeamSearchDecoder<T>::Backtrace(const LoDTensorArray& step_ids,
const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1; const size_t src_num = step_ids.at(0).lod().at(kSourceLevel).size() - 1;
std::vector<SentenceVector<T>> sentence_vector_list( std::vector<SentenceVector<T>> sentence_vector_list(
src_num, SentenceVector<T>(beam_size_)); src_num, SentenceVector<T>(beam_size_));
std::vector<std::vector<size_t>> prefix_idx_vector_list( std::vector<std::vector<size_t>> prefix_idx_vector_list(src_num);
src_num, std::vector<size_t>());
for (int step_id = step_num - 1; step_id >= 0; --step_id) { for (int step_id = step_num - 1; step_id >= 0; --step_id) {
auto& cur_ids = step_ids.at(step_id); auto& cur_ids = step_ids.at(step_id);
auto& cur_scores = step_scores.at(step_id); auto& cur_scores = step_scores.at(step_id);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册