From 0a8ff2ecd4a674f7232876949e5815c0bea8fa54 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Fri, 12 Apr 2019 15:46:07 +0800 Subject: [PATCH] add cpu_merge_add_multi_noduplicated_test test=develop --- .../math/selected_rows_functor_test.cc | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index aedb82da2..9b348d2cf 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -360,6 +360,69 @@ TEST(selected_rows_functor, cpu_merge_add_multi) { } } +TEST(selected_rows_functor, cpu_merge_add_multi_noduplicated) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{1, 3, 5, 7, 9}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{0, 2, 4, 6, 8}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + set_const(ctx, in2_value, 2.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({10, row_numel})); + + std::vector ret_rows{1, 3, 5, 7, 9, 0, 2, 4, 6, 8}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output->value().data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + float data_value = 0; + if (i < 5) { + data_value = 1.0; + } else { + data_value = 2.0; + } + for (size_t j = 0; j < static_cast(row_numel); ++j) { + EXPECT_EQ(out_data[i * row_numel + j], data_value); + } + } +} + TEST(selected_rows_functor, cpu_sum_to) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place); -- GitLab