diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index b99115e44b31536f0fd0a9078b40d07949be86f0..647d4f14842ee38bbd8a5d07563ea29ff0432e1a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -296,6 +296,7 @@ struct MergeAdd { auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; + size_t row_num = 0; for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -305,42 +306,71 @@ struct MergeAdd { "dimension except for the first one"); PADDLE_ENFORCE_EQ(input_height, input->height(), "all input should have same height"); + row_num += input->rows().size(); merged_row_set.insert(input->rows().begin(), input->rows().end()); } - std::vector merge_rows(merged_row_set.begin(), - merged_row_set.end()); - if (sorted_result) { - std::sort(merge_rows.begin(), merge_rows.end()); - } - std::unordered_map rows_to_id; - for (size_t i = 0; i < merge_rows.size(); ++i) { - rows_to_id[merge_rows[i]] = i; - } - out.set_rows(merge_rows); + out.set_height(input_height); out.mutable_value()->mutable_data( framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), + {static_cast(merged_row_set.size()), input_width}), context.GetPlace()); + auto* out_data = out.mutable_value()->data(); - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + if (merged_row_set.size() == row_num && !sorted_result) { + // no duplicated ids, just concat the result together + std::vector merge_rows; + merge_rows.reserve(row_num); + // concat rows + for (auto* in : inputs) { + merge_rows.insert(merge_rows.end(), in->rows().begin(), + in->rows().end()); + } + out.set_rows(merge_rows); + auto in_place = inputs[0]->place(); + auto out_place = out.place(); + int64_t copied_numel = 0; + for (auto* in : inputs) { + auto* in_data = in->value().data(); + auto in_numel = in->value().numel(); + memory::Copy(boost::get(out_place), + out_data + copied_numel, + boost::get(in_place), in_data, + in_numel * sizeof(T)); + copied_numel += in_numel; + } + } else { + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); - auto* out_data = out.mutable_value()->data(); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } - auto blas = math::GetBlas(context); - for (auto* input : inputs) { - if (input->rows().size() == 0) { - continue; + out.set_rows(merge_rows); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; } - auto* input_data = input->value().data(); - auto& input_rows = input->rows(); - - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_to_id[input_rows[i]]; - elementwise_add_to( - context, &blas, static_cast(input_width), - &input_data[i * input_width], &out_data[out_i * input_width]); + + auto blas = math::GetBlas(context); + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + elementwise_add_to( + context, &blas, static_cast(input_width), + &input_data[i * input_width], &out_data[out_i * input_width]); + } } } }