diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index b99115e44b31536f0fd0a9078b40d07949be86f0..647d4f14842ee38bbd8a5d07563ea29ff0432e1a 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -296,6 +296,7 @@ struct MergeAdd { auto input_height = has_value_input->height(); framework::SelectedRows& out = *output; std::set merged_row_set; + size_t row_num = 0; for (auto* input : inputs) { if (input->rows().size() == 0) { continue; @@ -305,42 +306,71 @@ struct MergeAdd { "dimension except for the first one"); PADDLE_ENFORCE_EQ(input_height, input->height(), "all input should have same height"); + row_num += input->rows().size(); merged_row_set.insert(input->rows().begin(), input->rows().end()); } - std::vector merge_rows(merged_row_set.begin(), - merged_row_set.end()); - if (sorted_result) { - std::sort(merge_rows.begin(), merge_rows.end()); - } - std::unordered_map rows_to_id; - for (size_t i = 0; i < merge_rows.size(); ++i) { - rows_to_id[merge_rows[i]] = i; - } - out.set_rows(merge_rows); + out.set_height(input_height); out.mutable_value()->mutable_data( framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), + {static_cast(merged_row_set.size()), input_width}), context.GetPlace()); + auto* out_data = out.mutable_value()->data(); - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); + if (merged_row_set.size() == row_num && !sorted_result) { + // no duplicated ids, just concat the result together + std::vector merge_rows; + merge_rows.reserve(row_num); + // concat rows + for (auto* in : inputs) { + merge_rows.insert(merge_rows.end(), in->rows().begin(), + in->rows().end()); + } + out.set_rows(merge_rows); + auto in_place = inputs[0]->place(); + auto out_place = out.place(); + int64_t copied_numel = 0; + for (auto* in : inputs) { + auto* in_data = in->value().data(); + auto in_numel = in->value().numel(); + memory::Copy(boost::get(out_place), + out_data + copied_numel, + boost::get(in_place), in_data, + in_numel * sizeof(T)); + copied_numel += in_numel; + } + } else { + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); - auto* out_data = out.mutable_value()->data(); + if (sorted_result) { + std::sort(merge_rows.begin(), merge_rows.end()); + } - auto blas = math::GetBlas(context); - for (auto* input : inputs) { - if (input->rows().size() == 0) { - continue; + out.set_rows(merge_rows); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + std::unordered_map rows_to_id; + for (size_t i = 0; i < merge_rows.size(); ++i) { + rows_to_id[merge_rows[i]] = i; } - auto* input_data = input->value().data(); - auto& input_rows = input->rows(); - - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_to_id[input_rows[i]]; - elementwise_add_to( - context, &blas, static_cast(input_width), - &input_data[i * input_width], &out_data[out_i * input_width]); + + auto blas = math::GetBlas(context); + for (auto* input : inputs) { + if (input->rows().size() == 0) { + continue; + } + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = rows_to_id[input_rows[i]]; + elementwise_add_to( + context, &blas, static_cast(input_width), + &input_data[i * input_width], &out_data[out_i * input_width]); + } } } } diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index aedb82da2f0fb2f15e1586d351af7c9d4364852b..5581b9e040272e224669d612409f88d61f794443 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -13,8 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/math/selected_rows_functor.h" + +#include #include #include "gtest/gtest.h" + #include "paddle/fluid/operators/math/math_function.h" TEST(selected_rows_functor, cpu_add) { @@ -360,6 +363,69 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { } } +TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{1, 3, 5, 7, 9}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{0, 2, 4, 6, 8}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + set_const(ctx, in2_value, 2.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({10, row_numel})); + + std::vector ret_rows{1, 3, 5, 7, 9, 0, 2, 4, 6, 8}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output->value().data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + float data_value = 0; + if (i < 5) { + data_value = 1.0; + } else { + data_value = 2.0; + } + for (size_t j = 0; j < static_cast(row_numel); ++j) { + EXPECT_EQ(out_data[i * row_numel + j], data_value); + } + } +} + TEST(selected_rows_functor, cpu_sum_to) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place);