/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/selected_rows_functor.h" #include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { namespace math { template struct SelectedRowsAdd { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input1, const framework::SelectedRows& input2, framework::SelectedRows* output) { auto in1_height = input1.height(); PADDLE_ENFORCE_EQ(in1_height, input2.height()); output->set_height(in1_height); auto& in1_rows = input1.rows(); auto& in2_rows = input2.rows(); std::vector out_rows; out_rows.reserve(in1_rows.size() + in2_rows.size()); // concat rows out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); output->set_rows(out_rows); auto* out_value = output->mutable_value(); auto& in1_value = input1.value(); auto& in2_value = input2.value(); auto in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); auto in1_place = input1.place(); PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); auto in2_place = input2.place(); PADDLE_ENFORCE(platform::is_cpu_place(in2_place)); auto out_place = context.GetPlace(); PADDLE_ENFORCE(platform::is_cpu_place(out_place)); auto* out_data = out_value->data(); auto* in1_data = in1_value.data(); memory::Copy(boost::get(out_place), out_data, boost::get(in1_place), in1_data, in1_value.numel() * sizeof(T)); auto* in2_data = in2_value.data(); memory::Copy(boost::get(out_place), out_data + in1_value.numel(), boost::get(in2_place), in2_data, in2_value.numel() * sizeof(T)); } }; template struct SelectedRowsAdd; template struct SelectedRowsAdd; template struct SelectedRowsAddTensor { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input1, const framework::Tensor& input2, framework::Tensor* output) { auto in1_height = input1.height(); auto in2_dims = input2.dims(); auto out_dims = output->dims(); PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); auto& in1_value = input1.value(); auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); SetConstant functor; functor(context, output, 0.0); auto* in1_data = in1_value.data(); auto* out_data = output->data(); for (size_t i = 0; i < in1_rows.size(); i++) { for (int64_t j = 0; j < in1_row_numel; j++) { out_data[in1_rows[i] * in1_row_numel + j] += in1_data[i * in1_row_numel + j]; } } auto out_eigen = framework::EigenVector::Flatten(*output); auto in2_eigen = framework::EigenVector::Flatten(input2); out_eigen.device(*context.eigen_device()) = out_eigen + in2_eigen; } }; template struct SelectedRowsAddTensor; template struct SelectedRowsAddTensor; template struct SelectedRowsAddTo { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input1, const int64_t input2_offset, framework::SelectedRows* input2) { auto in1_height = input1.height(); PADDLE_ENFORCE_EQ(in1_height, input2->height()); auto& in1_rows = input1.rows(); auto& in2_rows = *(input2->mutable_rows()); auto& in1_value = input1.value(); auto* in2_value = input2->mutable_value(); // concat rows in2_rows.insert(in2_rows.end(), in1_rows.begin(), in1_rows.end()); auto in1_place = input1.place(); PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); auto in2_place = input2->place(); PADDLE_ENFORCE(platform::is_cpu_place(in2_place)); auto* in1_data = in1_value.data(); auto* in2_data = in2_value->data(); memory::Copy(boost::get(in2_place), in2_data + input2_offset, boost::get(in1_place), in1_data, in1_value.numel() * sizeof(T)); } }; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddTo; template struct SelectedRowsAddToTensor { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input1, framework::Tensor* input2) { auto in1_height = input1.height(); auto in2_dims = input2->dims(); PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); auto& in1_value = input1.value(); auto& in1_rows = input1.rows(); int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); PADDLE_ENFORCE_EQ(in1_row_numel, input2->numel() / in1_height); auto* in1_data = in1_value.data(); auto* input2_data = input2->data(); for (size_t i = 0; i < in1_rows.size(); i++) { for (int64_t j = 0; j < in1_row_numel; j++) { input2_data[in1_rows[i] * in1_row_numel + j] += in1_data[i * in1_row_numel + j]; } } } }; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; template struct SelectedRowsAddToTensor; // This is a separated namespace for manipulate SelectedRows typed // data. Like merge duplicated rows, adding two SelectedRows etc. // // Another group of functors is called "scatter updates", which means // use SelectedRows to update a dense tensor with different Ops, like // add or mul. namespace scatter { size_t FindPos(const std::vector& rows, int64_t value) { return std::find(rows.begin(), rows.end(), value) - rows.begin(); } template struct MergeAdd { void operator()(const platform::CPUDeviceContext& context, const framework::SelectedRows& input, framework::SelectedRows* out) { auto input_rows = input.rows(); std::set row_set(input_rows.begin(), input_rows.end()); std::vector merge_rows(row_set.begin(), row_set.end()); auto input_width = input.value().dims()[1]; // std::unique_ptr out{ // new framework::SelectedRows()}; out->set_rows(merge_rows); out->set_height(input.height()); out->mutable_value()->mutable_data( framework::make_ddim( {static_cast(merge_rows.size()), input_width}), context.GetPlace()); math::SetConstant constant_functor; constant_functor(context, out->mutable_value(), 0.0); auto* out_data = out->mutable_value()->data(); auto* input_data = input.value().data(); for (size_t i = 0; i < input_rows.size(); i++) { size_t out_i = FindPos(merge_rows, input_rows[i]); for (int64_t j = 0; j < input_width; j++) { out_data[out_i * input_width + j] += input_data[i * input_width + j]; } } } }; } // namespace scatter } // namespace math } // namespace operators } // namespace paddle