提交 de967fce 编写于 作者: Q qijun

set gemm support continuous memory now

上级 7eb07b33
...@@ -19,21 +19,30 @@ namespace operators { ...@@ -19,21 +19,30 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::CPUPlace, float>( void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda, const int N, const int K,
const float* B, const int ldb, const float beta, float* C, const int ldc, const float alpha, const float* A,
platform::DeviceContext* context) { const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
} }
template <> template <>
void gemm<platform::CPUPlace, double>( void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const double alpha, const double* A, const int N, const int K,
const int lda, const double* B, const int ldb, const double beta, double* C, const double alpha, const double* A,
const int ldc, platform::DeviceContext* context) { const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc); beta, C, ldc);
} }
...@@ -67,8 +76,8 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -67,8 +76,8 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N, in1.data<float>(), in2.data<float>(), beta,
beta, out->data<float>(), N, context); out->data<float>(), context);
} }
template <> template <>
...@@ -100,8 +109,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1, ...@@ -100,8 +109,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N, in1.data<double>(), in2.data<double>(), beta,
beta, out->data<double>(), N, context); out->data<double>(), context);
} }
} // namespace math } // namespace math
......
...@@ -18,14 +18,16 @@ namespace operators { ...@@ -18,14 +18,16 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::GPUPlace, float>( void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda, const int N, const int K,
const float* B, const int ldb, const float beta, float* C, const int ldc, const float alpha, const float* A,
platform::DeviceContext* context) { const float* B, const float beta, float* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
/* int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -34,8 +36,6 @@ void gemm<platform::GPUPlace, float>( ...@@ -34,8 +36,6 @@ void gemm<platform::GPUPlace, float>(
PADDLE_ENFORCE(platform::dynload::cublasSgemm( PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
} }
template <> template <>
...@@ -46,7 +46,8 @@ void gemm<platform::GPUPlace, double>( ...@@ -46,7 +46,8 @@ void gemm<platform::GPUPlace, double>(
const int ldc, platform::DeviceContext* context) { const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from // Note that cublas follows fortran order, so the order is different from
// the cblas convention. // the cblas convention.
/* int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA = cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB = cublasOperation_t cuTransB =
...@@ -54,8 +55,6 @@ void gemm<platform::GPUPlace, double>( ...@@ -54,8 +55,6 @@ void gemm<platform::GPUPlace, double>(
PADDLE_ENFORCE(platform::dynload::cublasDgemm( PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(), reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc)); cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
} }
template <> template <>
...@@ -87,8 +86,8 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, ...@@ -87,8 +86,8 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::GPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N, in1.data<float>(), in2.data<float>(), beta,
beta, out->data<float>(), N, context); out->data<float>(), context);
} }
template <> template <>
...@@ -120,8 +119,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, ...@@ -120,8 +119,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha, gemm<platform::GPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N, in1.data<double>(), in2.data<double>(), beta,
beta, out->data<double>(), N, context); out->data<double>(), context);
} }
} // namespace math } // namespace math
} // namespace operators } // namespace operators
......
...@@ -60,11 +60,11 @@ namespace paddle { ...@@ -60,11 +60,11 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
// support continuous memory now
template <typename Place, typename T> template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A, const int M, const int N, const int K, const T alpha, const T* A,
const int lda, const T* B, const int ldb, const T beta, T* C, const T* B, const T beta, T* C, platform::DeviceContext* context);
const int ldc, platform::DeviceContext* context);
// matrix multiply with continuous memory // matrix multiply with continuous memory
template <typename Place, typename T> template <typename Place, typename T>
......
...@@ -16,5 +16,4 @@ ...@@ -16,5 +16,4 @@
#include "paddle/operators/mul_op.h" #include "paddle/operators/mul_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
// float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册