From d48a0e4eae939f3615fabc9f86f11670fcfad6e3 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Wed, 27 Dec 2017 21:04:51 +0800 Subject: [PATCH] WIP: adding generic scattor functors --- .../operators/math/selected_rows_functor.cc | 47 +++++++++++++ .../operators/math/selected_rows_functor.cu | 67 +++++++++++++++++++ paddle/operators/math/selected_rows_functor.h | 47 +++++++++++++ 3 files changed, 161 insertions(+) diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index ab758d1e7fd..21418ba4b02 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -179,6 +179,53 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} + +template +struct MergeAdd { + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out) { + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + // std::unique_ptr out{ + // new framework::SelectedRows()}; + out->set_rows(merge_rows); + out->set_height(input.height()); + out->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out->mutable_value(), 0.0); + + auto* out_data = out->mutable_value()->data(); + auto* input_data = input.value().data(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = FindPos(merge_rows, input_rows[i]); + for (int64_t j = 0; j < input_width; j++) { + out_data[out_i * input_width + j] += input_data[i * input_width + j]; + } + } + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 9fddd97a36f..b2c0fe7bc3d 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -222,6 +222,73 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; + +namespace scatter { + +template +__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, + T* out, const int64_t* out_rows, + size_t out_rows_size, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t out_idx; + + if (tid == 0) { + for (size_t i = 0; i < out_rows_size; i++) { + if (input_rows[ty] == out_rows[i]) { + out_idx = i; + } + } + } + + __syncthreads(); + + input += ty * row_numel; + out += out_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(out + index, input[index]); + } +} + +template +struct MergeAdd { + void operator()(const platform::GPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out) { + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + // std::unique_ptr out{ + // new framework::SelectedRows()}; + out->set_rows(merge_rows); + out->set_height(input.height()); + out->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out->mutable_value(), 0.0); + + auto* out_data = out->mutable_value()->data(); + auto* input_data = input.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, input_rows.size()); + + MergeAddKernel< + T, 256><<(context) + .stream()>>>(input_data, input.rows().data(), out_data, + out->rows().data(), out->rows().size(), + input_width); + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 1149075abf1..8adfca77f69 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -52,6 +52,53 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2); }; +namespace scatter { +// functors for manuplating SelectedRows data + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + void operator()(const DeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out); +}; + +template +struct Add { + void operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* out) { + out->set_rows(input1->rows()); + out->set_height(input1->height()); + out->mutable_value()->mutable_data(input1->value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1->value()); + auto e_in2 = framework::EigenVector::Flatten(input2->value()); + e_out.device(*context.eigen_device()) = e_in1 + e_in2; + } +}; + +template +struct Mul { + void operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* out) { + out->set_rows(input1->rows()); + out->set_height(input1->height()); + out->mutable_value()->mutable_data(input1->value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1->value()); + auto e_in2 = framework::EigenVector::Flatten(input2->value()); + e_out.device(*context.eigen_device()) = e_in1 * e_in2; + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle -- GitLab