From c888e01660ff1258352a537521d0c725d091e6df Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sat, 28 Apr 2018 17:21:29 +0800 Subject: [PATCH] Refactor GEMM in blas --- .../operators/bilinear_tensor_product_op.h | 23 ++- paddle/fluid/operators/gru_unit_op.h | 52 +++--- paddle/fluid/operators/math/blas_impl.cu.h | 145 ++++++++++++++++ paddle/fluid/operators/math/blas_impl.h | 68 ++++++++ paddle/fluid/operators/math/gru_compute.cc | 50 +++--- paddle/fluid/operators/math/gru_compute.cu | 51 +++--- paddle/fluid/operators/math/math_function.cc | 82 +-------- paddle/fluid/operators/math/math_function.cu | 163 +----------------- paddle/fluid/operators/math/math_function.h | 53 +++++- paddle/fluid/operators/math/matmul.h | 5 +- 10 files changed, 357 insertions(+), 335 deletions(-) create mode 100644 paddle/fluid/operators/math/blas_impl.cu.h create mode 100644 paddle/fluid/operators/math/blas_impl.h diff --git a/paddle/fluid/operators/bilinear_tensor_product_op.h b/paddle/fluid/operators/bilinear_tensor_product_op.h index ca80e6085c4..7191711a731 100644 --- a/paddle/fluid/operators/bilinear_tensor_product_op.h +++ b/paddle/fluid/operators/bilinear_tensor_product_op.h @@ -61,9 +61,9 @@ class BilinearTensorProductKernel : public framework::OpKernel { auto output_col_vec = output_mat.chip(i, 1); Tensor weight_mat = weight->Slice(i, i + 1).Resize(framework::make_ddim({x_dim, y_dim})); - math::gemm(dev_ctx, CblasNoTrans, CblasNoTrans, - batch_size, y_dim, x_dim, 1, x->data(), - weight_mat.data(), 0, left_mul.data()); + math::GetBlas(dev_ctx).GEMM( + CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, x->data(), + weight_mat.data(), 0, left_mul.data()); output_col_vec.device(place) = (left_mul_mat * y_mat).sum(Eigen::DSizes(1)); } @@ -125,6 +125,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { set_zero(dev_ctx, d_y, static_cast(0)); } + auto blas = math::GetBlas(ctx); + // Caculate the Output(X@Grad) and Output(Y@Grad). if (d_x || d_y) { Eigen::DSizes bcast_for_x(1, y_dim); @@ -138,18 +140,16 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_x) * y_mat; - math::gemm( - dev_ctx, CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, - y_scale.data(), weight_i.data(), 1, d_x->data()); + blas.GEMM(CblasNoTrans, CblasTrans, batch_size, x_dim, y_dim, 1, + y_scale.data(), weight_i.data(), 1, d_x->data()); } if (d_y) { x_scale_mat.device(place) = output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_y) * x_mat; - math::gemm( - dev_ctx, CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, - x_scale.data(), weight_i.data(), 1, d_y->data()); + blas.GEMM(CblasNoTrans, CblasNoTrans, batch_size, y_dim, x_dim, 1, + x_scale.data(), weight_i.data(), 1, d_y->data()); } } } @@ -166,9 +166,8 @@ class BilinearTensorProductGradKernel : public framework::OpKernel { output_vec.reshape(Eigen::DSizes(batch_size, 1)) .broadcast(bcast_for_weight) * x_mat; - math::gemm(dev_ctx, CblasTrans, CblasNoTrans, x_dim, - y_dim, batch_size, 1, x_scale.data(), - y->data(), 0, d_weight_i.data()); + blas.GEMM(CblasTrans, CblasNoTrans, x_dim, y_dim, batch_size, 1, + x_scale.data(), y->data(), 0, d_weight_i.data()); } } diff --git a/paddle/fluid/operators/gru_unit_op.h b/paddle/fluid/operators/gru_unit_op.h index 15d91ca3059..49e657a272c 100644 --- a/paddle/fluid/operators/gru_unit_op.h +++ b/paddle/fluid/operators/gru_unit_op.h @@ -87,10 +87,10 @@ class GRUUnitKernel : public framework::OpKernel { const T* weight_data = weight->data(); T* gate_data = gate->data(); T* reset_hidden_prev_data = reset_hidden_prev->data(); - math::gemm( - context.template device_context(), false, false, - batch_size, 2 * frame_size, frame_size, 1, hidden_prev_data, frame_size, - weight_data, frame_size * 2, 1, gate_data, frame_size * 3); + auto blas = math::GetBlas(context); + blas.GEMM(false, false, batch_size, 2 * frame_size, frame_size, 1, + hidden_prev_data, frame_size, weight_data, frame_size * 2, 1, + gate_data, frame_size * 3); // calculate activited gate Eigen::array extents({{batch_size, frame_size}}); @@ -103,11 +103,10 @@ class GRUUnitKernel : public framework::OpKernel { g.slice(r_offsets, extents), g.slice(r_offsets, extents)); auto r = g.slice(r_offsets, extents); // reset gate r_h_p.device(place) = r * h_p; // reset previous hidden state - math::gemm( - context.template device_context(), false, false, - batch_size, frame_size, frame_size, 1, reset_hidden_prev_data, - frame_size, weight_data + frame_size * frame_size * 2, frame_size, 1, - gate_data + frame_size * 2, frame_size * 3); + blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, + reset_hidden_prev_data, frame_size, + weight_data + frame_size * frame_size * 2, frame_size, 1, + gate_data + frame_size * 2, frame_size * 3); Eigen::array c_offsets({{0, frame_size * 2}}); ActCompute(context.Attr("activation"), place, @@ -188,11 +187,11 @@ class GRUUnitGradKernel : public framework::OpKernel { ActGradCompute(context.Attr("activation"), place, c, c, d_g.slice(c_offsets, extents), d_h * u); // backward for reset_hidden_prev - math::gemm( - context.template device_context(), false, true, - batch_size, frame_size, frame_size, 1, gate_grad_data + frame_size * 2, - frame_size * 3, weight_data + frame_size * frame_size * 2, frame_size, - 0, reset_hidden_prev_grad_data, frame_size); + auto blas = math::GetBlas(context); + blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, + gate_grad_data + frame_size * 2, frame_size * 3, + weight_data + frame_size * frame_size * 2, frame_size, 0, + reset_hidden_prev_grad_data, frame_size); // backward for unactivated reset gate ActGradCompute(context.Attr("gate_activation"), place, r, r, d_g.slice(r_offsets, extents), d_r_h_p * h_p); @@ -200,18 +199,15 @@ class GRUUnitGradKernel : public framework::OpKernel { if (weight_grad) { T* weight_grad_data = weight_grad->mutable_data(context.GetPlace()); // backward for state_weight - math::gemm( - context.template device_context(), true, false, - frame_size, frame_size, batch_size, 1, reset_hidden_prev_data, - frame_size, gate_grad_data + frame_size * 2, frame_size * 3, 0, - weight_grad_data + frame_size * frame_size * 2, frame_size); + blas.GEMM(true, false, frame_size, frame_size, batch_size, 1, + reset_hidden_prev_data, frame_size, + gate_grad_data + frame_size * 2, frame_size * 3, 0, + weight_grad_data + frame_size * frame_size * 2, frame_size); // backward for update_gate_weight and reset_gate_weight - math::gemm( - context.template device_context(), true, false, - frame_size, frame_size * 2, batch_size, 1, hidden_prev_data, - frame_size, gate_grad_data, frame_size * 3, 0, weight_grad_data, - frame_size * 2); + blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1, + hidden_prev_data, frame_size, gate_grad_data, frame_size * 3, 0, + weight_grad_data, frame_size * 2); } // backward for hidden_prev if (hidden_prev_grad) { @@ -219,11 +215,9 @@ class GRUUnitGradKernel : public framework::OpKernel { hidden_prev_grad->mutable_data(context.GetPlace()); auto d_h_p = EigenMatrix::From(*hidden_prev_grad); d_h_p.device(place) = d_r_h_p * r + d_h * (u.constant(T(1)) - u); - math::gemm( - context.template device_context(), false, true, - batch_size, frame_size, frame_size * 2, 1, gate_grad_data, - frame_size * 3, weight_data, frame_size * 2, 1, hidden_prev_grad_data, - frame_size); + blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, + gate_grad_data, frame_size * 3, weight_data, frame_size * 2, 1, + hidden_prev_grad_data, frame_size); } // backward for input if (input_grad) { diff --git a/paddle/fluid/operators/math/blas_impl.cu.h b/paddle/fluid/operators/math/blas_impl.cu.h new file mode 100644 index 00000000000..b7bd8f1d042 --- /dev/null +++ b/paddle/fluid/operators/math/blas_impl.cu.h @@ -0,0 +1,145 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/dynload/cublas.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CUBlas; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasSgemm(args...)); + } +}; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasDgemm(args...)); + } +}; + +template <> +struct CUBlas { + template + static void GEMM(ARGS... args) { + PADDLE_ENFORCE(platform::dynload::cublasHgemm(args...)); + } +}; + +template <> +template +void Blas::GEMM(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, const int N, + const int K, const T alpha, + const T *A, const T *B, + const T beta, T *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, N); +} + +template <> +template <> +inline void Blas::GEMM( + const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, const int M, + const int N, const int K, const platform::float16 alpha, + const platform::float16 *A, const platform::float16 *B, + const platform::float16 beta, platform::float16 *C) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + + // TODO(kexinzhao): add processing code for compute capability < 53 case + PADDLE_ENFORCE_GE(context_.GetComputeCapability(), 53, + "cublas fp16 gemm requires GPU compute capability >= 53"); + +#if CUDA_VERSION >= 8000 + float h_alpha = static_cast(alpha); + float h_beta = static_cast(beta); + + cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; +#if CUDA_VERSION >= 9000 + if (context_.GetComputeCapability() >= 70) { + PADDLE_ENFORCE(platform::dynload::cublasSetMathMode( + context_.cublas_handle(), CUBLAS_TENSOR_OP_MATH)); + algo = CUBLAS_GEMM_DFALT_TENSOR_OP; + } else { + PADDLE_ENFORCE(platform::dynload::cublasSetMathMode( + context_.cublas_handle(), CUBLAS_DEFAULT_MATH)); + } +#endif // CUDA_VERSION >= 9000 + + // cublasHgemm does true FP16 computation which is slow for non-Volta + // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: + // input/output in fp16, computation in fp32, which can also be accelerated + // using tensor cores in volta GPUs. + PADDLE_ENFORCE(platform::dynload::cublasGemmEx( + context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B, + CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, + CUDA_R_32F, algo)); +#else + // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm + const half h_alpha = static_cast(alpha); + const half h_beta = static_cast(beta); + const half *h_A = reinterpret_cast(A); + const half *h_B = reinterpret_cast(B); + half *h_C = reinterpret_cast(C); + + CUBlas(context_.cublas_handle(), cuTransB, cuTransA, N, M, + K, &h_alpha, h_B, ldb, h_A, lda, &h_beta, h_C, N); +#endif // CUDA_VERSION >= 8000 +} + +template <> +template +void Blas::GEMM( + const bool transA, const bool transB, const int M, const int N, const int K, + const T alpha, const T *A, const int lda, const T *B, const int ldb, + const T beta, T *C, const int ldc) const { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; + CUBlas::GEMM(context_.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, + B, ldb, A, lda, &beta, C, ldc); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/blas_impl.h b/paddle/fluid/operators/math/blas_impl.h new file mode 100644 index 00000000000..4934afd8bb1 --- /dev/null +++ b/paddle/fluid/operators/math/blas_impl.h @@ -0,0 +1,68 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { +namespace math { + +template +struct CBlas; + +template <> +struct CBlas { + static constexpr auto GEMM = cblas_sgemm; +}; + +template <> +struct CBlas { + static constexpr auto GEMM = cblas_dgemm; +}; + +template <> +struct CBlas { + void GEMM(...) { PADDLE_THROW("float16 GEMM not supported on CPU"); } +}; + +template <> +template +void Blas::GEMM(const CBLAS_TRANSPOSE transA, + const CBLAS_TRANSPOSE transB, + const int M, const int N, + const int K, const T alpha, + const T *A, const T *B, + const T beta, T *C) const { + int lda = (transA == CblasNoTrans) ? K : M; + int ldb = (transB == CblasNoTrans) ? N : K; + int ldc = N; + CBlas::GEMM(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, + beta, C, ldc); +} + +template <> +template +void Blas::GEMM( + const bool transA, const bool transB, const int M, const int N, const int K, + const T alpha, const T *A, const int lda, const T *B, const int ldb, + const T beta, T *C, const int ldc) const { + CBlas::GEMM(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, + transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, + lda, B, ldb, beta, C, ldc); +} + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/gru_compute.cc b/paddle/fluid/operators/math/gru_compute.cc index 3f044b77513..d7862502712 100644 --- a/paddle/fluid/operators/math/gru_compute.cc +++ b/paddle/fluid/operators/math/gru_compute.cc @@ -25,21 +25,21 @@ struct GRUUnitFunctor { const detail::ActivationType active_node, const detail::ActivationType active_gate) { #ifndef __NVCC__ + auto blas = math::GetBlas(context); if (value.prev_out_value) { - math::gemm( - context, false, false, batch_size, frame_size * 2, frame_size, 1, - value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, - 1, value.gate_value, frame_size * 3); + blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, + value.prev_out_value, frame_size, value.gate_weight, + frame_size * 2, 1, value.gate_value, frame_size * 3); } detail::forward_reset_output(detail::forward::gru_resetOutput(), value, frame_size, batch_size, active_gate); if (value.prev_out_value) { - math::gemm( - context, false, false, batch_size, frame_size, frame_size, 1, - value.reset_output_value, frame_size, value.state_weight, frame_size, - 1, value.gate_value + frame_size * 2, frame_size * 3); + blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, + value.reset_output_value, frame_size, value.state_weight, + frame_size, 1, value.gate_value + frame_size * 2, + frame_size * 3); } detail::forward_final_output(detail::forward::gru_finalOutput(), value, @@ -58,36 +58,32 @@ struct GRUUnitGradFunctor { #ifndef __NVCC__ detail::backward_state_grad(detail::backward::gru_stateGrad(), value, grad, frame_size, batch_size, active_node); - + auto blas = math::GetBlas(context); if (value.prev_out_value && grad.prev_out_grad) { - math::gemm( - context, false, true, batch_size, frame_size, frame_size, 1, - grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, - frame_size, 0, grad.reset_output_grad, frame_size); + blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, + grad.gate_grad + frame_size * 2, frame_size * 3, + value.state_weight, frame_size, 0, grad.reset_output_grad, + frame_size); if (grad.state_weight_grad) { - math::gemm( - context, true, false, frame_size, frame_size, batch_size, 1, - value.reset_output_value, frame_size, - grad.gate_grad + frame_size * 2, frame_size * 3, 1, - grad.state_weight_grad, frame_size); + blas.GEMM(true, false, frame_size, frame_size, batch_size, 1, + value.reset_output_value, frame_size, + grad.gate_grad + frame_size * 2, frame_size * 3, 1, + grad.state_weight_grad, frame_size); } } detail::backward_reset_grad(detail::backward::gru_resetGrad(), value, grad, frame_size, batch_size, active_gate); - if (grad.prev_out_grad && value.prev_out_value) { - math::gemm( - context, false, true, batch_size, frame_size, frame_size * 2, 1, - grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, - grad.prev_out_grad, frame_size); + blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, + grad.gate_grad, frame_size * 3, value.gate_weight, + frame_size * 2, 1, grad.prev_out_grad, frame_size); if (grad.gate_weight_grad) { - math::gemm( - context, true, false, frame_size, frame_size * 2, batch_size, 1, - value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, - grad.gate_weight_grad, frame_size * 2); + blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1, + value.prev_out_value, frame_size, grad.gate_grad, + frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2); } } #endif diff --git a/paddle/fluid/operators/math/gru_compute.cu b/paddle/fluid/operators/math/gru_compute.cu index 27caf3383dd..f26bec41095 100644 --- a/paddle/fluid/operators/math/gru_compute.cu +++ b/paddle/fluid/operators/math/gru_compute.cu @@ -9,6 +9,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ +#include #include "paddle/fluid/operators/math/detail/gru_gpu_kernel.h" #include "paddle/fluid/operators/math/detail/gru_kernel.h" #include "paddle/fluid/operators/math/gru_compute.h" @@ -36,12 +37,11 @@ struct GRUUnitFunctor { threads = dim3(32, 32); grid = dim3((frame_size + 32 - 1) / 32, (batch_size + 32 - 1) / 32); } - + auto blas = math::GetBlas(context); if (value.prev_out_value) { - math::gemm( - context, false, false, batch_size, frame_size * 2, frame_size, 1, - value.prev_out_value, frame_size, value.gate_weight, frame_size * 2, - 1, value.gate_value, frame_size * 3); + blas.GEMM(false, false, batch_size, frame_size * 2, frame_size, 1, + value.prev_out_value, frame_size, value.gate_weight, + frame_size * 2, 1, value.gate_value, frame_size * 3); } if (batch_size == 1) { @@ -61,10 +61,10 @@ struct GRUUnitFunctor { } if (value.prev_out_value) { - math::gemm( - context, false, false, batch_size, frame_size, frame_size, 1, - value.reset_output_value, frame_size, value.state_weight, frame_size, - 1, value.gate_value + frame_size * 2, frame_size * 3); + blas.GEMM(false, false, batch_size, frame_size, frame_size, 1, + value.reset_output_value, frame_size, value.state_weight, + frame_size, 1, value.gate_value + frame_size * 2, + frame_size * 3); } if (batch_size == 1) { @@ -121,18 +121,19 @@ struct GRUUnitGradFunctor { grad.output_grad, frame_size, batch_size, active_node); } + auto blas = math::GetBlas(context); + if (value.prev_out_value && grad.prev_out_grad) { - math::gemm( - context, false, true, batch_size, frame_size, frame_size, 1, - grad.gate_grad + frame_size * 2, frame_size * 3, value.state_weight, - frame_size, 0, grad.reset_output_grad, frame_size); + blas.GEMM(false, true, batch_size, frame_size, frame_size, 1, + grad.gate_grad + frame_size * 2, frame_size * 3, + value.state_weight, frame_size, 0, grad.reset_output_grad, + frame_size); if (grad.state_weight_grad) { - math::gemm( - context, true, false, frame_size, frame_size, batch_size, 1, - value.reset_output_value, frame_size, - grad.gate_grad + frame_size * 2, frame_size * 3, 1, - grad.state_weight_grad, frame_size); + blas.GEMM(true, false, frame_size, frame_size, batch_size, 1, + value.reset_output_value, frame_size, + grad.gate_grad + frame_size * 2, frame_size * 3, 1, + grad.state_weight_grad, frame_size); } } @@ -153,16 +154,14 @@ struct GRUUnitGradFunctor { } if (grad.prev_out_grad && value.prev_out_value) { - math::gemm( - context, false, true, batch_size, frame_size, frame_size * 2, 1, - grad.gate_grad, frame_size * 3, value.gate_weight, frame_size * 2, 1, - grad.prev_out_grad, frame_size); + blas.GEMM(false, true, batch_size, frame_size, frame_size * 2, 1, + grad.gate_grad, frame_size * 3, value.gate_weight, + frame_size * 2, 1, grad.prev_out_grad, frame_size); if (grad.gate_weight_grad) { - math::gemm( - context, true, false, frame_size, frame_size * 2, batch_size, 1, - value.prev_out_value, frame_size, grad.gate_grad, frame_size * 3, 1, - grad.gate_weight_grad, frame_size * 2); + blas.GEMM(true, false, frame_size, frame_size * 2, batch_size, 1, + value.prev_out_value, frame_size, grad.gate_grad, + frame_size * 3, 1, grad.gate_weight_grad, frame_size * 2); } } } diff --git a/paddle/fluid/operators/math/math_function.cc b/paddle/fluid/operators/math/math_function.cc index b5ae41c8f9d..b63676f961b 100644 --- a/paddle/fluid/operators/math/math_function.cc +++ b/paddle/fluid/operators/math/math_function.cc @@ -24,72 +24,6 @@ namespace math { using float16 = paddle::platform::float16; -template <> -void gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C) { - PADDLE_THROW("float16 GEMM not supported on CPU"); -} - -template <> -void gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C) { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cblas_sgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -void gemm( - const platform::CPUDeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C) { - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - int ldc = N; - cblas_dgemm(CblasRowMajor, transA, transB, M, N, K, alpha, A, lda, B, ldb, - beta, C, ldc); -} - -template <> -void gemm( - const platform::CPUDeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const int lda, const float16* B, - const int ldb, const float16 beta, float16* C, const int ldc) { - PADDLE_THROW("float16 GEMM not supported on CPU"); -} - -template <> -void gemm( - const platform::CPUDeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, const float alpha, - const float* A, const int lda, const float* B, const int ldb, - const float beta, float* C, const int ldc) { - cblas_sgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} - -template <> -void gemm( - const platform::CPUDeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, - const double alpha, const double* A, const int lda, const double* B, - const int ldb, const double beta, double* C, const int ldc) { - cblas_dgemm(CblasRowMajor, transA == false ? CblasNoTrans : CblasTrans, - transB == false ? CblasNoTrans : CblasTrans, M, N, K, alpha, A, - lda, B, ldb, beta, C, ldc); -} - template <> void matmul( const platform::CPUDeviceContext& context, @@ -123,8 +57,8 @@ void matmul( CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm( - context, transA, transB, M, N, K, alpha, matrix_a.data(), + Blas(context).GEMM( + transA, transB, M, N, K, alpha, matrix_a.data(), matrix_b.data(), beta, matrix_out->data()); } @@ -152,8 +86,8 @@ void matmul( CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm( - context, transA, transB, M, N, K, alpha, matrix_a.data(), + Blas(context).GEMM( + transA, transB, M, N, K, alpha, matrix_a.data(), matrix_b.data(), beta, matrix_out->data()); } @@ -230,8 +164,8 @@ void batched_gemm( const float* Ak = &A[k * strideA]; const float* Bk = &B[k * strideB]; float* Ck = &C[k * M * N]; - gemm(context, transA, transB, M, N, K, - alpha, Ak, Bk, beta, Ck); + Blas(context).GEMM(transA, transB, M, N, K, + alpha, Ak, Bk, beta, Ck); } } @@ -246,8 +180,8 @@ void batched_gemm( const double* Ak = &A[k * strideA]; const double* Bk = &B[k * strideB]; double* Ck = &C[k * M * N]; - gemm(context, transA, transB, M, N, K, - alpha, Ak, Bk, beta, Ck); + Blas(context).GEMM(transA, transB, M, N, K, + alpha, Ak, Bk, beta, Ck); } } #endif diff --git a/paddle/fluid/operators/math/math_function.cu b/paddle/fluid/operators/math/math_function.cu index 2aa819625e0..7bf816ac190 100644 --- a/paddle/fluid/operators/math/math_function.cu +++ b/paddle/fluid/operators/math/math_function.cu @@ -25,157 +25,6 @@ namespace math { using float16 = paddle::platform::float16; -template <> -void gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const float16* B, const float16 beta, - float16* C) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, - "cublas fp16 gemm requires GPU compute capability >= 53"); - -#if CUDA_VERSION >= 8000 - float h_alpha = static_cast(alpha); - float h_beta = static_cast(beta); - - cublasGemmAlgo_t algo = CUBLAS_GEMM_DFALT; -#if CUDA_VERSION >= 9000 - if (context.GetComputeCapability() >= 70) { - PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(), - CUBLAS_TENSOR_OP_MATH)); - algo = CUBLAS_GEMM_DFALT_TENSOR_OP; - } else { - PADDLE_ENFORCE(platform::dynload::cublasSetMathMode(context.cublas_handle(), - CUBLAS_DEFAULT_MATH)); - } -#endif // CUDA_VERSION >= 9000 - - // cublasHgemm does true FP16 computation which is slow for non-Volta - // GPUs. So use cublasGemmEx instead which does pesudo FP16 computation: - // input/output in fp16, computation in fp32, which can also be accelerated - // using tensor cores in volta GPUs. - PADDLE_ENFORCE(platform::dynload::cublasGemmEx( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, B, - CUDA_R_16F, ldb, A, CUDA_R_16F, lda, &h_beta, C, CUDA_R_16F, N, - CUDA_R_32F, algo)); -#else - // CUDA 7.5 does not support cublasGemmEx, hence we fall back to use hgemm - const half h_alpha = static_cast(alpha); - const half h_beta = static_cast(beta); - const half* h_A = reinterpret_cast(A); - const half* h_B = reinterpret_cast(B); - half* h_C = reinterpret_cast(C); - - PADDLE_ENFORCE(platform::dynload::cublasHgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, - h_A, lda, &h_beta, h_C, N)); -#endif // CUDA_VERSION >= 8000 -} - -template <> -void gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const float alpha, const float* A, const float* B, const float beta, - float* C) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - - PADDLE_ENFORCE(platform::dynload::cublasSgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, N)); -} - -template <> -void gemm( - const platform::CUDADeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const double alpha, const double* A, const double* B, const double beta, - double* C) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - int lda = (transA == CblasNoTrans) ? K : M; - int ldb = (transB == CblasNoTrans) ? N : K; - cublasOperation_t cuTransA = - (transA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = - (transB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; - PADDLE_ENFORCE(platform::dynload::cublasDgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, N)); -} - -template <> -void gemm( - const platform::CUDADeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, - const float16 alpha, const float16* A, const int lda, const float16* B, - const int ldb, const float16 beta, float16* C, const int ldc) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; - - const half h_alpha = static_cast(alpha); - const half h_beta = static_cast(beta); - const half* h_A = reinterpret_cast(A); - const half* h_B = reinterpret_cast(B); - half* h_C = reinterpret_cast(C); - - // TODO(kexinzhao): add processing code for compute capability < 53 case - PADDLE_ENFORCE_GE(context.GetComputeCapability(), 53, - "cublas Hgemm requires GPU compute capability >= 53"); - PADDLE_ENFORCE(platform::dynload::cublasHgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &h_alpha, h_B, ldb, - h_A, lda, &h_beta, h_C, ldc)); -} - -template <> -void gemm( - const platform::CUDADeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, const float alpha, - const float* A, const int lda, const float* B, const int ldb, - const float beta, float* C, const int ldc) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; - PADDLE_ENFORCE(platform::dynload::cublasSgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, ldc)); -} - -template <> -void gemm( - const platform::CUDADeviceContext& context, const bool transA, - const bool transB, const int M, const int N, const int K, - const double alpha, const double* A, const int lda, const double* B, - const int ldb, const double beta, double* C, const int ldc) { - // Note that cublas follows fortran order, so the order is different from - // the cblas convention. - cublasOperation_t cuTransA = transA == false ? CUBLAS_OP_N : CUBLAS_OP_T; - cublasOperation_t cuTransB = transB == false ? CUBLAS_OP_N : CUBLAS_OP_T; - PADDLE_ENFORCE(platform::dynload::cublasDgemm( - context.cublas_handle(), cuTransB, cuTransA, N, M, K, &alpha, B, ldb, A, - lda, &beta, C, ldc)); -} - template <> void matmul( const platform::CUDADeviceContext& context, @@ -200,8 +49,8 @@ void matmul( CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm( - context, transA, transB, M, N, K, alpha, matrix_a.data(), + Blas(context).GEMM( + transA, transB, M, N, K, alpha, matrix_a.data(), matrix_b.data(), beta, matrix_out->data()); } @@ -229,8 +78,8 @@ void matmul( CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm( - context, transA, transB, M, N, K, alpha, matrix_a.data(), + Blas(context).GEMM( + transA, transB, M, N, K, alpha, matrix_a.data(), matrix_b.data(), beta, matrix_out->data()); } @@ -258,8 +107,8 @@ void matmul( CBLAS_TRANSPOSE transA = (trans_a == false) ? CblasNoTrans : CblasTrans; CBLAS_TRANSPOSE transB = (trans_b == false) ? CblasNoTrans : CblasTrans; - gemm( - context, transA, transB, M, N, K, alpha, matrix_a.data(), + Blas(context).GEMM( + transA, transB, M, N, K, alpha, matrix_a.data(), matrix_b.data(), beta, matrix_out->data()); } diff --git a/paddle/fluid/operators/math/math_function.h b/paddle/fluid/operators/math/math_function.h index cdd02974722..9950c09ea61 100644 --- a/paddle/fluid/operators/math/math_function.h +++ b/paddle/fluid/operators/math/math_function.h @@ -42,6 +42,7 @@ int LAPACKE_dgetri(int matrix_layout, int n, double* a, int lda, #include #include "paddle/fluid/framework/eigen.h" +#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/platform/device_context.h" @@ -56,17 +57,48 @@ namespace math { // Then matrixA: M * K, matrixB: K * N, matrixC : M * N // For more detailed info, please refer to // http://www.netlib.org/lapack/explore-html/d4/de2/sgemm_8f.html + +template +class Blas { + public: + explicit Blas(const DeviceContext& context) : context_(context) {} + + template + void GEMM(const CBLAS_TRANSPOSE transA, const CBLAS_TRANSPOSE transB, + const int M, const int N, const int K, const T alpha, const T* A, + const T* B, const T beta, T* C) const; + + template + void GEMM(const bool transA, const bool transB, const int M, const int N, + const int K, const T alpha, const T* A, const int lda, const T* B, + const int ldb, const T beta, T* C, const int ldc) const; + + private: + const DeviceContext& context_; +}; + template -void gemm(const DeviceContext& context, const CBLAS_TRANSPOSE transA, - const CBLAS_TRANSPOSE transB, const int M, const int N, const int K, - const T alpha, const T* A, const T* B, const T beta, T* C); +class BlasT : private Blas { + public: + using Blas::Blas; + + template + void GEMM(ARGS... args) const { + static_cast*>(this)->template GEMM(args...); + } +}; -// gemm wrapper with stride args for matrix uncontinuous in memory template -void gemm(const DeviceContext& context, const bool transA, const bool transB, - const int M, const int N, const int K, const T alpha, const T* A, - const int lda, const T* B, const int ldb, const T beta, T* C, - const int ldc); +inline BlasT GetBlas( + const framework::ExecutionContext& exe_ctx) { + return BlasT( + exe_ctx.template device_context()); +} + +template +inline BlasT GetBlas(const DeviceContext& dev_ctx) { + return BlasT(dev_ctx); +} // matrix multiply with continuous memory template @@ -137,3 +169,8 @@ struct RowwiseMean { } // namespace math } // namespace operators } // namespace paddle + +#include "paddle/fluid/operators/math/blas_impl.h" +#ifdef PADDLE_WITH_CUDA +#include "paddle/fluid/operators/math/blas_impl.cu.h" +#endif diff --git a/paddle/fluid/operators/math/matmul.h b/paddle/fluid/operators/math/matmul.h index 0006c5062f3..67efd1be532 100644 --- a/paddle/fluid/operators/math/matmul.h +++ b/paddle/fluid/operators/math/matmul.h @@ -131,8 +131,9 @@ class MatMulFunctor { if (!batchCount) { // regular matrix multiplication - gemm(context, transA, transB, M, N, kA, alpha, - a.data(), b.data(), beta, out->data()); + Blas(context).GEMM(transA, transB, M, N, kA, alpha, + a.data(), b.data(), beta, + out->data()); } else { // batched matrix multiplication batched_gemm( -- GitLab