提交 97d8175a 编写于 作者: Q qijun

add global matmul function for Tensor

上级 bf740a3f
......@@ -103,6 +103,8 @@ class Tensor {
template <typename T>
inline Tensor Slice(const int& begin_idx, const int& end_idx) const;
platform::Place place() const { return holder_->place(); }
private:
template <typename T>
inline void check_memory_size() const;
......
......@@ -80,6 +80,99 @@ void gemm<platform::CPUPlace, double>(const CBLAS_TRANSPOSE transA,
ldc);
}
template <>
void matmul<platform::CPUPlace, float>(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) &&
platform::is_cpu_place(in2.place()) &&
platform::is_cpu_place(out->place()),
"Matrix must all be in CPUPlace");
int M = out_dim[0];
int N = out_dim[1];
int K = in1_dim[1];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, float>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<float>(),
K,
in2.data<float>(),
N,
beta,
out->data<float>(),
N,
context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(
in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) &&
platform::is_cpu_place(in2.place()) &&
platform::is_cpu_place(out->place()),
"Matrix must all be in CPUPlace");
int M = out_dim[0];
int N = out_dim[1];
int K = in1_dim[1];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::CPUPlace, double>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<double>(),
K,
in2.data<double>(),
N,
beta,
out->data<double>(),
N,
context);
}
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -98,6 +98,79 @@ void gemm<platform::GPUPlace, double>(const CBLAS_TRANSPOSE transA,
ldc));
}
template <>
void matmul<platform::GPUPlace, float>(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha,
framework::Tensor* out, float beta, platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace");
int M = out_dim[0];
int N = out_dim[1];
int K = in1_dim[1];
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, float>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<float>(),
K,
in2.data<float>(),
N,
beta,
out->data<float>(),
N,
context);
}
template <>
void matmul<platform::GPUPlace, double>(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha,
framework::Tensor* out, float beta, platform::DeviceContext* context) {
auto in1_dim = in1.dims();
auto in2_dim = in2.dims();
auto out_dim = out->dims();
PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2,
"The input and output of matmul be matrix");
PADDLE_ENFORCE(
in1_dim[1] == in2_dim[0],
"First matrix's width must be equal with second matrix's height.");
PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace");
CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans;
gemm<platform::GPUPlace, double>(in1_Trans,
in2_Trans,
M,
N,
K,
alpha,
in1.data<double>(),
K,
in2.data<double>(),
N,
beta,
out->data<double>(),
N,
context);
}
} // namespace math
} // namespace operators
......
......@@ -38,6 +38,7 @@ extern "C" {
#endif
#include <cmath>
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
namespace paddle {
......@@ -60,6 +61,17 @@ void gemm(const CBLAS_TRANSPOSE transA,
const int ldc,
platform::DeviceContext* context);
// matrix multiply with continous memory
template <typename Place, typename T>
void matmul(const framework::Tensor& in1,
bool in1_T,
const framework::Tensor& in2,
bool in2_T,
float alpha,
framework::Tensor* out,
float beta,
platform::DeviceContext* context);
} // namespace math
} // namespace operators
} // namespace paddle
......@@ -24,33 +24,20 @@ template <typename Place, typename T>
class MulKernel : public OpKernel {
public:
void Compute(const ExecutionContext& context) const override {
auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0);
auto* input0 = context.Input<Tensor>("X");
auto* input1 = context.Input<Tensor>("Y");
auto* output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
auto out_dim = output->dims();
auto in0_dim = input0->dims();
int M = out_dim[0];
int N = out_dim[1];
int K = in0_dim[1];
paddle::operators::math::template gemm<Place, T>(
CblasNoTrans,
CblasNoTrans,
M,
N,
K,
paddle::operators::math::template matmul<Place, T>(
*input0,
false,
*input1,
false,
1,
input0->data<T>(),
K,
input1->data<T>(),
N,
output,
0,
output->data<T>(),
N,
&const_cast<platform::DeviceContext&>(context.device_context()));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册