From 14f5a4089844fe3afa8ff4810a5431aaa03b2156 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Thu, 25 Oct 2018 03:16:26 +0000 Subject: [PATCH] fix unit test --- .../fluid/operators/math/selected_rows_functor.cu | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 20d1b2ed7bc..d237abc880b 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -267,10 +267,15 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const framework::SelectedRows& input, framework::SelectedRows* output) { - framework::SelectedRows& out = *output; framework::Vector input_rows(input.rows()); + if (input_rows.size() == 0) { + return; + } + + framework::SelectedRows& out = *output; std::set row_set(input_rows.begin(), input_rows.end()); - std::vector merge_rows(row_set.begin(), row_set.end()); + std::vector merge_rows_cpu(row_set.begin(), row_set.end()); + framework::Vector merge_rows(merge_rows_cpu); auto input_width = input.value().dims()[1]; @@ -313,8 +318,9 @@ struct MergeAdd { "all input should have same height"); merged_row_set.insert(input->rows().begin(), input->rows().end()); } - std::vector merge_rows(merged_row_set.begin(), + std::vector merge_rows_cpu(merged_row_set.begin(), merged_row_set.end()); + framework::Vector merge_rows(merge_rows_cpu); out.set_rows(merge_rows); out.set_height(input_height); @@ -334,6 +340,9 @@ struct MergeAdd { for (auto* input : inputs) { auto* input_data = input->value().data(); auto& input_rows = input->rows(); + if (input_rows.size() == 0) { + continue; + } dim3 grid1(input_rows.size(), 1); MergeAddKernel<<>>( -- GitLab