提交 5703eb50 编写于 作者: Q qijun

add .clang-format file

上级 01a198a5
---
Language: Cpp
BasedOnStyle: Google
Standard: Cpp11
...
......@@ -14,25 +14,15 @@ limitations under the License. */
#include "paddle/operators/math/math_function.h"
namespace paddle {
namespace operators {
namespace math {
template <>
void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const float alpha,
const float* A,
const int lda,
const float* B,
const int ldb,
const float beta,
float* C,
const int ldc,
void gemm<platform::GPUPlace, float>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const float alpha, const float* A, const int lda,
const float* B, const int ldb, const float beta, float* C, const int ldc,
platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
......@@ -42,38 +32,16 @@ void gemm<platform::GPUPlace, float>(const CBLAS_TRANSPOSE transA,
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasSgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
ldc));
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <>
void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
const CBLAS_TRANSPOSE transB,
const int M,
const int N,
const int K,
const double alpha,
const double* A,
const int lda,
const double* B,
const int ldb,
const double beta,
double* C,
const int ldc,
platform::DeviceContext* context) {
void gemm<platform::GPUPlace, double>(
const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M,
const int N, const int K, const double alpha, const double* A,
const int lda, const double* B, const int ldb, const double beta, double* C,
const int ldc, platform::DeviceContext* context) {
// Note that cublas follows fortran order, so the order is different from
// the cblas convention.
cublasOperation_t cuTransA =
......@@ -81,36 +49,30 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
cublasOperation_t cuTransB =
(transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
PADDLE_ENFORCE(platform::dynload::cublasDgemm(
reinterpret_cast<platform::CUDADeviceContext*>(context)->
cublas_handle(),
cuTransB,
cuTransA,
N,
M,
K,
&alpha,
B,
ldb,
A,
lda,
&beta,
C,
ldc));
reinterpret_cast<platform::CUDADeviceContext*>(context)->cublas_handle(),
cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, lda, &beta, C, ldc));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha,
framework::Tensor* out, float beta, platform::DeviceContext* context) {
void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T,
const framework::Tensor& in2, bool in2_T,
float alpha, framework::Tensor* out,
float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) &&
platform::is_gpu_place(in2.place()) &&
platform::is_gpu_place(out->place()),
"Matrix must all be in GPUPlace");
int M = out_dim[0];
int N = out_dim[1];
......@@ -119,59 +81,44 @@ framework::Tensor* out, float beta, platform::DeviceContext* context) {
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<float>(),
K,
in2.data<float>(),
N,
beta,
out->data<float>(),
N,
context);
gemm<platform::GPUPlace, float>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<float>(), K, in2.data<float>(), N,
beta, out->data<float>(), N, context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha,
framework::Tensor* out, float beta, platform::DeviceContext* context) {
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T, float alpha,
framework::Tensor* out, float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) &&
platform::is_gpu_place(in2.place()) &&
platform::is_gpu_place(out->place()),
"Matrix must all be in GPUPlace");
int M = out_dim[0];
int N = out_dim[1];
int K = in1_dim[1];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<double>(),
K,
in2.data<double>(),
N,
beta,
out->data<double>(),
N,
context);
gemm<platform::GPUPlace, double>(in1_Trans, in2_Trans, M, N, K, alpha,
in1.data<double>(), K, in2.data<double>(), N,
beta, out->data<double>(), N, context);
}
} // namespace math
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册