From ab3e36da80149b7840f6651d69f070bddebd3b4c Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Mon, 15 Oct 2018 11:13:58 +0800 Subject: [PATCH] update MergeAdd for selected_rows_functor.cu --- .../operators/math/selected_rows_functor.cu | 46 +++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/paddle/fluid/operators/math/selected_rows_functor.cu b/paddle/fluid/operators/math/selected_rows_functor.cu index ba8eccf8204..4c6e2ee7c23 100644 --- a/paddle/fluid/operators/math/selected_rows_functor.cu +++ b/paddle/fluid/operators/math/selected_rows_functor.cu @@ -296,6 +296,52 @@ struct MergeAdd { out.mutable_rows()->CUDAMutableData(context.GetPlace()), out.rows().size(), input_width); } + + void operator()(const platform::CUDADeviceContext& context, + const std::vector& inputs, + framework::SelectedRows* output) { + PADDLE_ENFORCE_GT(inputs.size(), 0, "should have at least one input"); + auto input_width = inputs[0]->value().dims()[1]; + auto input_height = inputs[0]->height(); + framework::SelectedRows& out = *output; + std::set merged_row_set; + for (auto* input : inputs) { + PADDLE_ENFORCE_EQ(input_width, input->value().dims()[1], + "all input should have same " + "dimension except for the first one"); + PADDLE_ENFORCE_EQ(input_height, input->height(), + "all input should have same height"); + merged_row_set.insert(input->rows().begin(), input->rows().end()); + } + std::vector merge_rows(merged_row_set.begin(), + merged_row_set.end()); + + out.set_rows(merge_rows); + out.set_height(input_height); + out.mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out.mutable_value(), 0.0); + + auto* out_data = out.mutable_value()->data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + + for (auto* input : inputs) { + auto* input_data = input->value().data(); + auto& input_rows = input->rows(); + dim3 grid1(input_rows.size(), 1); + + MergeAddKernel<<>>( + input_data, input_rows.CUDAData(context.GetPlace()), out_data, + out.mutable_rows()->CUDAMutableData(context.GetPlace()), + out.rows().size(), input_width); + } + } }; template struct MergeAdd; -- GitLab