提交 de967fce 编写于 作者: Q qijun

set gemm support continuous memory now

上级 7eb07b33
......@@ -19,21 +19,30 @@ namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda,
const float* B, const int ldb, const float beta, float* C, const int ldc,
platform::DeviceContext* context) {
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const double alpha, const double* A,
const int lda, const double* B, const int ldb, const double beta, double* C,
const int ldc, platform::DeviceContext* context) {
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const double alpha, const double* A,
const double* B, const double beta,
double* C,
platform::DeviceContext* context) {
int lda = K;
int ldb = N;
int ldc = N;
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
......@@ -67,8 +76,8 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N,
beta, out->data<float>(), N, context);
in1.data<float>(), in2.data<float>(), beta,
out->data<float>(), context);
}
template <>
......@@ -100,8 +109,8 @@ void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N,
beta, out->data<double>(), N, context);
in1.data<double>(), in2.data<double>(), beta,
out->data<double>(), context);
}
} // namespace math
......
......@@ -18,14 +18,16 @@ namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda,
const float* B, const int ldb, const float beta, float* C, const int ldc,
platform::DeviceContext* context) {
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K,
const float alpha, const float* A,
const float* B, const float beta, float* C,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
/*
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
......@@ -34,8 +36,6 @@ void gemm<platform::GPUPlace, float>(
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
}
template <>
......@@ -46,7 +46,8 @@ void gemm<platform::GPUPlace, double>(
const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
/*
int lda = (transA == CblasNoTrans) ? K : M;
int ldb = (transB == CblasNoTrans) ? N : K;
cublasOperation_t cuTransA =
(transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
cublasOperation_t cuTransB =
......@@ -54,8 +55,6 @@ void gemm<platform::GPUPlace, double>(
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
*/
PADDLE_THROW("not implemented now");
}
template <>
......@@ -87,8 +86,8 @@ void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N,
beta, out->data<float>(), N, context);
in1.data<float>(), in2.data<float>(), beta,
out->data<float>(), context);
}
template <>
......@@ -120,8 +119,8 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N,
beta, out->data<double>(), N, context);
in1.data<double>(), in2.data<double>(), beta,
out->data<double>(), context);
}
} // namespace math
} // namespace operators
......
......@@ -60,11 +60,11 @@ namespace paddle {
namespace operators {
namespace math {
// support continuous memory now
template <typename Place, typename T>
void gemm(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB,
const int M, const int N, const int K, const T alpha, const T* A,
const int lda, const T* B, const int ldb, const T beta, T* C,
const int ldc, platform::DeviceContext* context);
const T* B, const T beta, T* C, platform::DeviceContext* context);
// matrix multiply with continuous memory
template <typename Place, typename T>
......
......@@ -16,5 +16,4 @@
#include "paddle/operators/mul_op.h"
namespace ops = paddle::operators;
// REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace,
// float>);
REGISTER_OP_GPU_KERNEL(mul, ops::MulKernel<paddle::platform::GPUPlace, float>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册