提交 f190a795 编写于 作者: Q qijun

fix gpu build error

上级 22dac40c
......@@ -32,7 +32,7 @@ void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const float beta,
float* C,
const int ldc,
const platform::DeviceContext* context) {
platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor,
transA,
transB,
......@@ -63,7 +63,7 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta,
double* C,
const int ldc,
const platform::DeviceContext* context) {
platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor,
transA,
transB,
......@@ -80,42 +80,6 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
ldc);
}
template <>
void axpy<platform::CPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
cblas_saxpy(n, alpha, x, 1, y, 1);
}
template <>
void axpy<platform::CPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
cblas_daxpy(n, alpha, x, 1, y, 1);
}
template <>
float dotProduct<platform::CPUPlace, float>(
const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
return cblas_sdot(n, x, 1, y, 1);
}
template <>
double dotProduct<platform::CPUPlace, double>(
const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
return cblas_ddot(n, x, 1, y, 1);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -20,7 +20,7 @@ namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA,
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
......@@ -33,16 +33,16 @@ void gemm<platform::GPUPlace float>(const CBLAS_TRANSPOSE transA,
const float beta,
float* C,
const int ldc,
const platform::DeviceContext* context) {
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
......@@ -73,15 +73,15 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const double beta,
double* C,
const int ldc,
const platform::DeviceContext* context) {
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
(TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
(TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
......@@ -99,48 +99,6 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
}
template <>
void axpy<platform::GPUPlace, float>(const int n,
const float alpha,
const float* x,
float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
void axpy<platform::GPUPlace, double>(const int n,
const double alpha,
const double* x,
double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDaxpy(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), N, &alpha, X, 1, Y, 1));
}
template <>
float dotProduct<platform::GPUPlace, float>(const int n,
const float* x,
const float* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasSdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
template <>
double dotProduct<platform::GPUPlace, double>(const int n,
const double* x,
const double* y,
const platform::DeviceContext* context) {
CUBLAS_ENFORCE(platform::dynload::cublasDdot(
reinterpret_cast<const platform::CUDADeviceContext*>(context)->
cublas_handle(), n, a, 1, b, 1, &result));
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -58,20 +58,7 @@ void gemm(const CBLAS_TRANSPOSE transA,
const T beta,
T* C,
const int ldc,
const platform::DeviceContext* context);
template <typename Place, typename T>
void axpy(const int n,
const T alpha,
const T* x,
T* y,
const platform::DeviceContext* context);
template <typename Place, typename T>
T dotProduct(const int n,
const T* x,
const T* y,
const platform::DeviceContext* context);
platform::DeviceContext* context);
} // namespace math
} // namespace operators
......
......@@ -37,7 +37,8 @@ public:
int N = out_dim[1];
int K = in0_dim[1];
paddle::operators::math::template gemm<Place, T>(CblasNoTrans,
paddle::operators::math::template gemm<Place, T>(
CblasNoTrans,
CblasNoTrans,
M,
N,
......@@ -50,7 +51,7 @@ public:
0,
output->data<T>(),
N,
&context.device_context());
&const_cast<platform::DeviceContext&>(context.device_context()));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册