diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 0532e8f034ccc8ddb3a0a6fac37b0415de6056b4..c678b37616a73ae7239ba133059344b4ac55f56e 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -32,7 +32,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const float beta, float* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { cblas_sgemm(CblasRowMajor, transA, transB, @@ -63,7 +63,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const double beta, double* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { cblas_dgemm(CblasRowMajor, transA, transB, @@ -80,42 +80,6 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc); } -template <> -void axpy(const int n, - const float alpha, - const float* x, - float* y, - const platform::DeviceContext* context) { - cblas_saxpy(n, alpha, x, 1, y, 1); -} - -template <> -void axpy(const int n, - const double alpha, - const double* x, - double* y, - const platform::DeviceContext* context) { - cblas_daxpy(n, alpha, x, 1, y, 1); -} - -template <> -float dotProduct( - const int n, - const float* x, - const float* y, - const platform::DeviceContext* context) { - return cblas_sdot(n, x, 1, y, 1); -} - -template <> -double dotProduct( - const int n, - const double* x, - const double* y, - const platform::DeviceContext* context) { - return cblas_ddot(n, x, 1, y, 1); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 46301df8f9d4f82f0d915a131965e5fd76038be6..190312e59d45041d97b3b434d52fe43a2db2ad95 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -20,29 +20,29 @@ namespace operators { namespace math { template <> -void gemm(const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, - const int M, - const int N, - const int K, - const float alpha, - const float* A, - const int lda, - const float* B, - const int ldb, - const float beta, - float* C, - const int ldc, - const platform::DeviceContext* context) { +void gemm(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, + const int N, + const int K, + const float alpha, + const float* A, + const int lda, + const float* B, + const int ldb, + const float beta, + float* C, + const int ldc, + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasSgemm( - reinterpret_cast(context)-> + reinterpret_cast(context)-> cublas_handle(), cuTransB, cuTransA, @@ -73,15 +73,15 @@ void gemm(const CBLAS_TRANSPOSE transA, const double beta, double* C, const int ldc, - const platform::DeviceContext* context) { + platform::DeviceContext* context) { // Note that cublas follows fortran order, so the order is different from // the cblas convention. cublasOperation_t cuTransA = - (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; cublasOperation_t cuTransB = - (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; PADDLE_ENFORCE(platform::dynload::cublasDgemm( - reinterpret_cast(context)-> + reinterpret_cast(context)-> cublas_handle(), cuTransB, cuTransA, @@ -99,48 +99,6 @@ void gemm(const CBLAS_TRANSPOSE transA, } -template <> -void axpy(const int n, - const float alpha, - const float* x, - float* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasSaxpy( - reinterpret_cast(context)-> - cublas_handle(), N, &alpha, X, 1, Y, 1)); -} - -template <> -void axpy(const int n, - const double alpha, - const double* x, - double* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasDaxpy( - reinterpret_cast(context)-> - cublas_handle(), N, &alpha, X, 1, Y, 1)); -} - -template <> -float dotProduct(const int n, - const float* x, - const float* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasSdot( - reinterpret_cast(context)-> - cublas_handle(), n, a, 1, b, 1, &result)); -} - -template <> -double dotProduct(const int n, - const double* x, - const double* y, - const platform::DeviceContext* context) { - CUBLAS_ENFORCE(platform::dynload::cublasDdot( - reinterpret_cast(context)-> - cublas_handle(), n, a, 1, b, 1, &result)); -} - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c5b7fe8793c952afa4af7bae02434f6d1df86ca0..f1f87ac5f2d277e82977faeb97ada691a9c8c5a8 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA, const T beta, T* C, const int ldc, - const platform::DeviceContext* context); - -template -void axpy(const int n, - const T alpha, - const T* x, - T* y, - const platform::DeviceContext* context); - -template -T dotProduct(const int n, - const T* x, - const T* y, - const platform::DeviceContext* context); + platform::DeviceContext* context); } // namespace math } // namespace operators diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index e1759d00c55ab9caf5e6714883d7b187deb05363..0bffe79a1e23a23b8a5fcdf298fcd63bac9e4ed5 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -37,20 +37,21 @@ public: int N = out_dim[1]; int K = in0_dim[1]; - paddle::operators::math::template gemm(CblasNoTrans, - CblasNoTrans, - M, - N, - K, - 1, - input0->data(), - K, - input1->data(), - N, - 0, - output->data(), - N, - &context.device_context()); + paddle::operators::math::template gemm( + CblasNoTrans, + CblasNoTrans, + M, + N, + K, + 1, + input0->data(), + K, + input1->data(), + N, + 0, + output->data(), + N, + &const_cast(context.device_context())); } };