提交 08159359 编写于 作者: Q qijun

fix typo error

上级 090247dd
......@@ -19,74 +19,29 @@ namespace operators {
namespace math {
template <>
void gemm<platform::CPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const int lda,
const float* B,
const int ldb,
const float beta,
float* C,
const int ldc,
void gemm<platform::CPUPlace, float>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda,
const float* B, const int ldb, const float beta, float* C, const int ldc,
platform::DeviceContext* context) {
cblas_sgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const double alpha,
const double* A,
const int lda,
const double* B,
const int ldb,
const double beta,
double* C,
const int ldc,
platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor,
transA,
transB,
M,
N,
K,
alpha,
A,
lda,
B,
ldb,
beta,
C,
ldc);
void gemm<platform::CPUPlace, double>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const double alpha, const double* A,
const int lda, const double* B, const int ldb, const double beta, double* C,
const int ldc, platform::DeviceContext* context) {
cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb,
beta, C, ldc);
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1, bool in1_T,
const framework::Tensor& in2, bool in2_T,
float alpha, framework::Tensor* out,
float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
......@@ -111,30 +66,17 @@ void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<float>(),
K,
in2.data<float>(),
N,
beta,
out->data<float>(),
N,
context);
gemm<platform::CPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N,
beta, out->data<float>(), N, context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
void matmul<platform::CPUPlace, double>(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
bool in2_T, float alpha,
framework::Tensor* out, float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
......@@ -157,20 +99,9 @@ void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<double>(),
K,
in2.data<double>(),
N,
beta,
out->data<double>(),
N,
context);
gemm<platform::CPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N,
beta, out->data<double>(), N, context);
}
} // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册