From f59a7c1d36d2e930c48e118ad90f35a541e00223 Mon Sep 17 00:00:00 2001 From: qijun Date: Sat, 14 Oct 2017 14:57:07 -0700 Subject: [PATCH] add gpu functor for SelectedRows --- paddle/framework/lod_tensor.h | 3 - paddle/framework/selected_rows.h | 7 +- paddle/framework/selected_rows_test.cc | 2 +- paddle/framework/type_defs.h | 6 +- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function.cc | 17 +- paddle/operators/math/math_function.cu | 107 ++++++++ paddle/operators/math/math_function_test.cc | 179 ------------- paddle/operators/math/math_function_test.cu | 277 ++++++++++++++++++++ 9 files changed, 405 insertions(+), 195 deletions(-) create mode 100644 paddle/operators/math/math_function_test.cu diff --git a/paddle/framework/lod_tensor.h b/paddle/framework/lod_tensor.h index 4db36ee766..ee040a9144 100644 --- a/paddle/framework/lod_tensor.h +++ b/paddle/framework/lod_tensor.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/paddle/framework/selected_rows.h b/paddle/framework/selected_rows.h index ddc6dec194..cd90781371 100644 --- a/paddle/framework/selected_rows.h +++ b/paddle/framework/selected_rows.h @@ -10,6 +10,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/framework/lod_tensor.h" #include "paddle/framework/tensor.h" namespace paddle { @@ -34,9 +35,9 @@ class SelectedRows { void set_height(int64_t height) { height_ = height; } - const std::vector& rows() const { return rows_; } + const Vector& rows() const { return rows_; } - void set_rows(const std::vector& rows) { rows_ = rows; } + void set_rows(const Vector& rows) { rows_ = rows; } DDim GetCompleteDims() const { std::vector dims = vectorize(value_->dims()); @@ -48,7 +49,7 @@ class SelectedRows { // Notice: rows can be duplicate. We can have {0, 4, 7, 0, 5, 7, 9} here. // SelectedRows are simplely concated when adding together. Until a // SelectedRows add a Tensor, will the duplicate rows be handled. - std::vector rows_; + Vector rows_; std::unique_ptr value_{nullptr}; int64_t height_; }; diff --git a/paddle/framework/selected_rows_test.cc b/paddle/framework/selected_rows_test.cc index 4ee13a65d7..055b867600 100644 --- a/paddle/framework/selected_rows_test.cc +++ b/paddle/framework/selected_rows_test.cc @@ -18,7 +18,7 @@ namespace framework { class SelectedRowsTester : public ::testing::Test { public: virtual void SetUp() override { - std::vector rows{0, 4, 7}; + Vector rows{0, 4, 7}; int64_t height = 10; int64_t row_numel = 100; selected_rows_.reset(new SelectedRows(rows, height)); diff --git a/paddle/framework/type_defs.h b/paddle/framework/type_defs.h index 7e1b79c97b..0c0a72de31 100644 --- a/paddle/framework/type_defs.h +++ b/paddle/framework/type_defs.h @@ -1,11 +1,8 @@ /* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. - Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -37,7 +34,8 @@ using OpCreator = std::function; using GradOpMakerFN = std::function>( - const OpDescBind&, const std::unordered_set& /*no_grad_set*/)>; + const OpDescBind&, const std::unordered_set& /*no_grad_set*/, + std::unordered_map* /*grad_to_var*/)>; } // namespace framework } // namespace paddle diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 1a2f623ce7..a7f275bae5 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu im2col.cc im2col.cu DEPS cblas device_context operator) - nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) + nv_test(math_function_gpu_test SRCS math_function_test.cu DEPS math_function tensor) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index ddb904aa46..a1faafb7c4 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -162,15 +162,24 @@ struct SelectedRowsAdd { PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); - auto* out_data = out_value->data(); + auto in1_place = input1.place(); + PADDLE_ENFORCE(platform::is_cpu_place(in1_place)); + auto in2_place = input2.place(); + PADDLE_ENFORCE(platform::is_cpu_place(in2_place)); + auto out_place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_cpu_place(out_place)); + auto* out_data = out_value->data(); auto* in1_data = in1_value.data(); - memory::Copy(platform::CPUPlace(), out_data, platform::CPUPlace(), in1_data, + memory::Copy(boost::get(out_place), out_data, + boost::get(in1_place), in1_data, in1_value.numel() * sizeof(T)); auto* in2_data = in2_value.data(); - memory::Copy(platform::CPUPlace(), out_data + in1_value.numel(), - platform::CPUPlace(), in2_data, in2_value.numel() * sizeof(T)); + memory::Copy(boost::get(out_place), + out_data + in1_value.numel(), + boost::get(in2_place), in2_data, + in2_value.numel() * sizeof(T)); } }; diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 649f1f352c..26bf0ec2f1 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -155,6 +155,113 @@ void matmul( matrix_b.data(), beta, matrix_out->data()); } +template +struct SelectedRowsAdd { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::SelectedRows& input2, + framework::SelectedRows* output) { + auto in1_height = input1.height(); + PADDLE_ENFORCE_EQ(in1_height, input2.height()); + output->set_height(in1_height); + + auto& in1_rows = input1.rows(); + auto& in2_rows = input2.rows(); + std::vector out_rows; + out_rows.reserve(in1_rows.size() + in2_rows.size()); + + // concat rows + out_rows.insert(out_rows.end(), in1_rows.begin(), in1_rows.end()); + out_rows.insert(out_rows.end(), in2_rows.begin(), in2_rows.end()); + output->set_rows(out_rows); + + auto* out_value = output->mutable_value(); + auto& in1_value = input1.value(); + auto& in2_value = input2.value(); + + auto in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, in2_value.numel() / in2_rows.size()); + PADDLE_ENFORCE_EQ(in1_row_numel, out_value->numel() / out_rows.size()); + + auto* out_data = out_value->data(); + auto* in1_data = in1_value.data(); + + auto in1_place = input1.place(); + PADDLE_ENFORCE(platform::is_gpu_place(in1_place)); + auto in2_place = input2.place(); + PADDLE_ENFORCE(platform::is_gpu_place(in2_place)); + auto out_place = context.GetPlace(); + PADDLE_ENFORCE(platform::is_gpu_place(out_place)) + + memory::Copy( + boost::get(out_place), out_data, + boost::get(in1_place), in1_data, + in1_value.numel() * sizeof(T), + reinterpret_cast(context).stream()); + + auto* in2_data = in2_value.data(); + memory::Copy( + boost::get(out_place), out_data + in1_value.numel(), + boost::get(in2_place), in2_data, + in2_value.numel() * sizeof(T), + reinterpret_cast(context).stream()); + } +}; + +template struct SelectedRowsAdd; + +namespace { +template +__global__ void SelectedRowsAddTensorKernel(T* selected_rows, int64_t* rows, + T* tensor_in, T* tensor_out, + const int64_t row_numel) { + const ty = blockIdx.y; + int tid = threadIdx.x; + + selected_rows += ty * row_numel; + tensor_in += rows[ty] * row_numel; + tensor_out += rows[ty] * row_numel; + + for (int index = tid; index < row_numel; index += block_size) { + tensor_out[index] = tensor_in[index] + selected_rows[index]; + } +} +} + +template +struct SelectedRowsAddTensor { + void operator()(const platform::DeviceContext& context, + const framework::SelectedRows& input1, + const framework::Tensor& input2, framework::Tensor* output) { + auto in1_height = input1.height(); + auto in2_dims = input2.dims(); + auto out_dims = output->dims(); + PADDLE_ENFORCE_EQ(in1_height, in2_dims[0]); + PADDLE_ENFORCE_EQ(in1_height, out_dims[0]); + + auto& in1_value = input1.value(); + auto& in1_rows = input1.rows(); + + int64_t in1_row_numel = in1_value.numel() / in1_rows.size(); + PADDLE_ENFORCE_EQ(in1_row_numel, input2.numel() / in1_height); + PADDLE_ENFORCE_EQ(in1_row_numel, output->numel() / in1_height); + + auto* in1_data = in1_value.data(); + auto* in2_data = input2.data(); + auto* out_data = output->data(); + + const int block_size = 256; + dim3 threads(block_size, 1); + dim3 grid(1, in1_height); + SelectedRowsAddTensorKernel<<< + grid, threads, 0, + reinterpret_cast(ctx).stream()>>>( + in1_data, in1_rows.data(), in2_data, out_data, in1_row_numel); + } +}; + +template struct SelectedRowsAddTensor; + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc index fe0d1981fa..33c561f6c6 100644 --- a/paddle/operators/math/math_function_test.cc +++ b/paddle/operators/math/math_function_test.cc @@ -1,185 +1,6 @@ #include "paddle/operators/math/math_function.h" #include "gtest/gtest.h" -#ifdef PADDLE_WITH_CUDA -TEST(math_function, notrans_mul_trans) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor out_gpu; - paddle::framework::Tensor out; - - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); - float arr[6] = {0, 1, 2, 3, 4, 5}; - memcpy(input1_ptr, arr, 6 * sizeof(float)); - - auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - input1_gpu.CopyFrom(input1, *gpu_place, context); - input2_gpu.CopyFrom(input1, *gpu_place, context); - - out_gpu.mutable_data({2, 2}, *gpu_place); - - paddle::operators::math::matmul( - context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); - - out.CopyFrom(out_gpu, *cpu_place, context); - - float* out_ptr = out.data(); - context.Wait(); - EXPECT_EQ(out_ptr[0], 5); - EXPECT_EQ(out_ptr[1], 14); - EXPECT_EQ(out_ptr[2], 14); - EXPECT_EQ(out_ptr[3], 50); - delete gpu_place; -} - -TEST(math_function, trans_mul_notrans) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor out_gpu; - paddle::framework::Tensor out; - - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); - float arr[6] = {0, 1, 2, 3, 4, 5}; - memcpy(input1_ptr, arr, 6 * sizeof(float)); - - auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - input1_gpu.CopyFrom(input1, *gpu_place, context); - input2_gpu.CopyFrom(input1, *gpu_place, context); - - out_gpu.mutable_data({3, 3}, *gpu_place); - - paddle::operators::math::matmul( - context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); - - out.CopyFrom(out_gpu, *cpu_place, context); - - float* out_ptr = out.data(); - context.Wait(); - EXPECT_EQ(out_ptr[0], 9); - EXPECT_EQ(out_ptr[1], 12); - EXPECT_EQ(out_ptr[2], 15); - EXPECT_EQ(out_ptr[3], 12); - EXPECT_EQ(out_ptr[4], 17); - EXPECT_EQ(out_ptr[5], 22); - EXPECT_EQ(out_ptr[6], 15); - EXPECT_EQ(out_ptr[7], 22); - EXPECT_EQ(out_ptr[8], 29); - delete gpu_place; -} - -TEST(math_function, gemm_notrans_cublas) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input2; - paddle::framework::Tensor input3; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor input3_gpu; - - int m = 2; - int n = 3; - int k = 3; - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); - float arr1[6] = {0, 1, 2, 3, 4, 5}; - memcpy(input1_ptr, arr1, 6 * sizeof(float)); - float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); - float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; - memcpy(input2_ptr, arr2, 12 * sizeof(float)); - float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); - float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - memcpy(input3_ptr, arr3, 8 * sizeof(float)); - - auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - input1_gpu.CopyFrom(input1, *gpu_place, context); - input2_gpu.CopyFrom(input2, *gpu_place, context); - input3_gpu.CopyFrom(input3, *gpu_place, context); - float* a = input1_gpu.data(); - float* b = input2_gpu.data(); - float* c = input3_gpu.mutable_data(*gpu_place); - - paddle::operators::math::gemm( - context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); - - input3.CopyFrom(input3_gpu, *cpu_place, context); - - // numpy code: - // a = np.arange(6).reshape(2, 3) - // b = np.arange(12).reshape(3, 4)[:, 1:] - // c = np.arange(8).reshape(2, 4)[:, 1:] - // out = np.arange(8).reshape(2, 4) - // out[:, 1:] = np.dot(a, b) + c - context.Wait(); - EXPECT_EQ(input3_ptr[0], 0); - EXPECT_EQ(input3_ptr[1], 24); - EXPECT_EQ(input3_ptr[2], 28); - EXPECT_EQ(input3_ptr[3], 32); - EXPECT_EQ(input3_ptr[4], 4); - EXPECT_EQ(input3_ptr[5], 73); - EXPECT_EQ(input3_ptr[6], 86); - EXPECT_EQ(input3_ptr[7], 99); - delete gpu_place; -} - -TEST(math_function, gemm_trans_cublas) { - paddle::framework::Tensor input1; - paddle::framework::Tensor input2; - paddle::framework::Tensor input3; - paddle::framework::Tensor input1_gpu; - paddle::framework::Tensor input2_gpu; - paddle::framework::Tensor input3_gpu; - - int m = 2; - int n = 3; - int k = 3; - auto* cpu_place = new paddle::platform::CPUPlace(); - float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); - float arr1[6] = {0, 1, 2, 3, 4, 5}; - memcpy(input1_ptr, arr1, 6 * sizeof(float)); - float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); - float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; - memcpy(input2_ptr, arr2, 12 * sizeof(float)); - float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); - float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; - memcpy(input3_ptr, arr3, 8 * sizeof(float)); - - auto* gpu_place = new paddle::platform::GPUPlace(0); - paddle::platform::CUDADeviceContext context(*gpu_place); - - input1_gpu.CopyFrom(input1, *gpu_place, context); - input2_gpu.CopyFrom(input2, *gpu_place, context); - input3_gpu.CopyFrom(input3, *gpu_place, context); - float* a = input1_gpu.data(); - float* b = input2_gpu.data(); - float* c = input3_gpu.mutable_data(*gpu_place); - - paddle::operators::math::gemm( - context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); - - input3.CopyFrom(input3_gpu, *cpu_place, context); - context.Wait(); - - EXPECT_EQ(input3_ptr[0], 0); - EXPECT_EQ(input3_ptr[1], 24); - EXPECT_EQ(input3_ptr[2], 28); - EXPECT_EQ(input3_ptr[3], 32); - EXPECT_EQ(input3_ptr[4], 4); - EXPECT_EQ(input3_ptr[5], 73); - EXPECT_EQ(input3_ptr[6], 86); - EXPECT_EQ(input3_ptr[7], 99); - delete gpu_place; -} -#endif - TEST(math_function, gemm_notrans_cblas) { paddle::framework::Tensor input1; paddle::framework::Tensor input2; diff --git a/paddle/operators/math/math_function_test.cu b/paddle/operators/math/math_function_test.cu new file mode 100644 index 0000000000..e691078bb6 --- /dev/null +++ b/paddle/operators/math/math_function_test.cu @@ -0,0 +1,277 @@ +#include "gtest/gtest.h" +#include "paddle/operators/math/math_function.h" + +TEST(math_function, notrans_mul_trans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input1, *gpu_place, context); + + out_gpu.mutable_data({2, 2}, *gpu_place); + + paddle::operators::math::matmul( + context, input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0); + + out.CopyFrom(out_gpu, *cpu_place, context); + + float* out_ptr = out.data(); + context.Wait(); + EXPECT_EQ(out_ptr[0], 5); + EXPECT_EQ(out_ptr[1], 14); + EXPECT_EQ(out_ptr[2], 14); + EXPECT_EQ(out_ptr[3], 50); + delete gpu_place; +} + +TEST(math_function, trans_mul_notrans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input1, *gpu_place, context); + + out_gpu.mutable_data({3, 3}, *gpu_place); + + paddle::operators::math::matmul( + context, input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0); + + out.CopyFrom(out_gpu, *cpu_place, context); + + float* out_ptr = out.data(); + context.Wait(); + EXPECT_EQ(out_ptr[0], 9); + EXPECT_EQ(out_ptr[1], 12); + EXPECT_EQ(out_ptr[2], 15); + EXPECT_EQ(out_ptr[3], 12); + EXPECT_EQ(out_ptr[4], 17); + EXPECT_EQ(out_ptr[5], 22); + EXPECT_EQ(out_ptr[6], 15); + EXPECT_EQ(out_ptr[7], 22); + EXPECT_EQ(out_ptr[8], 29); + delete gpu_place; +} + +TEST(math_function, gemm_notrans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({3, 4}, *cpu_place); + float arr2[12] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input2, *gpu_place, context); + input3_gpu.CopyFrom(input3, *gpu_place, context); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, false, m, n, k, 1, a, 3, b + 1, 4, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place, context); + + // numpy code: + // a = np.arange(6).reshape(2, 3) + // b = np.arange(12).reshape(3, 4)[:, 1:] + // c = np.arange(8).reshape(2, 4)[:, 1:] + // out = np.arange(8).reshape(2, 4) + // out[:, 1:] = np.dot(a, b) + c + context.Wait(); + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} + +TEST(math_function, gemm_trans_cublas) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input2; + paddle::framework::Tensor input3; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor input3_gpu; + + int m = 2; + int n = 3; + int k = 3; + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr1[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr1, 6 * sizeof(float)); + float* input2_ptr = input2.mutable_data({4, 3}, *cpu_place); + float arr2[12] = {0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11}; + memcpy(input2_ptr, arr2, 12 * sizeof(float)); + float* input3_ptr = input3.mutable_data({2, 4}, *cpu_place); + float arr3[8] = {0, 1, 2, 3, 4, 5, 6, 7}; + memcpy(input3_ptr, arr3, 8 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::CUDADeviceContext context(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place, context); + input2_gpu.CopyFrom(input2, *gpu_place, context); + input3_gpu.CopyFrom(input3, *gpu_place, context); + float* a = input1_gpu.data(); + float* b = input2_gpu.data(); + float* c = input3_gpu.mutable_data(*gpu_place); + + paddle::operators::math::gemm( + context, false, true, m, n, k, 1, a, 3, b + 3, 3, 1, c + 1, 4); + + input3.CopyFrom(input3_gpu, *cpu_place, context); + context.Wait(); + + EXPECT_EQ(input3_ptr[0], 0); + EXPECT_EQ(input3_ptr[1], 24); + EXPECT_EQ(input3_ptr[2], 28); + EXPECT_EQ(input3_ptr[3], 32); + EXPECT_EQ(input3_ptr[4], 4); + EXPECT_EQ(input3_ptr[5], 73); + EXPECT_EQ(input3_ptr[6], 86); + EXPECT_EQ(input3_ptr[7], 99); + delete gpu_place; +} + +TEST(math_function, selected_rows_add) { + using namespace paddle::framework; + using namespace paddle::platform; + using namespace paddle::operators::math; + + CPUPlace gpu_place(0); + CUDADeviceContext ctx(gpu_place); + SetConstant functor; + int64_t height = 10; + int64_t row_numel = 10; + + Vector rows1{0, 4, 7}; + std::unique_ptr selected_rows1{new SelectedRows(rows1, height)}; + auto* in1_value = selected_rows1->mutable_value(); + in1_value->mutable_data( + make_ddim({static_cast(rows1.size()), row_numel}), gpu_place); + functor(ctx, in1_value, 1.0); + + Vector rows2{0, 5, 7, 9}; + std::unique_ptr selected_rows2{new SelectedRows(rows2, height)}; + auto* in2_value = selected_rows2->mutable_value(); + in2_value->mutable_data( + make_ddim({static_cast(rows2.size()), row_numel}), gpu_place); + functor(ctx, in2_value, 2.0); + + std::unique_ptr output{new SelectedRows()}; + auto* out_value = output->mutable_value(); + + // simplely concat two SelectedRows + out_value->mutable_data(make_ddim({7, 10}), gpu_place); + + SelectedRowsAdd add_functor; + add_functor(ctx, *selected_rows1, *selected_rows2, output.get()); + + auto out_height = output->height(); + EXPECT_EQ(out_height, height); + + auto& out_rows = output->rows(); + + // input1 rows + EXPECT_EQ(out_rows[0], 0); + EXPECT_EQ(out_rows[1], 4); + EXPECT_EQ(out_rows[2], 7); + // input2 rows + EXPECT_EQ(out_rows[3], 0); + EXPECT_EQ(out_rows[4], 5); + EXPECT_EQ(out_rows[5], 7); + EXPECT_EQ(out_rows[6], 9); + + Tensor out_cpu; + out_cpu.CopyFrom(*out_value, platform::CPUPlace(), ctx); + ctx.Wait(); + + auto* out_cpu_data = out_cpu.data(); + // input1 value + EXPECT_EQ(out_cpu_data[0 * row_numel + 0], 1.0); + EXPECT_EQ(out_cpu_data[0 * row_numel + 8], 1.0); + EXPECT_EQ(out_cpu_data[1 * row_numel + 1], 1.0); + EXPECT_EQ(out_cpu_data[2 * row_numel + 6], 1.0); + // input2 value + EXPECT_EQ(out_cpu_data[3 * row_numel + 3], 2.0); + EXPECT_EQ(out_cpu_data[3 * row_numel + 8], 2.0); + EXPECT_EQ(out_cpu_data[4 * row_numel + 4], 2.0); + EXPECT_EQ(out_cpu_data[5 * row_numel + 7], 2.0); + EXPECT_EQ(out_cpu_data[6 * row_numel + 9], 2.0); + + std::unique_ptr tensor1{new Tensor()}; + tensor1->mutable_data(make_ddim({height, row_numel}), gpu_place); + SetConstant constant_functor; + constant_functor(ctx, tensor1.get(), 3.0); + + std::unique_ptr tensor2{new Tensor()}; + tensor2->mutable_data(make_ddim({height, row_numel}), gpu_place); + + SelectedRowsAddTensor add_tensor_functor; + add_tensor_functor(ctx, *output, *tensor1, tensor2.get()); + + Tensor tensor2_cpu; + tensor2_cpu.CopyFrom(*tensor2, platform::CPUPlace(), ctx); + ctx.Wait(); + + auto* tensor2_cpu_data = tensor2_cpu->data(); + // row0: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[0 * row_numel + 0], 6.0); + // row1: 3.0 + EXPECT_EQ(tensor2_cpu_data[1 * row_numel + 1], 3.0); + // row4 : 1.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[4 * row_numel + 6], 4.0); + // row5: 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[5 * row_numel + 7], 5.0); + // row6: 3.0 + EXPECT_EQ(tensor2_cpu_data[6 * row_numel + 1], 3.0); + // row7: 1.0 + 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[7 * row_numel + 3], 6.0); + // row9: 2.0 + 3.0 + EXPECT_EQ(tensor2_cpu_data[9 * row_numel + 6], 5.0); +} -- GitLab