提交 960a5255 编写于 作者: Q qijun

fix gpu build error

上级 2ec8dab4
...@@ -7,7 +7,7 @@ endif() ...@@ -7,7 +7,7 @@ endif()
if(WITH_GPU) if(WITH_GPU)
nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) nv_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context)
else() else()
cc_library(math_function SRCS math_function.cc math_function.cu DEPS ${BLAS_LIB} device_context) cc_library(math_function SRCS math_function.cc DEPS ${BLAS_LIB} device_context)
endif() endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
...@@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,44 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
namespace paddle { namespace paddle {
......
...@@ -13,6 +13,44 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,44 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_USE_MKLML
#include <mkl_cblas.h>
#include <mkl_lapacke.h>
#include <mkl_vml_functions.h>
#endif
#ifdef PADDLE_USE_MKL
#include <mkl.h>
#include <mkl_lapacke.h>
#endif
#ifdef PADDLE_USE_ATLAS
extern "C" {
#include <cblas.h>
#include <clapack.h>
}
#endif
#ifdef PADDLE_USE_OPENBLAS
#include <cblas.h>
#include <lapacke.h>
#endif
#ifndef LAPACK_FOUND
extern "C" {
#include <cblas.h>
int LAPACKE_sgetrf(int matrix_layout, int m, int n, float* a, int lda,
int* ipiv);
int LAPACKE_dgetrf(int matrix_layout, int m, int n, double* a, int lda,
int* ipiv);
int LAPACKE_sgetri(int matrix_layout, int n, float* a, int lda,
const int* ipiv);
int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda,
const int* ipiv);
}
#endif
#include <cmath>
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h" #include "paddle/platform/device_context.h"
...@@ -27,6 +65,7 @@ namespace math { ...@@ -27,6 +65,7 @@ namespace math {
// Then matrixA: M * K, matrixB: K * N matrixC : M * N // Then matrixA: M * K, matrixB: K * N matrixC : M * N
// For more detailed info, please refer to // For more detailed info, please refer to
// http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A, const int M, const int N, const int K, const T alpha, const T* A,
const T* B, const T beta, T* C, platform::DeviceContext* context); const T* B, const T beta, T* C, platform::DeviceContext* context);
...@@ -34,8 +73,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, ...@@ -34,8 +73,8 @@ void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
void matmul(const framework::Tensor& matrix_a, bool trans_a, void matmul(const framework::Tensor& matrix_a, bool trans_a,
const framework::Tensor& matrix_b, bool trans_b, float alpha, const framework::Tensor& matrix_b, bool trans_b, T alpha,
framework::Tensor* matrix_out, float beta, framework::Tensor* matrix_out, T beta,
platform::DeviceContext* context); platform::DeviceContext* context);
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册