diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h index c76fc17d78cce514b5e35ce8e5ca890d7cec1e98..d84c88cb3bc1a13acb83b3444dbd1bfca3cba503 100644 --- a/paddle/fluid/operators/math/blas_impl.cu.h +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -96,10 +96,22 @@ struct CUBlas { reinterpret_cast<__half *>(C), ldc)); } - template - static void GEMM_BATCH(ARGS... args) { + static void GEMM_BATCH(cublasHandle_t handle, cublasOperation_t transa, + cublasOperation_t transb, int m, int n, int k, + const float16 *alpha, const float16 *A, int lda, + long long int strideA, const float16 *B, // NOLINT + int ldb, long long int strideB, // NOLINT + const float16 *beta, float16 *C, int ldc, + long long int strideC, // NOLINT + int batchCount) { #if CUDA_VERSION >= 8000 - PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched(args...)); + PADDLE_ENFORCE(platform::dynload::cublasHgemmStridedBatched( + handle, transa, transb, m, n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, strideA, + reinterpret_cast(B), ldb, strideB, + reinterpret_cast(beta), reinterpret_cast<__half *>(C), + ldc, strideC, batchCount)); #else PADDLE_THROW("HgemmStridedBatched is not supported on cuda <= 7.5"); #endif diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h index 577cbe3beb806ffcb2f1a7d7a469402be9b69224..14b3624b420cb883b36268c0a5a9e8692dbb5b43 100644 --- a/paddle/fluid/operators/math/blas_impl.h +++ b/paddle/fluid/operators/math/blas_impl.h @@ -172,9 +172,9 @@ void Blas::BatchedGEMM( c_array.data(), &ldc, 1 /* group_count */, &batchCount); #else for (int k = 0; k < batchCount; ++k) { - const float *Ak = &A[k * strideA]; - const float *Bk = &B[k * strideB]; - float *Ck = &C[k * M * N]; + auto *Ak = &A[k * strideA]; + auto *Bk = &B[k * strideB]; + auto *Ck = &C[k * M * N]; this->template GEMM(transA, transB, M, N, K, alpha, Ak, Bk, beta, Ck); } #endif diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index b5bf84e5178c143de35ec6dcb16b1bde5577c166..d5af718723e8d44da0971ea7756b8c36e771cca2 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -33,9 +33,10 @@ template struct SetConstant; template struct SetConstant; template struct SetConstant; -#define DEFINE_GPU_TRANS(RANK) \ - template struct Transpose; \ - template struct Transpose; +#define DEFINE_GPU_TRANS(RANK) \ + template struct Transpose; \ + template struct Transpose; \ + template struct Transpose; DEFINE_GPU_TRANS(1); DEFINE_GPU_TRANS(2); diff --git a/paddle/fluid/operators/matmul_op.cc b/paddle/fluid/operators/matmul_op.cc index 2d05449822e6addc42c4fea5af8422ae6dcfd37d..7182149164854038bb67a9f06cdbec8a4a0f1fb2 100644 --- a/paddle/fluid/operators/matmul_op.cc +++ b/paddle/fluid/operators/matmul_op.cc @@ -25,7 +25,7 @@ namespace operators { * Get row matrix shape from a vector shape. If the rank of x_dim > 1, the * original x_dim is returned. */ -static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { +static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) { if (x_dim.size() > 1) { return x_dim; } @@ -36,7 +36,7 @@ static framework::DDim RowMatrixFromVector(const framework::DDim& x_dim) { * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the * original y_dim is returned. */ -static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { +static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) { if (y_dim.size() > 1) { return y_dim; } @@ -46,12 +46,12 @@ static framework::DDim ColumnMatrixFromVector(const framework::DDim& y_dim) { template class MatMulKernel : public framework::OpKernel { public: - void Compute(const framework::ExecutionContext& context) const override { - auto& x = + void Compute(const framework::ExecutionContext &context) const override { + auto &x = detail::Ref(context.Input("X"), "Cannot find X"); - auto& y = + auto &y = detail::Ref(context.Input("Y"), "Cannot find Y"); - auto* out = context.Output("Out"); + auto *out = context.Output("Out"); out->mutable_data(context.GetPlace()); auto blas = math::GetBlas(context); @@ -65,7 +65,7 @@ class MatMulKernel : public framework::OpKernel { // Reshape a rank-3 tensor from P x M x N to (P * M) x N. // Identity op if the tensor is not of rank 3. -static framework::Tensor FoldInitDims(const framework::Tensor& input) { +static framework::Tensor FoldInitDims(const framework::Tensor &input) { auto output = input; auto in_dims = input.dims(); if (in_dims.size() == 3) { @@ -78,8 +78,8 @@ static framework::Tensor FoldInitDims(const framework::Tensor& input) { // (Warning: This requires transposing data and writes into new memory.) // Identity op if the tensor is not of rank 3. template -static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, - const framework::Tensor& input) { +static framework::Tensor FoldHeadAndLastDims(const DeviceContext &context, + const framework::Tensor &input) { auto in_dims = input.dims(); if (in_dims.size() != 3) { return input; @@ -102,7 +102,7 @@ static framework::Tensor FoldHeadAndLastDims(const DeviceContext& context, * If transposed, `H,W` will be swapped. */ static void ReshapeTensorIntoMatrixSequence( - framework::Tensor* x, const math::MatDescriptor& descriptor) { + framework::Tensor *x, const math::MatDescriptor &descriptor) { int64_t h, w; h = descriptor.height_; w = descriptor.width_; @@ -130,9 +130,9 @@ static void ReshapeTensorIntoMatrixSequence( * If any of `X` and `Y` has batch size BatchSize, the out will have the * BatchSize. */ -static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, - framework::Tensor* y, - framework::Tensor* out, bool trans_x, +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, + framework::Tensor *y, + framework::Tensor *out, bool trans_x, bool trans_y) { auto x_dim = RowMatrixFromVector(x->dims()); auto y_dim = ColumnMatrixFromVector(y->dims()); @@ -177,10 +177,10 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor* x, template class MatMulGradKernel : public framework::OpKernel { public: - void MatMul(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - const framework::Tensor& b, bool trans_b, - framework::Tensor* out) const { + void MatMul(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + const framework::Tensor &b, bool trans_b, + framework::Tensor *out) const { out->mutable_data(context.GetPlace()); auto blas = math::GetBlas(context); auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); @@ -188,18 +188,18 @@ class MatMulGradKernel : public framework::OpKernel { blas.MatMul(a, mat_dim_a, b, mat_dim_b, T(1), out, T(0)); } - void CalcInputGrad(const framework::ExecutionContext& context, - const framework::Tensor& a, bool trans_a, - bool is_fold_init_dims_a, const framework::Tensor& b, + void CalcInputGrad(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor &b, bool trans_b, bool is_fold_init_dims_b, - framework::Tensor* out) const { + framework::Tensor *out) const { if (out == nullptr) return; bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && out->dims().size() == 2; if (!need_combine) { MatMul(context, a, trans_a, b, trans_b, out); } else { - auto& ctx = context.template device_context(); + auto &ctx = context.template device_context(); MatMul(context, is_fold_init_dims_a ? FoldInitDims(a) : FoldHeadAndLastDims(ctx, a), @@ -210,13 +210,13 @@ class MatMulGradKernel : public framework::OpKernel { } } - void Compute(const framework::ExecutionContext& context) const override { + void Compute(const framework::ExecutionContext &context) const override { auto x = *context.Input("X"); auto y = *context.Input("Y"); auto dout = *context.Input(framework::GradVarName("Out")); - auto* dx = context.Output(framework::GradVarName("X")); - auto* dy = context.Output(framework::GradVarName("Y")); + auto *dx = context.Output(framework::GradVarName("X")); + auto *dy = context.Output(framework::GradVarName("Y")); bool transpose_x = context.Attr("transpose_X"); bool transpose_y = context.Attr("transpose_Y"); @@ -269,7 +269,7 @@ class MatMulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* context) const override { + void InferShape(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("X"), "Input(X) of MatMulOp should not be null."); PADDLE_ENFORCE(context->HasInput("Y"), @@ -375,7 +375,7 @@ class MatMulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContext* context) const override { + void InferShape(framework::InferShapeContext *context) const override { PADDLE_ENFORCE(context->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(context->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(context->HasInput(framework::GradVarName("Out")), @@ -401,7 +401,7 @@ class MatMulOpGradMaker : public framework::SingleGradOpDescMaker { protected: std::unique_ptr Apply() const override { - auto* retv = new framework::OpDesc(); + auto *retv = new framework::OpDesc(); retv->SetType("matmul_grad"); retv->SetInput("X", Input("X")); retv->SetInput("Y", Input("Y")); @@ -420,15 +420,27 @@ REGISTER_OPERATOR(matmul, ops::MatMulOp, ops::MatMulOpMaker, ops::MatMulOpGradMaker); REGISTER_OPERATOR(matmul_grad, ops::MatMulOpGrad); REGISTER_OP_CPU_KERNEL( - matmul, ops::MatMulKernel); + matmul, ops::MatMulKernel, + ops::MatMulKernel, + ops::MatMulKernel); REGISTER_OP_CPU_KERNEL( matmul_grad, - ops::MatMulGradKernel); + ops::MatMulGradKernel, + ops::MatMulGradKernel, + ops::MatMulGradKernel); #ifdef PADDLE_WITH_CUDA REGISTER_OP_CUDA_KERNEL( - matmul, ops::MatMulKernel); + matmul, ops::MatMulKernel, + ops::MatMulKernel, + ops::MatMulKernel); REGISTER_OP_CUDA_KERNEL( matmul_grad, - ops::MatMulGradKernel); + ops::MatMulGradKernel, + ops::MatMulGradKernel, + ops::MatMulGradKernel); #endif