From 0a13d3c67a7be24cb541b964b99b42ce851de618 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 8 May 2018 11:16:36 +0800 Subject: [PATCH] Move MatMul to blas_impl.h Rename MatDim to MatDescriptor --- paddle/fluid/operators/math/blas.cc | 5 +++-- paddle/fluid/operators/math/blas.h | 29 +++++-------------------- paddle/fluid/operators/math/blas_impl.h | 25 +++++++++++++++++++++ paddle/fluid/operators/matmul_op.h | 2 +- 4 files changed, 35 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/math/blas.cc b/paddle/fluid/operators/math/blas.cc index 7427ceac6..3e09ef7a2 100644 --- a/paddle/fluid/operators/math/blas.cc +++ b/paddle/fluid/operators/math/blas.cc @@ -18,8 +18,9 @@ namespace paddle { namespace operators { namespace math { -MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) { - MatDim retv; +MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols, + bool trans) { + MatDescriptor retv; if (num_flatten_cols > 1) { auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols); retv.height_ = flatten_dim[0]; diff --git a/paddle/fluid/operators/math/blas.h b/paddle/fluid/operators/math/blas.h index cca967f33..0c0794125 100644 --- a/paddle/fluid/operators/math/blas.h +++ b/paddle/fluid/operators/math/blas.h @@ -46,7 +46,7 @@ namespace paddle { namespace operators { namespace math { -struct MatDim { +struct MatDescriptor { int64_t height_; int64_t width_; int64_t stride_{0}; @@ -54,8 +54,8 @@ struct MatDim { bool trans_; }; -extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols, - bool trans); +extern MatDescriptor GetMatDim(const framework::DDim& tensor, + int num_flatten_cols, bool trans); template class Blas { @@ -102,26 +102,9 @@ class Blas { int batchCount, int64_t strideA, int64_t strideB) const; template - void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a, - const framework::Tensor& mat_b, const MatDim& dim_b, T alpha, - framework::Tensor* mat_out, T beta) const { - PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); - CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; - CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; - if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { - this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, - dim_a.width_, alpha, mat_a.data(), - mat_b.data(), beta, mat_out->data()); - } else { - PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || - dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); - this->template BatchedGEMM( - transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, - mat_a.data(), mat_b.data(), beta, mat_out->data(), - dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, - dim_a.stride_, dim_b.stride_); - } - } + void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a, + const framework::Tensor& mat_b, const MatDescriptor& dim_b, + T alpha, framework::Tensor* mat_out, T beta) const; private: const DeviceContext& context_; diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 7360cc0a9..577cbe3be 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -180,6 +180,31 @@ void Blas::BatchedGEMM( #endif } +template +template +void Blas::MatMul(const framework::Tensor &mat_a, + const MatDescriptor &dim_a, + const framework::Tensor &mat_b, + const MatDescriptor &dim_b, T alpha, + framework::Tensor *mat_out, T beta) const { + PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_); + CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans; + if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) { + this->template GEMM(transA, transB, dim_a.height_, dim_b.width_, + dim_a.width_, alpha, mat_a.data(), + mat_b.data(), beta, mat_out->data()); + } else { + PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ || + dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0); + this->template BatchedGEMM( + transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha, + mat_a.data(), mat_b.data(), beta, mat_out->data(), + dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_, + dim_a.stride_, dim_b.stride_); + } +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/matmul_op.h b/paddle/fluid/operators/matmul_op.h index 7b484d124..9bf39026f 100644 --- a/paddle/fluid/operators/matmul_op.h +++ b/paddle/fluid/operators/matmul_op.h @@ -91,7 +91,7 @@ inline framework::Tensor CombineBatchAndN(const DeviceContext& context, } inline void NormalizeTensorShape(framework::Tensor* x, - const math::MatDim& mat_dim_x) { + const math::MatDescriptor& mat_dim_x) { int64_t h, w; h = mat_dim_x.height_; w = mat_dim_x.width_; -- GitLab