diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc index ab758d1e7fd8ab361948b28e8cb735b9a742a339..21418ba4b0201e50302edea66d016e9464891f36 100644 --- a/paddle/operators/math/selected_rows_functor.cc +++ b/paddle/operators/math/selected_rows_functor.cc @@ -179,6 +179,53 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; +// This is a separated namespace for manipulate SelectedRows typed +// data. Like merge duplicated rows, adding two SelectedRows etc. +// +// Another group of functors is called "scatter updates", which means +// use SelectedRows to update a dense tensor with different Ops, like +// add or mul. +namespace scatter { + +size_t FindPos(const std::vector& rows, int64_t value) { + return std::find(rows.begin(), rows.end(), value) - rows.begin(); +} + +template +struct MergeAdd { + void operator()(const platform::CPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out) { + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + // std::unique_ptr out{ + // new framework::SelectedRows()}; + out->set_rows(merge_rows); + out->set_height(input.height()); + out->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out->mutable_value(), 0.0); + + auto* out_data = out->mutable_value()->data(); + auto* input_data = input.value().data(); + + for (size_t i = 0; i < input_rows.size(); i++) { + size_t out_i = FindPos(merge_rows, input_rows[i]); + for (int64_t j = 0; j < input_width; j++) { + out_data[out_i * input_width + j] += input_data[i * input_width + j]; + } + } + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu index 9fddd97a36f7fdb6628d6eeb192cb216fdae3e5b..b2c0fe7bc3da725a9175930657bb4fc8f6da3764 100644 --- a/paddle/operators/math/selected_rows_functor.cu +++ b/paddle/operators/math/selected_rows_functor.cu @@ -222,6 +222,73 @@ template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; + +namespace scatter { + +template +__global__ void MergeAddKernel(const T* input, const int64_t* input_rows, + T* out, const int64_t* out_rows, + size_t out_rows_size, int64_t row_numel) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + __shared__ size_t out_idx; + + if (tid == 0) { + for (size_t i = 0; i < out_rows_size; i++) { + if (input_rows[ty] == out_rows[i]) { + out_idx = i; + } + } + } + + __syncthreads(); + + input += ty * row_numel; + out += out_idx * row_numel; + for (int index = tid; index < row_numel; index += block_size) { + paddle::platform::CudaAtomicAdd(out + index, input[index]); + } +} + +template +struct MergeAdd { + void operator()(const platform::GPUDeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out) { + auto input_rows = input.rows(); + std::set row_set(input_rows.begin(), input_rows.end()); + std::vector merge_rows(row_set.begin(), row_set.end()); + + auto input_width = input.value().dims()[1]; + // std::unique_ptr out{ + // new framework::SelectedRows()}; + out->set_rows(merge_rows); + out->set_height(input.height()); + out->mutable_value()->mutable_data( + framework::make_ddim( + {static_cast(merge_rows.size()), input_width}), + context.GetPlace()); + + math::SetConstant constant_functor; + constant_functor(context, out->mutable_value(), 0.0); + + auto* out_data = out->mutable_value()->data(); + auto* input_data = input.value().data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid1(1, input_rows.size()); + + MergeAddKernel< + T, 256><<(context) + .stream()>>>(input_data, input.rows().data(), out_data, + out->rows().data(), out->rows().size(), + input_width); + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h index 1149075abf16547a120ac8928c45b4972409fc72..8adfca77f6930734b9c5dac43caa979d8705ed2a 100644 --- a/paddle/operators/math/selected_rows_functor.h +++ b/paddle/operators/math/selected_rows_functor.h @@ -52,6 +52,53 @@ struct SelectedRowsAddToTensor { framework::Tensor* input2); }; +namespace scatter { +// functors for manuplating SelectedRows data + +template +struct MergeAdd { + // unary functor, merge by adding duplicated rows in + // the input SelectedRows object. + void operator()(const DeviceContext& context, + const framework::SelectedRows& input, + framework::SelectedRows* out); +}; + +template +struct Add { + void operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* out) { + out->set_rows(input1->rows()); + out->set_height(input1->height()); + out->mutable_value()->mutable_data(input1->value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1->value()); + auto e_in2 = framework::EigenVector::Flatten(input2->value()); + e_out.device(*context.eigen_device()) = e_in1 + e_in2; + } +}; + +template +struct Mul { + void operator()(const DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* out) { + out->set_rows(input1->rows()); + out->set_height(input1->height()); + out->mutable_value()->mutable_data(input1->value().dims(), + context.GetPlace()); + auto e_out = framework::EigenVector::Flatten(*(out->mutable_value())); + auto e_in1 = framework::EigenVector::Flatten(input1->value()); + auto e_in2 = framework::EigenVector::Flatten(input2->value()); + e_out.device(*context.eigen_device()) = e_in1 * e_in2; + } +}; + +} // namespace scatter } // namespace math } // namespace operators } // namespace paddle