From 960a52555064d0496c8b76ce726c604d3fba66d4 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 14 Aug 2017 07:20:16 +0000 Subject: [PATCH] fix gpu build error --- paddle/operators/math/CMakeLists.txt | 2 +- paddle/operators/math/math_function.cc | 38 ----------------------- paddle/operators/math/math_function.cu | 38 ----------------------- paddle/operators/math/math_function.h | 43 ++++++++++++++++++++++++-- 4 files changed, 42 insertions(+), 79 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 84fffe68437..abcaf940ab0 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -7,7 +7,7 @@ endif() if(WITH_GPU) nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) else() - cc_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) + cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context) endif() nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index 03a63d063f8..affdd1ac2cd 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_USE_MKLML -#include -#include -#include -#endif - -#ifdef PADDLE_USE_MKL -#include -#include -#endif - -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - -#ifdef PADDLE_USE_OPENBLAS -#include -#include -#endif - -#ifndef LAPACK_FOUND -extern "C" { -#include -int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, - int* ipiv); -int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, - int* ipiv); -int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, - const int* ipiv); -int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, - const int* ipiv); -} -#endif - -#include #include "paddle/operators/math/math_function.h" namespace paddle { diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index c1ec2d93eda..da40b27c948 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifdef PADDLE_USE_MKLML -#include -#include -#include -#endif - -#ifdef PADDLE_USE_MKL -#include -#include -#endif - -#ifdef PADDLE_USE_ATLAS -extern "C" { -#include -#include -} -#endif - -#ifdef PADDLE_USE_OPENBLAS -#include -#include -#endif - -#ifndef LAPACK_FOUND -extern "C" { -#include -int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, - int* ipiv); -int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, - int* ipiv); -int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, - const int* ipiv); -int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, - const int* ipiv); -} -#endif - -#include #include "paddle/operators/math/math_function.h" namespace paddle { diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index c20e6a3b39f..155589fadb3 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -13,6 +13,44 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#ifdef PADDLE_USE_MKLML +#include +#include +#include +#endif + +#ifdef PADDLE_USE_MKL +#include +#include +#endif + +#ifdef PADDLE_USE_ATLAS +extern "C" { +#include +#include +} +#endif + +#ifdef PADDLE_USE_OPENBLAS +#include +#include +#endif + +#ifndef LAPACK_FOUND +extern "C" { +#include +int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda, + int* ipiv); +int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda, + int* ipiv); +int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda, + const int* ipiv); +int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, + const int* ipiv); +} +#endif + +#include #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" @@ -27,6 +65,7 @@ namespace math { // Then matrixA: M * K, matrixB: K * N matrixC : M * N // For more detailed info, please refer to // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html +template void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, const T alpha, const T* A, const T* B, const T beta, T* C, platform::DeviceContext* context); @@ -34,8 +73,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, // matrix multiply with continuous memory template void matmul(const framework::Tensor& matrix_a, bool trans_a, - const framework::Tensor& matrix_b, bool trans_b, float alpha, - framework::Tensor* matrix_out, float beta, + const framework::Tensor& matrix_b, bool trans_b, T alpha, + framework::Tensor* matrix_out, T beta, platform::DeviceContext* context); } // namespace math -- GitLab