diff --git a/paddle/fluid/operators/math/selected_rows_functor.cc b/paddle/fluid/operators/math/selected_rows_functor.cc index 34fb168036cd2498860bf4f3f1d10e875f650804..a4f584623acf99ca2040955514b28670125bb6c0 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cc +++ b/paddle/fluid/operators/math/selected_rows_functor.cc @@ -228,8 +228,25 @@ template struct SelectedRowsAddToTensor; // add or mul. namespace scatter { -static size_t FindPos(const std::vector& rows, int64_t value) { - return std::find(rows.begin(), rows.end(), value) - rows.begin(); +template +typename std::enable_if< + std::is_floating_point::value && + std::is_same::value>::type +elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, + T* out) { + auto blas = math::GetBlas(ctx); + blas.AXPY(data_len, 1., in, out); +} + +template +typename std::enable_if< + !std::is_floating_point::value && + std::is_same::value>::type +elementwise_add(const DeviceContext& ctx, size_t data_len, const T* in, + T* out) { + for (int64_t i = 0; i < data_len; i++) { + out[i] += in[i]; + } } template @@ -290,9 +307,9 @@ struct MergeAdd { for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = rows_to_id[input_rows[i]]; - for (int64_t j = 0; j < input_width; j++) { - out_data[out_i * input_width + j] += input_data[i * input_width + j]; - } + elementwise_add( + context, static_cast(input_width), + &input_data[i * input_width], &out_data[out_i * input_width]); } } } @@ -300,6 +317,8 @@ struct MergeAdd { template struct MergeAdd; template struct MergeAdd; +template struct MergeAdd; +template struct MergeAdd; template struct UpdateToTensor { diff --git a/paddle/fluid/operators/math/selected_rows_functor.h b/paddle/fluid/operators/math/selected_rows_functor.h index f003bcd8db20664d37066615abe619998b0fcd94..8dc17478e6a4e8c60741b72233eee3cd8d33428e 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.h +++ b/paddle/fluid/operators/math/selected_rows_functor.h @@ -87,108 +87,6 @@ struct MergeAdd { framework::SelectedRows* output); }; -template <> -struct MergeAdd { - framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { - framework::SelectedRows out; - (*this)(context, input, &out); - return out; - } - - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* output) { - framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::vector merge_rows; - merge_rows.reserve(input_rows.size()); - std::unordered_map rows_pos_map; - rows_pos_map.reserve(input_rows.size()); - size_t idx = 0u; - for (std::vector::iterator iter = input_rows.begin(); - iter != input_rows.end(); ++iter) { - if (rows_pos_map.find(*iter) == rows_pos_map.end()) { - rows_pos_map[*iter] = idx++; - merge_rows.emplace_back(*iter); - } - } - - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); - out.set_height(input.height()); - out.mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); - - auto blas = GetBlas(context); - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_pos_map[input_rows[i]]; - float* y = out_data + out_i * input_width; - const float* x = input_data + i * input_width; - blas.AXPY(input_width, 1., x, y); - } - } -}; - -template <> -struct MergeAdd { - framework::SelectedRows operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input) { - framework::SelectedRows out; - (*this)(context, input, &out); - return out; - } - - void operator()(const platform::CPUDeviceContext& context, - const framework::SelectedRows& input, - framework::SelectedRows* output) { - framework::SelectedRows& out = *output; - auto input_rows = input.rows(); - std::vector merge_rows; - merge_rows.reserve(input_rows.size()); - std::unordered_map rows_pos_map; - rows_pos_map.reserve(input_rows.size()); - size_t idx = 0u; - for (std::vector::iterator iter = input_rows.begin(); - iter != input_rows.end(); ++iter) { - if (rows_pos_map.find(*iter) == rows_pos_map.end()) { - rows_pos_map[*iter] = idx++; - merge_rows.emplace_back(*iter); - } - } - - auto input_width = input.value().dims()[1]; - out.set_rows(merge_rows); - out.set_height(input.height()); - out.mutable_value()->mutable_data( - framework::make_ddim( - {static_cast(merge_rows.size()), input_width}), - context.GetPlace()); - - math::SetConstant constant_functor; - constant_functor(context, out.mutable_value(), 0.0); - - auto* out_data = out.mutable_value()->data(); - auto* input_data = input.value().data(); - - auto blas = GetBlas(context); - for (size_t i = 0; i < input_rows.size(); i++) { - size_t out_i = rows_pos_map[input_rows[i]]; - double* y = out_data + out_i * input_width; - const double* x = input_data + i * input_width; - blas.AXPY(input_width, 1., x, y); - } - } -}; - template struct Add { framework::SelectedRows operator()(const DeviceContext& context, diff --git a/paddle/fluid/operators/math/selected_rows_functor_test.cc b/paddle/fluid/operators/math/selected_rows_functor_test.cc index e114e58dee23443750b8fe470c20a1dae84d5903..f5165fa5353c65c6a0fd0fa3c6497d70fe12c8fa 100644 --- a/paddle/fluid/operators/math/selected_rows_functor_test.cc +++ b/paddle/fluid/operators/math/selected_rows_functor_test.cc @@ -303,6 +303,65 @@ TEST(selected_rows_functor, cpu_merge_add_int) { EXPECT_EQ(out_data[2 * row_numel], 1); } +TEST(selected_rows_functor, cpu_merge_add_multi) { + paddle::platform::CPUPlace cpu_place; + paddle::platform::CPUDeviceContext ctx(cpu_place); + paddle::operators::math::SetConstant + set_const; + + int64_t height = 10; + int64_t row_numel = 8; + + std::vector rows1{5, 2, 5, 3, 5}; + std::unique_ptr selected_rows1{ + new paddle::framework::SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows1.size()), row_numel}), + cpu_place); + set_const(ctx, in1_value, 1.0); + + std::vector rows2{2, 5, 3, 5, 3}; + std::unique_ptr selected_rows2{ + new paddle::framework::SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + paddle::framework::make_ddim( + {static_cast(rows2.size()), row_numel}), + cpu_place); + set_const(ctx, in2_value, 1.0); + + std::unique_ptr output{ + new paddle::framework::SelectedRows()}; + output->set_height(height); + paddle::operators::math::scatter::MergeAdd + merge_add_functor; + + std::vector inputs; + inputs.push_back(selected_rows1.get()); + inputs.push_back(selected_rows2.get()); + merge_add_functor(ctx, inputs, output.get()); + + EXPECT_EQ(output->height(), height); + EXPECT_EQ(output->value().dims(), + paddle::framework::make_ddim({3, row_numel})); + + std::vector ret_rows{2, 3, 5}; + EXPECT_EQ(output->rows(), ret_rows); + + auto* out_data = output->value().data(); + for (size_t i = 0; i < ret_rows.size(); ++i) { + for (size_t j = 0; j < row_numel; ++j) { + EXPECT_EQ(out_data[i * row_numel + j], ret_rows[i]); + std::cout << out_data[i * row_numel + j] << " "; + } + std::cout << "\n"; + } +} + TEST(selected_rows_functor, cpu_sum_to) { paddle::platform::CPUPlace cpu_place; paddle::platform::CPUDeviceContext ctx(cpu_place);