提交 0a13d3c6 编写于 作者: Y Yu Yang

Move MatMul to blas_impl.h

Rename MatDim to MatDescriptor
上级 3dd01823
......@@ -18,8 +18,9 @@
namespace paddle {
namespace operators {
namespace math {
MatDim GetMatDim(const framework::DDim& dim, int num_flatten_cols, bool trans) {
MatDim retv;
MatDescriptor GetMatDim(const framework::DDim& dim, int num_flatten_cols,
bool trans) {
MatDescriptor retv;
if (num_flatten_cols > 1) {
auto flatten_dim = framework::flatten_to_2d(dim, num_flatten_cols);
retv.height_ = flatten_dim[0];
......
......@@ -46,7 +46,7 @@ namespace paddle {
namespace operators {
namespace math {
struct MatDim {
struct MatDescriptor {
int64_t height_;
int64_t width_;
int64_t stride_{0};
......@@ -54,8 +54,8 @@ struct MatDim {
bool trans_;
};
extern MatDim GetMatDim(const framework::DDim& tensor, int num_flatten_cols,
bool trans);
extern MatDescriptor GetMatDim(const framework::DDim& tensor,
int num_flatten_cols, bool trans);
template <typename DeviceContext>
class Blas {
......@@ -102,26 +102,9 @@ class Blas {
int batchCount, int64_t strideA, int64_t strideB) const;
template <typename T>
void MatMul(const framework::Tensor& mat_a, const MatDim& dim_a,
const framework::Tensor& mat_b, const MatDim& dim_b, T alpha,
framework::Tensor* mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
void MatMul(const framework::Tensor& mat_a, const MatDescriptor& dim_a,
const framework::Tensor& mat_b, const MatDescriptor& dim_b,
T alpha, framework::Tensor* mat_out, T beta) const;
private:
const DeviceContext& context_;
......
......@@ -180,6 +180,31 @@ void Blas<platform::CPUDeviceContext>::BatchedGEMM(
#endif
}
template <typename DeviceContext>
template <typename T>
void Blas<DeviceContext>::MatMul(const framework::Tensor &mat_a,
const MatDescriptor &dim_a,
const framework::Tensor &mat_b,
const MatDescriptor &dim_b, T alpha,
framework::Tensor *mat_out, T beta) const {
PADDLE_ENFORCE_EQ(dim_a.width_, dim_b.height_);
CBLAS_TRANSPOSE transA = !dim_a.trans_ ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE transB = !dim_b.trans_ ? CblasNoTrans : CblasTrans;
if (dim_a.batch_size_ == 0 && dim_b.batch_size_ == 0) {
this->template GEMM<T>(transA, transB, dim_a.height_, dim_b.width_,
dim_a.width_, alpha, mat_a.data<T>(),
mat_b.data<T>(), beta, mat_out->data<T>());
} else {
PADDLE_ENFORCE(dim_a.batch_size_ == dim_b.batch_size_ ||
dim_a.batch_size_ == 0 || dim_b.batch_size_ == 0);
this->template BatchedGEMM<T>(
transA, transB, dim_a.height_, dim_b.width_, dim_a.width_, alpha,
mat_a.data<T>(), mat_b.data<T>(), beta, mat_out->data<T>(),
dim_a.batch_size_ == 0 ? dim_b.batch_size_ : dim_a.batch_size_,
dim_a.stride_, dim_b.stride_);
}
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -91,7 +91,7 @@ inline framework::Tensor CombineBatchAndN(const DeviceContext& context,
}
inline void NormalizeTensorShape(framework::Tensor* x,
const math::MatDim& mat_dim_x) {
const math::MatDescriptor& mat_dim_x) {
int64_t h, w;
h = mat_dim_x.height_;
w = mat_dim_x.width_;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册