From 97d8175a5e19dbd60ea55cb21640cd7187d60974 Mon Sep 17 00:00:00 2001 From: qijun Date: Mon, 7 Aug 2017 11:45:00 +0800 Subject: [PATCH] add global matmul function for Tensor --- paddle/framework/tensor.h | 2 + paddle/operators/math/math_function.cc | 93 ++++++++++++++++++++++++++ paddle/operators/math/math_function.cu | 73 ++++++++++++++++++++ paddle/operators/math/math_function.h | 12 ++++ paddle/operators/mul_op.h | 31 +++------ 5 files changed, 189 insertions(+), 22 deletions(-) diff --git a/paddle/framework/tensor.h b/paddle/framework/tensor.h index 4c3b14b83..2aac8a128 100644 --- a/paddle/framework/tensor.h +++ b/paddle/framework/tensor.h @@ -103,6 +103,8 @@ class Tensor { template inline Tensor Slice(const int& begin_idx, const int& end_idx) const; + platform::Place place() const { return holder_->place(); } + private: template inline void check_memory_size() const; diff --git a/paddle/operators/math/math_function.cc b/paddle/operators/math/math_function.cc index c678b3761..1bfbc7557 100644 --- a/paddle/operators/math/math_function.cc +++ b/paddle/operators/math/math_function.cc @@ -80,6 +80,99 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc); } +template <> +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE( + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && + platform::is_cpu_place(in2.place()) && + platform::is_cpu_place(out->place()), + "Matrix must all be in CPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); +} + +template <> +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE( + in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_cpu_place(in1.place()) && + platform::is_cpu_place(in2.place()) && + platform::is_cpu_place(out->place()), + "Matrix must all be in CPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); +} + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/math/math_function.cu b/paddle/operators/math/math_function.cu index 190312e59..e1ac85608 100644 --- a/paddle/operators/math/math_function.cu +++ b/paddle/operators/math/math_function.cu @@ -98,6 +98,79 @@ void gemm(const CBLAS_TRANSPOSE transA, ldc)); } +template <> +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, +framework::Tensor* out, float beta, platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + + int M = out_dim[0]; + int N = out_dim[1]; + int K = in1_dim[1]; + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); + +} + + +template <> +void matmul(const framework::Tensor& in1, bool in1_T, const framework::Tensor& in2, bool in2_T, float alpha, +framework::Tensor* out, float beta, platform::DeviceContext* context) { + auto in1_dim = in1.dims(); + auto in2_dim = in2.dims(); + auto out_dim = out->dims(); + PADDLE_ENFORCE(in1_dim.size() == 2 && in2_dim.size() == 2 && out_dim.size() == 2, + "The input and output of matmul be matrix"); + PADDLE_ENFORCE( + in1_dim[1] == in2_dim[0], + "First matrix's width must be equal with second matrix's height."); + + PADDLE_ENFORCE(platform::is_gpu_place(in1.place()) && platform::is_gpu_place(in2.place())&& platform::is_gpu_place(out->place()), "Matrix must all be in GPUPlace"); + + CBLAS_TRANSPOSE in1_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + CBLAS_TRANSPOSE in2_Trans = (in1_T == false) ? CblasNoTrans : CblasTrans; + + gemm(in1_Trans, + in2_Trans, + M, + N, + K, + alpha, + in1.data(), + K, + in2.data(), + N, + beta, + out->data(), + N, + context); + +} + } // namespace math } // namespace operators diff --git a/paddle/operators/math/math_function.h b/paddle/operators/math/math_function.h index f1f87ac5f..f068f4a15 100644 --- a/paddle/operators/math/math_function.h +++ b/paddle/operators/math/math_function.h @@ -38,6 +38,7 @@ extern "C" { #endif #include +#include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" namespace paddle { @@ -60,6 +61,17 @@ void gemm(const CBLAS_TRANSPOSE transA, const int ldc, platform::DeviceContext* context); +// matrix multiply with continous memory +template +void matmul(const framework::Tensor& in1, + bool in1_T, + const framework::Tensor& in2, + bool in2_T, + float alpha, + framework::Tensor* out, + float beta, + platform::DeviceContext* context); + } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/mul_op.h b/paddle/operators/mul_op.h index 0bffe79a1..d5d8e220a 100644 --- a/paddle/operators/mul_op.h +++ b/paddle/operators/mul_op.h @@ -24,33 +24,20 @@ template class MulKernel : public OpKernel { public: void Compute(const ExecutionContext& context) const override { - auto input0 = context.Input("X"); - auto input1 = context.Input("Y"); - auto output = context.Output(0); + auto* input0 = context.Input("X"); + auto* input1 = context.Input("Y"); + auto* output = context.Output(0); output->mutable_data(context.GetPlace()); - auto out_dim = output->dims(); - auto in0_dim = input0->dims(); - - int M = out_dim[0]; - int N = out_dim[1]; - int K = in0_dim[1]; - - paddle::operators::math::template gemm( - CblasNoTrans, - CblasNoTrans, - M, - N, - K, + paddle::operators::math::template matmul( + *input0, + false, + *input1, + false, 1, - input0->data(), - K, - input1->data(), - N, + output, 0, - output->data(), - N, &const_cast(context.device_context())); } }; -- GitLab