From dd78b5df93ad9369a501568fa541316b06515cf1 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 27 Oct 2018 20:39:56 +0800 Subject: [PATCH] sum op handle empty input --- .../operators/math/selected_rows_functor.cc | 27 ++++++++++++++++--- .../math/selected_rows_functor_test.cc | 2 -- .../math/selected_rows_functor_test.cu | 2 -- paddle/fluid/operators/sum_op.h | 10 +++++-- 4 files changed, 31 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 2679f501dac..305743b082a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -269,12 +269,29 @@ struct MergeAdd { void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, framework::SelectedRows* output) { - PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); - auto input_width = inputs[0]->value().dims()[1]; - auto input_height = inputs[0]->height(); + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (!in->rows().empty()) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; for (auto* input : inputs) { + if (input->rows().empty()) { + continue; + } PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], "all input should have same " "dimension except for the first one"); @@ -288,7 +305,6 @@ struct MergeAdd { for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; } - out.set_rows(merge_rows); out.set_height(input_height); out.mutable_value()->mutable_data( @@ -303,6 +319,9 @@ struct MergeAdd { auto blas = math::GetBlas(context); for (auto* input : inputs) { + if (input->rows().empty()) { + continue; + } auto* input_data = input->value().data(); auto& input_rows = input->rows(); diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index f5165fa5353..f15b37a1e3f 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -356,9 +356,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t j = 0; j < row_numel; ++j) { EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); - std::cout << out_data[i * row_numel + j] << " "; } - std::cout << "\n"; } } diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index 93e55e88ca3..17af3e3999c 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -302,8 +302,6 @@ TEST(selected_rows_functor, gpu_merge_add) { for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t j = 0; j < row_numel; ++j) { EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); - std::cout << out_data[i * row_numel + j] << " "; } - std::cout << "\n"; } } diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index de81693c755..69e619a5305 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -99,11 +99,17 @@ class SumKernel : public framework::OpKernel { temp_in0.mutable_value()); inputs.push_back(&temp_in0); for (size_t i = 1; i < in_vars.size(); ++i) { - inputs.push_back(&in_vars[i]->Get()); + auto &in = in_vars[i]->Get(); + if (!in.rows().empty()) { + inputs.push_back(&in); + } } } else { for (auto &in_var : in_vars) { - inputs.push_back(&in_var->Get()); + auto &in = in_var->Get(); + if (!in.rows().empty()) { + inputs.push_back(&in_var->Get()); + } } } -- GitLab