提交 0a13d3c6 编写于 作者: Y Yu Yang

Move MatMul to blas_impl.h

Rename MatDim to MatDescriptor
上级 3dd01823
...@@ -18,8 +18,9 @@ ...@@ -18,8 +18,9 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) { MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols,
MatDim retv; bool trans) {
MatDescriptor retv;
if (num_flatten_cols > 1) { if (num_flatten_cols > 1) {
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols); auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
retv.height_ = flatten_dim[0]; retv.height_ = flatten_dim[0];
......
...@@ -46,7 +46,7 @@ namespace paddle { ...@@ -46,7 +46,7 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
struct MatDim { struct MatDescriptor {
int64_t height_; int64_t height_;
int64_t width_; int64_t width_;
int64_t stride_{0}; int64_t stride_{0};
...@@ -54,8 +54,8 @@ struct MatDim { ...@@ -54,8 +54,8 @@ struct MatDim {
bool trans_; bool trans_;
}; };
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols, extern MatDescriptor GetMatDim(const framework::DDim& tensor,
bool trans); int num_flatten_cols, bool trans);
template <typename DeviceContext> template <typename DeviceContext>
class Blas { class Blas {
...@@ -102,26 +102,9 @@ class Blas { ...@@ -102,26 +102,9 @@ class Blas {
int batchCount, int64_t strideA, int64_t strideB) const; int batchCount, int64_t strideA, int64_t strideB) const;
template <typename T> template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a, void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha, const framework::Tensor& mat_b, const MatDescriptor& dim_b,
framework::Tensor* mat_out, T beta) const { T alpha, framework::Tensor* mat_out, T beta) const;
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
private: private:
const DeviceContext& context_; const DeviceContext& context_;
......
...@@ -180,6 +180,31 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM( ...@@ -180,6 +180,31 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif #endif
} }
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const MatDescriptor &dim_a,
const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -91,7 +91,7 @@ inline framework::Tensor CombineBatchAndN(const DeviceContext& context, ...@@ -91,7 +91,7 @@ inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
} }
inline void NormalizeTensorShape(framework::Tensor* x, inline void NormalizeTensorShape(framework::Tensor* x,
const math::MatDim& mat_dim_x) { const math::MatDescriptor& mat_dim_x) {
int64_t h, w; int64_t h, w;
h = mat_dim_x.height_; h = mat_dim_x.height_;
w = mat_dim_x.width_; w = mat_dim_x.width_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册