提交 08159359 编写于 作者: Q qijun

fix typo error

上级 090247dd
...@@ -19,74 +19,29 @@ namespace operators { ...@@ -19,74 +19,29 @@ namespace operators {
namespace math { namespace math {
template <> template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA, void gemm<platform::CPUPlace, float>(
const CBLAS_TRANSPOSE transB, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int M, const int N, const int K, const float alpha, const float* A, const int lda,
const int N, const float* B, const int ldb, const float beta, float* C, const int ldc,
const int K, platform::DeviceContext* context) {
const float alpha, cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
const float* A, beta, C, ldc);
const int lda,
const float* B,
const int ldb,
const float beta,
float* C,
const int ldc,
platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
} }
template <> template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA, void gemm<platform::CPUPlace, double>(
const CBLAS_TRANSPOSE transB, const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int M, const int N, const int K, const double alpha, const double* A,
const int N, const int lda, const double* B, const int ldb, const double beta, double* C,
const int K, const int ldc, platform::DeviceContext* context) {
const double alpha, cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
const double* A, beta, C, ldc);
const int lda,
const double* B,
const int ldb,
const double beta,
double* C,
const int ldc,
platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
} }
template <> template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
bool in1_T, const framework::Tensor& in2, bool in2_T,
const framework::Tensor& in2, float alpha, framework::Tensor* out,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta, float beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto in1_dim = in1.dims();
...@@ -111,30 +66,17 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, ...@@ -111,30 +66,17 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans, gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in2_Trans, in1.data<float>(), K, in2.data<float>(), N,
M, beta, out->data<float>(), N, context);
N,
K,
alpha,
in1.data<float>(),
K,
in2.data<float>(),
N,
beta,
out->data<float>(),
N,
context);
} }
template <> template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
bool in1_T, bool in1_T,
const framework::Tensor& in2, const framework::Tensor& in2,
bool in2_T, bool in2_T, float alpha,
float alpha, framework::Tensor* out, float beta,
framework::Tensor* out,
float beta,
platform::DeviceContext* context) { platform::DeviceContext* context) {
auto in1_dim = in1.dims(); auto in1_dim = in1.dims();
auto in2_dim = in2.dims(); auto in2_dim = in2.dims();
...@@ -157,20 +99,9 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, ...@@ -157,20 +99,9 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans, gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in2_Trans, in1.data<double>(), K, in2.data<double>(), N,
M, beta, out->data<double>(), N, context);
N,
K,
alpha,
in1.data<double>(),
K,
in2.data<double>(),
N,
beta,
out->data<double>(),
N,
context);
} }
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册