From 748ee35c8968f1f288d89a34c2d15338036e06ff Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 27 Oct 2018 20:52:25 +0800 Subject: [PATCH] sum op handle empty input update selected_rows_functor.cu --- .../operators/math/selected_rows_functor.cu | 29 +++++++++++++++---- 1 file changed, 23 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index 9e6a8706ad2..7d94a452890 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -305,12 +305,29 @@ struct MergeAdd { void operator()(const platform::CUDADeviceContext& context, const std::vector& inputs, framework::SelectedRows* output) { - PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); - auto input_width = inputs[0]->value().dims()[1]; - auto input_height = inputs[0]->height(); + if (inputs.size() == 0) { + VLOG(3) << "no input! return"; + return; + } + const framework::SelectedRows* has_value_input = nullptr; + for (auto* in : inputs) { + if (!in->rows().empty()) { + has_value_input = in; + break; + } + } + if (has_value_input == nullptr) { + VLOG(3) << "no input has value! just return" << std::endl; + return; + } + auto input_width = has_value_input->value().dims()[1]; + auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; for (auto* input : inputs) { + if (input->rows().empty()) { + continue; + } PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], "all input should have same " "dimension except for the first one"); @@ -338,11 +355,11 @@ struct MergeAdd { dim3 threads(block_size, 1); for (auto* input : inputs) { - auto* input_data = input->value().data(); - auto& input_rows = input->rows(); - if (input_rows.size() == 0) { + if (input->rows().empty()) { continue; } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); dim3 grid1(input_rows.size(), 1); MergeAddKernel<<>>( -- GitLab