diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 2679f501dac040d01c1dd728bc27b9676fe126c4..305743b082a0c8b38053107a3924cb6cbef98948 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -269,12 +269,29 @@ struct MergeAdd { void operator()(const platform::CPUDeviceContext& context, const std::vector& inputs, framework::SelectedRows* output) { - PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); - auto input_width = inputs[0]->value().dims()[1]; - auto input_height = inputs[0]->height(); + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (!in->rows().empty()) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; for (auto* input : inputs) { + if (input->rows().empty()) { + continue; + } PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], "all input should have same " "dimension except for the first one"); @@ -288,7 +305,6 @@ struct MergeAdd { for (size_t i = 0; i < merge_rows.size(); ++i) { rows_to_id[merge_rows[i]] = i; } - out.set_rows(merge_rows); out.set_height(input_height); out.mutable_value()->mutable_data( @@ -303,6 +319,9 @@ struct MergeAdd { auto blas = math::GetBlas(context); for (auto* input : inputs) { + if (input->rows().empty()) { + continue; + } auto* input_data = input->value().data(); auto& input_rows = input->rows(); diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index f5165fa5353c65c6a0fd0fa3c6497d70fe12c8fa..f15b37a1e3f0ae9c7612c4f74470472393ff4ad6 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -356,9 +356,7 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t j = 0; j < row_numel; ++j) { EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); - std::cout << out_data[i * row_numel + j] << " "; } - std::cout << "\n"; } } diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cu b/paddle/fluid/operators/math/selected_rows_functor_test.cu index 93e55e88ca349fd96925456f135aa5db1aa37fc5..17af3e3999ca688c584f636f4c00386f886f9bbf 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cu +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cu @@ -302,8 +302,6 @@ TEST(selected_rows_functor, gpu_merge_add) { for (size_t i = 0; i < ret_rows.size(); ++i) { for (size_t j = 0; j < row_numel; ++j) { EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); - std::cout << out_data[i * row_numel + j] << " "; } - std::cout << "\n"; } } diff --git a/paddle/fluid/operators/sum_op.h b/paddle/fluid/operators/sum_op.h index de81693c755d87e7ddd0fd422eae7216739aca18..69e619a5305020d2dc91485becb4e369006823f5 100644 --- a/paddle/fluid/operators/sum_op.h +++ b/paddle/fluid/operators/sum_op.h @@ -99,11 +99,17 @@ class SumKernel : public framework::OpKernel { temp_in0.mutable_value()); inputs.push_back(&temp_in0); for (size_t i = 1; i < in_vars.size(); ++i) { - inputs.push_back(&in_vars[i]->Get()); + auto &in = in_vars[i]->Get(); + if (!in.rows().empty()) { + inputs.push_back(&in); + } } } else { for (auto &in_var : in_vars) { - inputs.push_back(&in_var->Get()); + auto &in = in_var->Get(); + if (!in.rows().empty()) { + inputs.push_back(&in_var->Get()); + } } }