diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index e145649d300d57425b9c83bd7daa4149cb698e2c..adca120638d305efad9f10e85b6e73a9109740ec 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -264,6 +264,10 @@ class ExecutionContext : public InferShapeContext { platform::Place GetPlace() const { return device_context_->GetPlace(); } + const platform::DeviceContext* device_context() const { + return device_context_; + } + const platform::DeviceContext* device_context_; }; diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index cd1b4de426a49fa66dbbf8cf7d09990ac8d21227..b8c779f4e5fc7bc51298cdd35b26c2c8ac98edf6 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -105,6 +105,8 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; + platform::Place place() const { return holder_->place(); } + private: template inline void check_memory_size() const; diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 7f56aaa92cc45d81440084cdeb3c6eb3b6fda3df..373611cc0ee952de813f01d32d1516e1a8384750 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -41,6 +41,7 @@ function(op_library TARGET) endif() endfunction() +add_subdirectory(math) cc_test(gather_test SRCS gather_test.cc DEPS tensor) cc_library(net_op SRCS net_op.cc DEPS op_registry) @@ -50,7 +51,7 @@ op_library(add_op SRCS add_op.cc add_op.cu) op_library(mean_op SRCS mean_op.cc mean_op.cu) -op_library(mul_op SRCS mul_op.cc mul_op.cu) +op_library(mul_op SRCS mul_op.cc mul_op.cu DEPS math_function) op_library(rowwise_add_op SRCS rowwise_add_op.cu rowwise_add_op.cc) op_library(sigmoid_op SRCS sigmoid_op.cc sigmoid_op.cu) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..abcaf940ab0128d6948acc620d678632c8f48960 --- /dev/null +++ b/paddle/operators/math/CMakeLists.txt @@ -0,0 +1,13 @@ +if(WITH_MKLML) + set(BLAS_LIB mklml) +else() + set(BLAS_LIB cblas) +endif() + +if(WITH_GPU) + nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) +else() + cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context) +endif() + +nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc new file mode 100644 index 0000000000000000000000000000000000000000..affdd1ac2cd486930881ee6b34a4b32f41df7ee9 --- /dev/null +++ b/paddle/operators/math/math_function.cc @@ -0,0 +1,114 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; + cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + int lda = K; + int ldb = N; + int ldc = N; + cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_cpu_place(matrix_a.place()) && + platform::is_cpu_place(matrix_b.place()) && + platform::is_cpu_place(matrix_out->place()), + "Matrix must all be in CPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu new file mode 100644 index 0000000000000000000000000000000000000000..da40b27c948918e4997f4a046d2145552296158b --- /dev/null +++ b/paddle/operators/math/math_function.cu @@ -0,0 +1,127 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const float alpha, const float* A, + const float* B, const float beta, float* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + PADDLE_ENFORCE(platform::dynload::cublasSgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, + const double alpha, const double* A, + const double* B, const double beta, + double* C, + platform::DeviceContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + PADDLE_ENFORCE(platform::dynload::cublasDgemm( + reinterpret_cast(context)->cublas_handle(), + cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, N)); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, float alpha, + framework::Tensor* matrix_out, + float beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +template <> +void matmul(const framework::Tensor& matrix_a, + bool trans_a, + const framework::Tensor& matrix_b, + bool trans_b, double alpha, + framework::Tensor* matrix_out, + double beta, + platform::DeviceContext* context) { + auto dim_a = matrix_a.dims(); + auto dim_b = matrix_b.dims(); + auto dim_out = matrix_out->dims(); + PADDLE_ENFORCE(dim_a.size() == 2 && dim_b.size() == 2 && dim_out.size() == 2, + "The input and output of matmul be matrix"); + + PADDLE_ENFORCE(platform::is_gpu_place(matrix_a.place()) && + platform::is_gpu_place(matrix_b.place()) && + platform::is_gpu_place(matrix_out->place()), + "Matrix must all be in GPUPlace"); + + int M = dim_out[0]; + int N = dim_out[1]; + int K = (trans_a == false) ? dim_a[1] : dim_a[0]; + + CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; + + gemm( + transA, transB, M, N, K, alpha, matrix_a.data(), + matrix_b.data(), beta, matrix_out->data(), context); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h new file mode 100644 index 0000000000000000000000000000000000000000..155589fadb3ed9f59160a750d546dd8093a56cbe --- /dev/null +++ b/paddle/operators/math/math_function.h @@ -0,0 +1,82 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include + +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/enforce.h" + +namespace paddle { +namespace operators { +namespace math { + +// Support continuous memory now +// If transA = N, and transB = N +// Then matrixA: M * K, matrixB: K * N matrixC : M * N +// For more detailed info, please refer to +// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html +template +void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, const T* A, + const T* B, const T beta, T* C, platform::DeviceContext* context); + +// matrix multiply with continuous memory +template +void matmul(const framework::Tensor& matrix_a, bool trans_a, + const framework::Tensor& matrix_b, bool trans_b, T alpha, + framework::Tensor* matrix_out, T beta, + platform::DeviceContext* context); + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/math_function_test.cc b/paddle/operators/math/math_function_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..6c020c4ff7285b43bc5836d80c173d3a068e72b3 --- /dev/null +++ b/paddle/operators/math/math_function_test.cc @@ -0,0 +1,75 @@ +#include "paddle/operators/math/math_function.h" +#include "gtest/gtest.h" + +#ifndef PADDLE_ONLY_CPU +TEST(math_function, notrans_mul_trans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({2, 2}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, false, input2_gpu, true, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 5); + EXPECT_EQ(out_ptr[1], 14); + EXPECT_EQ(out_ptr[2], 14); + EXPECT_EQ(out_ptr[3], 50); +} + +TEST(math_function, trans_mul_notrans) { + paddle::framework::Tensor input1; + paddle::framework::Tensor input1_gpu; + paddle::framework::Tensor input2_gpu; + paddle::framework::Tensor out_gpu; + paddle::framework::Tensor out; + + auto* cpu_place = new paddle::platform::CPUPlace(); + float* input1_ptr = input1.mutable_data({2, 3}, *cpu_place); + float arr[6] = {0, 1, 2, 3, 4, 5}; + memcpy(input1_ptr, arr, 6 * sizeof(float)); + + auto* gpu_place = new paddle::platform::GPUPlace(0); + paddle::platform::DeviceContext* context = + new paddle::platform::CUDADeviceContext(*gpu_place); + + input1_gpu.CopyFrom(input1, *gpu_place); + input2_gpu.CopyFrom(input1, *gpu_place); + + out_gpu.mutable_data({3, 3}, *gpu_place); + + paddle::operators::math::matmul( + input1_gpu, true, input2_gpu, false, 1, &out_gpu, 0, context); + + out.CopyFrom(out_gpu, *cpu_place); + + float* out_ptr = out.data(); + EXPECT_EQ(out_ptr[0], 9); + EXPECT_EQ(out_ptr[1], 12); + EXPECT_EQ(out_ptr[2], 15); + EXPECT_EQ(out_ptr[3], 12); + EXPECT_EQ(out_ptr[4], 17); + EXPECT_EQ(out_ptr[5], 22); + EXPECT_EQ(out_ptr[6], 15); + EXPECT_EQ(out_ptr[7], 22); + EXPECT_EQ(out_ptr[8], 29); +} +#endif diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index ae924375c2fb27104ffeb98268aec36fafde3c69..92322c75690f9d7506c089623d720c745b4f6f54 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -13,6 +13,7 @@ limitations under the License. */ #include "paddle/operators/mul_op.h" +#include "paddle/operators/math/math_function.h" namespace paddle { namespace operators { diff --git a/paddle/operators/mul_op.cu b/paddle/operators/mul_op.cu index 43debbc21a365a15c914e60e151f7782b82080cb..346a7e505d123b5e4e831daa39a1f6349b3dcccf 100644 --- a/paddle/operators/mul_op.cu +++ b/paddle/operators/mul_op.cu @@ -16,5 +16,4 @@ #include "paddle/operators/mul_op.h" namespace ops = paddle::operators; - REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel); diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index ca3105fa4f158064c822a319e2c9c5a40e39d481..b7812fd1a7a72f5ce543e18c8b7b5b51deff2204 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -13,6 +13,9 @@ limitations under the License. */ #pragma once + +#include "paddle/operators/math/math_function.h" + #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" diff --git a/paddle/platform/dynload/cublas.h b/paddle/platform/dynload/cublas.h index aad8097dbb33cbf6c0f2b4b3efb1376fbe96bc74..9d8343c0b5e200b390ccda760f09816959952e9d 100644 --- a/paddle/platform/dynload/cublas.h +++ b/paddle/platform/dynload/cublas.h @@ -62,12 +62,12 @@ extern void *cublas_dso_handle; DECLARE_DYNAMIC_LOAD_CUBLAS_WRAP(__name) #define CUBLAS_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasSgemv); \ - __macro(cublasDgemv); \ - __macro(cublasSgemm); \ - __macro(cublasDgemm); \ - __macro(cublasSgeam); \ - __macro(cublasDgeam); \ + __macro(cublasSgemv_v2); \ + __macro(cublasDgemv_v2); \ + __macro(cublasSgemm_v2); \ + __macro(cublasDgemm_v2); \ + __macro(cublasSgeam_v2); \ + __macro(cublasDgeam_v2); \ __macro(cublasCreate_v2); \ __macro(cublasDestroy_v2); \ __macro(cublasSetStream_v2); \