diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 77a3603eb63ff5d6a4ab5b96140af4a5cef555f4..72ce8585045b5166df424a401442db39b47ab098 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,18 +1,22 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor selected_rows) + nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) + nv_library(selected_rows_functor SRCS selected_rows_functor.cc selected_rows_functor.cu DEPS selected_rows math_function) + nv_test(selected_rows_functor_gpu_test SRCS selected_rows_functor_test.cu DEPS selected_rows_functor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) + cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) endif() -cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor selected_rows) +cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) +cc_test(selected_rows_functor_test SRCS selected_rows_functor_test.cc DEPS selected_rows_functor) cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor) cc_test(vol2col_test SRCS vol2col_test.cc DEPS vol2col tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index a1faafb7c451dc205f95583cdde798257a24b17a..77a1e22b41e8dcd1fe78f3c4730653dee04db80e 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" -#include "paddle/framework/eigen.h" -#include "paddle/memory/memcpy.h" namespace paddle { namespace operators { @@ -134,97 +132,6 @@ void matmul( template struct SetConstant; -template -struct SelectedRowsAdd { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::SelectedRows& input2, - framework::SelectedRows* output) { - auto in1_height = input1.height(); - PADDLE_ENFORCE_EQ(in1_height, input2.height()); - output->set_height(in1_height); - - auto& in1_rows = input1.rows(); - auto& in2_rows = input2.rows(); - std::vector out_rows; - out_rows.reserve(in1_rows.size() + in2_rows.size()); - - // concat rows - out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); - out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); - output->set_rows(out_rows); - - auto* out_value = output->mutable_value(); - auto& in1_value = input1.value(); - auto& in2_value = input2.value(); - - auto in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); - PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); - - auto in1_place = input1.place(); - PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); - auto in2_place = input2.place(); - PADDLE_ENFORCE(platform::is_cpu_place(in2_place)); - auto out_place = context.GetPlace(); - PADDLE_ENFORCE(platform::is_cpu_place(out_place)); - - auto* out_data = out_value->data(); - auto* in1_data = in1_value.data(); - memory::Copy(boost::get(out_place), out_data, - boost::get(in1_place), in1_data, - in1_value.numel() * sizeof(T)); - - auto* in2_data = in2_value.data(); - memory::Copy(boost::get(out_place), - out_data + in1_value.numel(), - boost::get(in2_place), in2_data, - in2_value.numel() * sizeof(T)); - } -}; - -template struct SelectedRowsAdd; - -template -struct SelectedRowsAddTensor { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::Tensor& input2, framework::Tensor* output) { - auto in1_height = input1.height(); - auto in2_dims = input2.dims(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); - PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); - - auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); - - int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); - PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); - - SetConstant functor; - functor(context, output, 0.0); - - auto* in1_data = in1_value.data(); - auto* out_data = output->data(); - - for (size_t i = 0; i < in1_rows.size(); i++) { - for (int64_t j = 0; j < in1_row_numel; j++) { - out_data[in1_rows[i] * in1_row_numel + j] += - in1_data[i * in1_row_numel + j]; - } - } - - auto out_eigen = framework::EigenVector::Flatten(*output); - auto in2_eigen = framework::EigenVector::Flatten(input2); - out_eigen.device(*context.GetEigenDevice()) = - out_eigen + in2_eigen; - } -}; - -template struct SelectedRowsAddTensor; - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index fc16d1b0a7a7e0d08b1c361272ed54b839d3cf49..7fbc03acf22231a6fa386aa67e43f738eadb18d3 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/operators/math/math_function.h" -#include "paddle/platform/cuda_helper.h" namespace paddle { namespace operators { @@ -156,123 +155,7 @@ void matmul( matrix_b.data(), beta, matrix_out->data()); } -template -struct SelectedRowsAdd { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::SelectedRows& input2, - framework::SelectedRows* output) { - auto in1_height = input1.height(); - PADDLE_ENFORCE_EQ(in1_height, input2.height()); - output->set_height(in1_height); - - auto& in1_rows = input1.rows(); - auto& in2_rows = input2.rows(); - std::vector out_rows; - out_rows.reserve(in1_rows.size() + in2_rows.size()); - - // concat rows - out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); - out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); - output->set_rows(out_rows); - - auto* out_value = output->mutable_value(); - auto& in1_value = input1.value(); - auto& in2_value = input2.value(); - - auto in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); - PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); - - auto* out_data = out_value->data(); - auto* in1_data = in1_value.data(); - - auto in1_place = input1.place(); - PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); - auto in2_place = input2.place(); - PADDLE_ENFORCE(platform::is_gpu_place(in2_place)); - auto out_place = context.GetPlace(); - PADDLE_ENFORCE(platform::is_gpu_place(out_place)); - - memory::Copy( - boost::get(out_place), out_data, - boost::get(in1_place), in1_data, - in1_value.numel() * sizeof(T), - reinterpret_cast(context).stream()); - - auto* in2_data = in2_value.data(); - memory::Copy( - boost::get(out_place), out_data + in1_value.numel(), - boost::get(in2_place), in2_data, - in2_value.numel() * sizeof(T), - reinterpret_cast(context).stream()); - } -}; - -template struct SelectedRowsAdd; - -namespace { -template -__global__ void SelectedRowsAddTensorKernel(const T* selected_rows, - const int64_t* rows, T* tensor_out, - int64_t row_numel, int block_size) { - const int ty = blockIdx.y; - int tid = threadIdx.x; - - selected_rows += ty * row_numel; - tensor_out += rows[ty] * row_numel; - - for (int index = tid; index < row_numel; index += block_size) { - // Since index in rows of SelectedRows can be duplicate, we can not use - // tensor_out[index] += selected_rows[index]; Instead, we have to use - // AtomicAdd to avoid concurrent write error. - paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]); - } -} -} // namespace - -template -struct SelectedRowsAddTensor { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::Tensor& input2, framework::Tensor* output) { - auto in1_height = input1.height(); - auto in2_dims = input2.dims(); - auto out_dims = output->dims(); - PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); - PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); - - auto& in1_value = input1.value(); - auto& in1_rows = input1.rows(); - - int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); - PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); - PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); - - auto* in1_data = in1_value.data(); - auto* in2_data = input2.data(); - auto* out_data = output->data(); - - SetConstant functor; - functor(context, output, 0.0); - - int block_size = 256; - dim3 threads(block_size, 1); - dim3 grid(1, in1_height); - SelectedRowsAddTensorKernel< - T><<(context) - .stream()>>>(in1_data, in1_rows.data(), out_data, - in1_row_numel, block_size); - - auto out_eigen = framework::EigenVector::Flatten(*output); - auto in2_eigen = framework::EigenVector::Flatten(input2); - out_eigen.device(*context.GetEigenDevice()) = - out_eigen + in2_eigen; - } -}; - -template struct SelectedRowsAddTensor; +template struct SetConstant; } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index 0d0d4cdd7396362ee154896fe57615ac02ad09ae..6f92d83aabbc77f7ea7d4159869e07126b270740 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -53,7 +53,6 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include #include "paddle/framework/eigen.h" -#include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/enforce.h" @@ -96,23 +95,6 @@ struct SetConstant { } }; -// SelectedRows + SelectedRows will simplely concat value and rows. -// The real computation happens in dealing with LoDTensor. -template -struct SelectedRowsAdd { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::SelectedRows& input2, - framework::SelectedRows* output); -}; - -template -struct SelectedRowsAddTensor { - void operator()(const platform::DeviceContext& context, - const framework::SelectedRows& input1, - const framework::Tensor& input2, framework::Tensor* output); -}; - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index 33c561f6c6692f4ae84d841e6680bea0451d956b..3b9f92e7ae5f34dd0fb1ba8fb0c67ff5ae1628c4 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -89,93 +89,3 @@ TEST(math_function, zero) { EXPECT_EQ(t[2], 1); EXPECT_EQ(t[3], 1); } - -TEST(math_function, selected_rows_add) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::operators::math; - - CPUPlace cpu_place; - CPUDeviceContext ctx(cpu_place); - SetConstant functor; - int64_t height = 10; - int64_t row_numel = 10; - - std::vector rows1{0, 4, 7}; - std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; - auto* in1_value = selected_rows1->mutable_value(); - in1_value->mutable_data( - make_ddim({static_cast(rows1.size()), row_numel}), cpu_place); - functor(ctx, in1_value, 1.0); - - std::vector rows2{0, 5, 7, 9}; - std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; - auto* in2_value = selected_rows2->mutable_value(); - in2_value->mutable_data( - make_ddim({static_cast(rows2.size()), row_numel}), cpu_place); - functor(ctx, in2_value, 2.0); - - std::unique_ptr output{new SelectedRows()}; - auto* out_value = output->mutable_value(); - - // simplely concat two SelectedRows - out_value->mutable_data(make_ddim({7, 10}), cpu_place); - - SelectedRowsAdd add_functor; - add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); - - auto out_height = output->height(); - EXPECT_EQ(out_height, height); - - auto& out_rows = output->rows(); - - // input1 rows - EXPECT_EQ(out_rows[0], 0); - EXPECT_EQ(out_rows[1], 4); - EXPECT_EQ(out_rows[2], 7); - // input2 rows - EXPECT_EQ(out_rows[3], 0); - EXPECT_EQ(out_rows[4], 5); - EXPECT_EQ(out_rows[5], 7); - EXPECT_EQ(out_rows[6], 9); - - auto* out_data = output->value().data(); - // input1 value - EXPECT_EQ(out_data[0 * row_numel + 0], 1.0); - EXPECT_EQ(out_data[0 * row_numel + 8], 1.0); - EXPECT_EQ(out_data[1 * row_numel + 1], 1.0); - EXPECT_EQ(out_data[2 * row_numel + 6], 1.0); - // input2 value - EXPECT_EQ(out_data[3 * row_numel + 3], 2.0); - EXPECT_EQ(out_data[3 * row_numel + 8], 2.0); - EXPECT_EQ(out_data[4 * row_numel + 4], 2.0); - EXPECT_EQ(out_data[5 * row_numel + 7], 2.0); - EXPECT_EQ(out_data[6 * row_numel + 9], 2.0); - - std::unique_ptr tensor1{new Tensor()}; - tensor1->mutable_data(make_ddim({height, row_numel}), cpu_place); - SetConstant constant_functor; - constant_functor(ctx, tensor1.get(), 3.0); - - std::unique_ptr tensor2{new Tensor()}; - tensor2->mutable_data(make_ddim({height, row_numel}), cpu_place); - - SelectedRowsAddTensor add_tensor_functor; - add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); - - auto* tensor2_data = tensor2->data(); - // row0: 1.0 + 2.0 + 3.0 - EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0); - // row1: 3.0 - EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0); - // row4 : 1.0 + 3.0 - EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0); - // row5: 2.0 + 3.0 - EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0); - // row6: 3.0 - EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0); - // row7: 1.0 + 2.0 + 3.0 - EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0); - // row9: 2.0 + 3.0 - EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0); -} diff --git a/paddle/operators/math/math_function_test.cu b/paddle/operators/math/math_function_test.cu index 1acc5f66a69f88d4eb2cf47c70277ca843b02627..14359d835bba794703a313d70f34082868474b20 100644 --- a/paddle/operators/math/math_function_test.cu +++ b/paddle/operators/math/math_function_test.cu @@ -177,102 +177,3 @@ TEST(math_function, gemm_trans_cublas) { EXPECT_EQ(input3_ptr[7], 99); delete gpu_place; } - -TEST(math_function, selected_rows_add) { - using namespace paddle::framework; - using namespace paddle::platform; - using namespace paddle::operators::math; - - GPUPlace gpu_place(0); - CPUPlace cpu_place; - CUDADeviceContext ctx(gpu_place); - SetConstant functor; - int64_t height = 10; - int64_t row_numel = 10; - - std::vector rows1{0, 4, 7}; - std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; - auto* in1_value = selected_rows1->mutable_value(); - in1_value->mutable_data( - make_ddim({static_cast(rows1.size()), row_numel}), gpu_place); - functor(ctx, in1_value, 1.0); - - std::vector rows2{0, 5, 7, 9}; - std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; - auto* in2_value = selected_rows2->mutable_value(); - in2_value->mutable_data( - make_ddim({static_cast(rows2.size()), row_numel}), gpu_place); - functor(ctx, in2_value, 2.0); - - std::unique_ptr output{new SelectedRows()}; - auto* out_value = output->mutable_value(); - - // simplely concat two SelectedRows - out_value->mutable_data(make_ddim({7, 10}), gpu_place); - - SelectedRowsAdd add_functor; - add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); - - auto out_height = output->height(); - EXPECT_EQ(out_height, height); - - auto& out_rows = output->rows(); - - // input1 rows - EXPECT_EQ(out_rows[0], 0); - EXPECT_EQ(out_rows[1], 4); - EXPECT_EQ(out_rows[2], 7); - // input2 rows - EXPECT_EQ(out_rows[3], 0); - EXPECT_EQ(out_rows[4], 5); - EXPECT_EQ(out_rows[5], 7); - EXPECT_EQ(out_rows[6], 9); - - Tensor out_cpu; - out_cpu.CopyFrom(*out_value, cpu_place, ctx); - ctx.Wait(); - - auto* out_cpu_data = out_cpu.data(); - // input1 value - EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0); - EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0); - EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0); - EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0); - // input2 value - EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0); - EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0); - EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0); - EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0); - EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0); - - std::unique_ptr tensor1{new Tensor()}; - tensor1->mutable_data(make_ddim({height, row_numel}), gpu_place); - SetConstant constant_functor; - constant_functor(ctx, tensor1.get(), 3.0); - - std::unique_ptr tensor2{new Tensor()}; - tensor2->mutable_data(make_ddim({height, row_numel}), gpu_place); - - SelectedRowsAddTensor add_tensor_functor; - add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); - - Tensor tensor2_cpu; - tensor2_cpu.CopyFrom(*tensor2, cpu_place, ctx); - ctx.Wait(); - - auto* tensor2_cpu_data = tensor2_cpu.data(); - // row0: 1.0 + 2.0 + 3.0 - EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0); - // row1: 3.0 - EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0); - // row4 : 1.0 + 3.0 - EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0); - // row5: 2.0 + 3.0 - EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0); - // row6: 3.0 - EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0); - // row7: 1.0 + 2.0 + 3.0 - EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0); - // row9: 2.0 + 3.0 - EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0); -} diff --git a/paddle/operators/math/selected_rows_functor.cc b/paddle/operators/math/selected_rows_functor.cc new file mode 100644 index 0000000000000000000000000000000000000000..f2305ea16913e927dca17e5a80201368f03ca253 --- /dev/null +++ b/paddle/operators/math/selected_rows_functor.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/selected_rows_functor.h" +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2.height()); + output->set_height(in1_height); + + auto& in1_rows = input1.rows(); + auto& in2_rows = input2.rows(); + std::vector out_rows; + out_rows.reserve(in1_rows.size() + in2_rows.size()); + + // concat rows + out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); + out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); + output->set_rows(out_rows); + + auto* out_value = output->mutable_value(); + auto& in1_value = input1.value(); + auto& in2_value = input2.value(); + + auto in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); + PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); + + auto in1_place = input1.place(); + PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); + auto in2_place = input2.place(); + PADDLE_ENFORCE(platform::is_cpu_place(in2_place)); + auto out_place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_cpu_place(out_place)); + + auto* out_data = out_value->data(); + auto* in1_data = in1_value.data(); + memory::Copy(boost::get(out_place), out_data, + boost::get(in1_place), in1_data, + in1_value.numel() * sizeof(T)); + + auto* in2_data = in2_value.data(); + memory::Copy(boost::get(out_place), + out_data + in1_value.numel(), + boost::get(in2_place), in2_data, + in2_value.numel() * sizeof(T)); + } +}; + +template struct SelectedRowsAdd; + +template +struct SelectedRowsAddTensor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); + PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); + + SetConstant functor; + functor(context, output, 0.0); + + auto* in1_data = in1_value.data(); + auto* out_data = output->data(); + + for (size_t i = 0; i < in1_rows.size(); i++) { + for (int64_t j = 0; j < in1_row_numel; j++) { + out_data[in1_rows[i] * in1_row_numel + j] += + in1_data[i * in1_row_numel + j]; + } + } + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.GetEigenDevice()) = + out_eigen + in2_eigen; + } +}; + +template struct SelectedRowsAddTensor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.cu b/paddle/operators/math/selected_rows_functor.cu new file mode 100644 index 0000000000000000000000000000000000000000..a406bef39ad5f71b1eaccc28e9ef5b09d2e6b59b --- /dev/null +++ b/paddle/operators/math/selected_rows_functor.cu @@ -0,0 +1,142 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2.height()); + output->set_height(in1_height); + + auto& in1_rows = input1.rows(); + auto& in2_rows = input2.rows(); + std::vector out_rows; + out_rows.reserve(in1_rows.size() + in2_rows.size()); + + // concat rows + out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); + out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); + output->set_rows(out_rows); + + auto* out_value = output->mutable_value(); + auto& in1_value = input1.value(); + auto& in2_value = input2.value(); + + auto in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); + PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); + + auto* out_data = out_value->data(); + auto* in1_data = in1_value.data(); + + auto in1_place = input1.place(); + PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); + auto in2_place = input2.place(); + PADDLE_ENFORCE(platform::is_gpu_place(in2_place)); + auto out_place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(out_place)); + + memory::Copy( + boost::get(out_place), out_data, + boost::get(in1_place), in1_data, + in1_value.numel() * sizeof(T), + reinterpret_cast(context).stream()); + + auto* in2_data = in2_value.data(); + memory::Copy( + boost::get(out_place), out_data + in1_value.numel(), + boost::get(in2_place), in2_data, + in2_value.numel() * sizeof(T), + reinterpret_cast(context).stream()); + } +}; + +template struct SelectedRowsAdd; + +namespace { +template +__global__ void SelectedRowsAddTensorKernel(const T* selected_rows, + const int64_t* rows, T* tensor_out, + int64_t row_numel, int block_size) { + const int ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + // Since index in rows of SelectedRows can be duplicate, we can not use + // tensor_out[index] += selected_rows[index]; Instead, we have to use + // AtomicAdd to avoid concurrent write error. + paddle::platform::CudaAtomicAdd(&tensor_out[index], selected_rows[index]); + } +} +} // namespace + +template +struct SelectedRowsAddTensor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); + PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2.data(); + auto* out_data = output->data(); + + SetConstant functor; + functor(context, output, 0.0); + + int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in1_height); + SelectedRowsAddTensorKernel< + T><<(context) + .stream()>>>(in1_data, in1_rows.data(), out_data, + in1_row_numel, block_size); + + auto out_eigen = framework::EigenVector::Flatten(*output); + auto in2_eigen = framework::EigenVector::Flatten(input2); + out_eigen.device(*context.GetEigenDevice()) = + out_eigen + in2_eigen; + } +}; + +template struct SelectedRowsAddTensor; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor.h b/paddle/operators/math/selected_rows_functor.h new file mode 100644 index 0000000000000000000000000000000000000000..53ab240ca600cd4a817afa2c19fb8d9427c6f3da --- /dev/null +++ b/paddle/operators/math/selected_rows_functor.h @@ -0,0 +1,41 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#include "paddle/framework/selected_rows.h" +#include "paddle/platform/device_context.h" + +namespace paddle { +namespace operators { +namespace math { + +// SelectedRows + SelectedRows will simplely concat value and rows. +// The real computation happens in dealing with LoDTensor. +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output); +}; + +template +struct SelectedRowsAddTensor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output); +}; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/selected_rows_functor_test.cc b/paddle/operators/math/selected_rows_functor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..4f7760cb713b6bf58c82f38fb043d7d53d82710a --- /dev/null +++ b/paddle/operators/math/selected_rows_functor_test.cc @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/selected_rows_functor.h" +#include "gtest/gtest.h" +#include "paddle/operators/math/math_function.h" + +TEST(selected_rows_functor, cpu_add) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators::math; + + CPUPlace cpu_place; + CPUDeviceContext ctx(cpu_place); + SetConstant functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows1{0, 4, 7}; + std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + make_ddim({static_cast(rows1.size()), row_numel}), cpu_place); + functor(ctx, in1_value, 1.0); + + std::vector rows2{0, 5, 7, 9}; + std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + make_ddim({static_cast(rows2.size()), row_numel}), cpu_place); + functor(ctx, in2_value, 2.0); + + std::unique_ptr output{new SelectedRows()}; + auto* out_value = output->mutable_value(); + + // simplely concat two SelectedRows + out_value->mutable_data(make_ddim({7, 10}), cpu_place); + + SelectedRowsAdd add_functor; + add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + + // input1 rows + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + // input2 rows + EXPECT_EQ(out_rows[3], 0); + EXPECT_EQ(out_rows[4], 5); + EXPECT_EQ(out_rows[5], 7); + EXPECT_EQ(out_rows[6], 9); + + auto* out_data = output->value().data(); + // input1 value + EXPECT_EQ(out_data[0 * row_numel + 0], 1.0); + EXPECT_EQ(out_data[0 * row_numel + 8], 1.0); + EXPECT_EQ(out_data[1 * row_numel + 1], 1.0); + EXPECT_EQ(out_data[2 * row_numel + 6], 1.0); + // input2 value + EXPECT_EQ(out_data[3 * row_numel + 3], 2.0); + EXPECT_EQ(out_data[3 * row_numel + 8], 2.0); + EXPECT_EQ(out_data[4 * row_numel + 4], 2.0); + EXPECT_EQ(out_data[5 * row_numel + 7], 2.0); + EXPECT_EQ(out_data[6 * row_numel + 9], 2.0); + + std::unique_ptr tensor1{new Tensor()}; + tensor1->mutable_data(make_ddim({height, row_numel}), cpu_place); + functor(ctx, tensor1.get(), 3.0); + + std::unique_ptr tensor2{new Tensor()}; + tensor2->mutable_data(make_ddim({height, row_numel}), cpu_place); + + SelectedRowsAddTensor add_tensor_functor; + add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); + + auto* tensor2_data = tensor2->data(); + // row0: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_data[0 * row_numel + 0], 6.0); + // row1: 3.0 + EXPECT_EQ(tensor2_data[1 * row_numel + 1], 3.0); + // row4 : 1.0 + 3.0 + EXPECT_EQ(tensor2_data[4 * row_numel + 6], 4.0); + // row5: 2.0 + 3.0 + EXPECT_EQ(tensor2_data[5 * row_numel + 7], 5.0); + // row6: 3.0 + EXPECT_EQ(tensor2_data[6 * row_numel + 1], 3.0); + // row7: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_data[7 * row_numel + 3], 6.0); + // row9: 2.0 + 3.0 + EXPECT_EQ(tensor2_data[9 * row_numel + 6], 5.0); +} diff --git a/paddle/operators/math/selected_rows_functor_test.cu b/paddle/operators/math/selected_rows_functor_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..8a9f25b98263c3bef50c38f358a20ea98ebe6324 --- /dev/null +++ b/paddle/operators/math/selected_rows_functor_test.cu @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "gtest/gtest.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/selected_rows_functor.h" + +TEST(selected_rows_functor, gpu_add) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators::math; + + GPUPlace gpu_place(0); + CPUPlace cpu_place; + CUDADeviceContext ctx(gpu_place); + SetConstant functor; + int64_t height = 10; + int64_t row_numel = 10; + + std::vector rows1{0, 4, 7}; + std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + make_ddim({static_cast(rows1.size()), row_numel}), gpu_place); + functor(ctx, in1_value, 1.0); + + std::vector rows2{0, 5, 7, 9}; + std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + make_ddim({static_cast(rows2.size()), row_numel}), gpu_place); + functor(ctx, in2_value, 2.0); + + std::unique_ptr output{new SelectedRows()}; + auto* out_value = output->mutable_value(); + + // simplely concat two SelectedRows + out_value->mutable_data(make_ddim({7, 10}), gpu_place); + + SelectedRowsAdd add_functor; + add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + + // input1 rows + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + // input2 rows + EXPECT_EQ(out_rows[3], 0); + EXPECT_EQ(out_rows[4], 5); + EXPECT_EQ(out_rows[5], 7); + EXPECT_EQ(out_rows[6], 9); + + Tensor out_cpu; + out_cpu.CopyFrom(*out_value, cpu_place, ctx); + ctx.Wait(); + + auto* out_cpu_data = out_cpu.data(); + // input1 value + EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0); + EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0); + EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0); + EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0); + // input2 value + EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0); + EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0); + EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0); + EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0); + EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0); + + std::unique_ptr tensor1{new Tensor()}; + tensor1->mutable_data(make_ddim({height, row_numel}), gpu_place); + functor(ctx, tensor1.get(), 3.0); + + std::unique_ptr tensor2{new Tensor()}; + tensor2->mutable_data(make_ddim({height, row_numel}), gpu_place); + + SelectedRowsAddTensor add_tensor_functor; + add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); + + Tensor tensor2_cpu; + tensor2_cpu.CopyFrom(*tensor2, cpu_place, ctx); + ctx.Wait(); + + auto* tensor2_cpu_data = tensor2_cpu.data(); + // row0: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0); + // row1: 3.0 + EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0); + // row4 : 1.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0); + // row5: 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0); + // row6: 3.0 + EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0); + // row7: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0); + // row9: 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0); +}