提交 f190a795 编写于 作者: Q qijun

fix gpu build error

上级 22dac40c
...@@ -32,7 +32,7 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA, ...@@ -32,7 +32,7 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float beta, const float beta,
float* C, float* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor, cblas_sgemm(CblasRowMajor,
transA, transA,
transB, transB,
...@@ -63,7 +63,7 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -63,7 +63,7 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta, const double beta,
double* C, double* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor, cblas_dgemm(CblasRowMajor,
transA, transA,
transB, transB,
...@@ -80,42 +80,6 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -80,42 +80,6 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
ldc); ldc);
} }
template <>
void axpy<platform::CPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template <>
float dotProduct<platform::CPUPlace, float>(
const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
return cblas_sdot(n, x, 1, y, 1);
}
template <>
double dotProduct<platform::CPUPlace, double>(
const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
return cblas_ddot(n, x, 1, y, 1);
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -20,29 +20,29 @@ namespace operators { ...@@ -20,29 +20,29 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA, void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const CBLAS_TRANSPOSE transB,
const int M, const int M,
const int N, const int N,
const int K, const int K,
const float alpha, const float alpha,
const float* A, const float* A,
const int lda, const int lda,
const float* B, const float* B,
const int ldb, const int ldb,
const float beta, const float beta,
float* C, float* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)-> reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(), cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
...@@ -73,15 +73,15 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -73,15 +73,15 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta, const double beta,
double* C, double* C,
const int ldc, const int ldc,
const platform::DeviceContext* context) { platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)-> reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(), cublas_handle(),
cuTransB, cuTransB,
cuTransA, cuTransA,
...@@ -99,48 +99,6 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA, ...@@ -99,48 +99,6 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
} }
template <>
void axpy<platform::GPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
void axpy<platform::GPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
float dotProduct<platform::GPUPlace, float>(const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
template <>
double dotProduct<platform::GPUPlace, double>(const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA, ...@@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA,
const T beta, const T beta,
T* C, T* C,
const int ldc, const int ldc,
const platform::DeviceContext* context); platform::DeviceContext* context);
template <typename Place, typename T>
void axpy(const int n,
const T alpha,
const T* x,
T* y,
const platform::DeviceContext* context);
template <typename Place, typename T>
T dotProduct(const int n,
const T* x,
const T* y,
const platform::DeviceContext* context);
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -37,20 +37,21 @@ public: ...@@ -37,20 +37,21 @@ public:
int N = out_dim[1]; int N = out_dim[1];
int K = in0_dim[1]; int K = in0_dim[1];
paddle::operators::math::template gemm<Place, T>(CblasNoTrans, paddle::operators::math::template gemm<Place, T>(
CblasNoTrans, CblasNoTrans,
M, CblasNoTrans,
N, M,
K, N,
1, K,
input0->data<T>(), 1,
K, input0->data<T>(),
input1->data<T>(), K,
N, input1->data<T>(),
0, N,
output->data<T>(), 0,
N, output->data<T>(),
&context.device_context()); N,
&const_cast<platform::DeviceContext&>(context.device_context()));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册