From f6f18835c76460656f17ab1c46c8f0012c2bd356 Mon Sep 17 00:00:00 2001 From: limingshu <61349199+JamesLim-sy@users.noreply.github.com> Date: Wed, 19 Apr 2023 12:20:25 +0800 Subject: [PATCH] Support Linear operation in cuBlaslt and plug into attn_gemm and fusedLinear backward op (#52028) * first commit * restruct c++ interface to divide linear from matmulwithcublaslt * finish building in cublaslt impl * fix code bugs * fix host cost * add some changes --- paddle/fluid/operators/fused/attn_gemm.h | 33 +- .../operators/fused/fused_gemm_epilogue_op.cc | 62 +- .../operators/fused/fused_gemm_epilogue_op.cu | 34 +- paddle/phi/kernels/autotune/cache.cc | 2 +- .../phi/kernels/funcs/blas/blaslt_impl.cu.h | 620 ++++++++++++------ paddle/phi/kernels/funcs/common_shape.h | 3 +- paddle/phi/kernels/funcs/dropout_impl.cu.h | 6 +- .../phi/kernels/funcs/fused_gemm_epilogue.h | 175 +++-- .../phi/kernels/gpu/cross_entropy_kernel.cu | 4 +- 9 files changed, 602 insertions(+), 337 deletions(-) diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 9ec25c110e5..9709f60bbc1 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -68,25 +68,20 @@ class AttnMatMul { "The output (= input * weight) is expected to be nullptr or the " "same as bias_out when fused is true.")); - auto fused_impl = - phi::funcs::MatmulPlanner(vectorize(input->dims()), - vectorize(weight->dims()), - transA_, - transB_, - phi::CppTypeToDataType::Type(), - phi::funcs::MatmulFusedType::kMatmulBias, - static_cast(bias->data()), - nullptr); - phi::funcs::MatmulWithCublasLt::Run(dev_ctx_, - input->data(), - weight->data(), - bias_out->data(), - bsz_seq_, // M - output_size_, // N - input_size_, // K - transA_, - transB_, - &fused_impl); + phi::funcs::LinearWithCublasLt::Run( + dev_ctx_, + input, // x + weight, // y + bias_out, // out + static_cast(bias->data()), // bias + nullptr, + bsz_seq_, // M + output_size_, // N + input_size_, // K + transA_, + transB_, + phi::funcs::MatmulFusedType::kMatmulBias); + return; } #endif diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index dc1c9c3f0af..d87bb97edf3 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -36,7 +36,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); auto bias_dims = ctx->GetInputDim("Bias"); - auto trans_x = ctx->Attrs().Get("trans_x"); auto trans_y = ctx->Attrs().Get("trans_y"); @@ -88,27 +87,6 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { K_from_x, K_from_y)); - auto activation = ctx->Attrs().Get("activation"); - if (activation == "none" && ctx->HasOutput("ReserveSpace")) { - PADDLE_THROW(platform::errors::InvalidArgument( - "The ReserveSpace would not be used when activation = \"none\"")); - } - - // cublasLt's restriction for auxiliary. - if (ctx->HasOutput("ReserveSpace") && activation != "none") { - int min_size_of_n = activation == "relu" ? 128 : 8; - int N_size = trans_y ? y_dims[0] : y_dims[1]; - PADDLE_ENFORCE_EQ(N_size % min_size_of_n, - 0, - platform::errors::InvalidArgument( - "The output dimension N (X(MxK) * Y(KxN) = C(MxN)) " - "should be multiple of %d when auxiliary_key given " - "and activation=%s, but got N = %d.", - min_size_of_n, - activation, - N_size)); - } - std::vector out_dims; out_dims.reserve(static_cast(x_dims.size())); if (trans_x) { @@ -122,11 +100,29 @@ class FusedGemmEpilogueOp : public framework::OperatorWithKernel { } else { out_dims.push_back(y_dims[1]); } - ctx->SetOutputDim("Out", phi::make_ddim(out_dims)); + auto activation = ctx->Attrs().Get("activation"); if (ctx->HasOutput("ReserveSpace")) { ctx->SetOutputDim("ReserveSpace", phi::make_ddim(out_dims)); + + if (activation == "none") { + PADDLE_THROW(platform::errors::InvalidArgument( + "The ReserveSpace would not be used when activation = \"none\"")); + } else { + int min_size_of_n = activation == "relu" ? 128 : 8; + int N_size = trans_y ? y_dims[0] : y_dims[1]; + PADDLE_ENFORCE_EQ( + N_size % min_size_of_n, + 0, + platform::errors::InvalidArgument( + "The output dimension N (X(MxK) * Y(KxN) = C(MxN)) " + "should be multiple of %d when auxiliary_key given " + "and activation=%s, but got N = %d.", + min_size_of_n, + activation, + N_size)); + } } } @@ -202,7 +198,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { auto dout_dims = ctx->GetInputDim("DOut"); auto x_dims = ctx->GetInputDim("X"); auto y_dims = ctx->GetInputDim("Y"); - auto trans_x = ctx->Attrs().Get("trans_x"); auto trans_y = ctx->Attrs().Get("trans_y"); @@ -241,7 +236,6 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { x_dims.size())); auto dout_mat_dims = phi::flatten_to_2d(dout_dims, dout_dims.size() - 1); - auto x_mat_dims = phi::flatten_to_2d(x_dims, x_dims.size() - 1); PADDLE_ENFORCE_EQ( @@ -268,25 +262,17 @@ class FusedGemmEpilogueGradOp : public framework::OperatorWithKernel { false, platform::errors::InvalidArgument( "The ReserveSpace should not be empty. " - "when activation_grad == {relu_grad, gelu_grad}.")); + "when activation == {relu_grad, gelu_grad}.")); } if (ctx->HasOutput("DX")) { - std::vector dx_dims; - dx_dims.reserve(static_cast(x_dims.size())); - for (int i = 0; i < x_dims.size(); ++i) { - dx_dims.push_back(x_dims[i]); - } - ctx->SetOutputDim("DX", phi::make_ddim(dx_dims)); + ctx->SetOutputDim("DX", x_dims); } - - std::vector dy_dims(y_dims.Get(), y_dims.Get() + y_dims.size()); - ctx->SetOutputDim("DY", phi::make_ddim(dy_dims)); + ctx->SetOutputDim("DY", y_dims); if (ctx->HasOutput("DBias")) { - std::vector dbias_dims; - dbias_dims.push_back(trans_y ? y_dims[0] : y_dims[1]); - ctx->SetOutputDim("DBias", phi::make_ddim(dbias_dims)); + int64_t dbias_dim = trans_y ? y_dims[0] : y_dims[1]; + ctx->SetOutputDim("DBias", phi::make_ddim({dbias_dim})); } } diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 483194f7c47..7fad2871827 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -17,7 +17,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" -#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" namespace paddle { @@ -101,26 +100,19 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { << ", activation=" << activation << ", fused_type=" << fused_type << ", reserve_space=" << reserve_space; - auto fused_impl = - phi::funcs::MatmulPlanner(vectorize(x->dims()), - vectorize(y->dims()), - trans_x, - trans_y, - phi::CppTypeToDataType::Type(), - fused_type, - static_cast(bias->data()), - reserve_data); - - phi::funcs::MatmulWithCublasLt::Run(dev_ctx, - x->data(), - y->data(), - out->data(), - M, - N, - K, - trans_x, - trans_y, - &fused_impl); + phi::funcs::LinearWithCublasLt::Run( + dev_ctx, + x, + y, + out, + static_cast(bias->data()), + reserve_data, + M, + N, + K, + trans_x, + trans_y, + fused_type); } }; diff --git a/paddle/phi/kernels/autotune/cache.cc b/paddle/phi/kernels/autotune/cache.cc index aacb4f6e268..6ff1296b513 100644 --- a/paddle/phi/kernels/autotune/cache.cc +++ b/paddle/phi/kernels/autotune/cache.cc @@ -25,7 +25,7 @@ size_t TransposeKey(const std::vector& x_dims, const std::vector& perm, phi::DataType dtype) { const auto rank = perm.size(); - return GenKey(x_dims, perm, rank, static_cast(dtype)); + return GenKey(x_dims, perm, rank, static_cast(dtype)); } std::string AlgorithmTypeString(int64_t algo_type) { diff --git a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h index 284d866d08f..1bc409bd0df 100644 --- a/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h +++ b/paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h @@ -33,20 +33,87 @@ namespace phi { namespace funcs { #if (defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060) + // Set this enum according to // https://docs.nvidia.com/cuda/cublas/index.html#cublasltepilogue-t +// While kMatmul, kMatmulGrad, kMatmulGradWithoutBias share the same +// enum value, but if all elements for MatmulPlanner->GetKey() is same, +// no matter forward or backward, they could share the same descriptor +// cache, in that the descritpor is for decription of matmul operation. enum MatmulFusedType { - kMatmul = CUBLASLT_EPILOGUE_DEFAULT, // No special postprocessing. + kMatmul = CUBLASLT_EPILOGUE_DEFAULT, + kMatmulGrad = CUBLASLT_EPILOGUE_DEFAULT, + kMatmulGradWithoutBias = CUBLASLT_EPILOGUE_DEFAULT, kMatmulBias = CUBLASLT_EPILOGUE_BIAS, kMatmulRelu = CUBLASLT_EPILOGUE_RELU, - kMatmulBiasRelu = - CUBLASLT_EPILOGUE_RELU_BIAS, // Apply bias and then ReLU transform. - kMatmulBiasGelu = - CUBLASLT_EPILOGUE_GELU_BIAS, // Apply Bias and then GELU transform. + kMatmulBiasRelu = CUBLASLT_EPILOGUE_RELU_BIAS, + kMatmulBiasGelu = CUBLASLT_EPILOGUE_GELU_BIAS, kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS, - kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS + kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS, + kMatmulReluGrad = CUBLASLT_EPILOGUE_DRELU, + kMatmulGeluGrad = CUBLASLT_EPILOGUE_DGELU, + kMatmulBiasGradToA = CUBLASLT_EPILOGUE_BGRADA, + kMatmulBiasGradToB = CUBLASLT_EPILOGUE_BGRADB +}; + +enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; + +template +struct FusedGEMMGradTrait; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDX; + static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradATrans = false; + static constexpr auto kYGradBTrans = false; }; +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradB = FusedGEMMGradInType::kDY; + static constexpr auto kXGradATrans = false; + static constexpr auto kXGradBTrans = false; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = false; +}; + +template <> +struct FusedGEMMGradTrait { + static constexpr auto kXGradA = FusedGEMMGradInType::kDY; + static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; + static constexpr auto kXGradATrans = true; + static constexpr auto kXGradBTrans = true; + + static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; + static constexpr auto kYGradB = FusedGEMMGradInType::kDX; + static constexpr auto kYGradATrans = true; + static constexpr auto kYGradBTrans = true; +}; + +// To tell any matmul or fused matmul operation from each other. struct MatmulPlanner { public: const void* bias{nullptr}; @@ -60,23 +127,31 @@ struct MatmulPlanner { phi::DataType dtype, MatmulFusedType impl_type, const void* bias_data = nullptr, - void* reserve_data = nullptr) - : bias(bias_data), aux_data(reserve_data) { - type = impl_type; - key = phi::autotune::GenKey(x_dims, - y_dims, - static_cast(trans_x), - static_cast(trans_y), - static_cast(dtype)); + void* reserve_data = nullptr, // Commonly for ReLu bit-mask. + bool use_addto = false, + bool no_exchange = true) + : bias(bias_data), aux_data(reserve_data), impl_type_(impl_type) { + use_addto_ = use_addto; + key_ = phi::autotune::GenKey(x_dims, + y_dims, + static_cast(trans_x), + static_cast(trans_y), + static_cast(dtype), + static_cast(no_exchange)); } - MatmulFusedType ImplType() const { return type; } - size_t GetKey() const { return key; } - size_t GenSubKey(int idx) const { return phi::autotune::GenKey(key, idx); } + bool UseAddTo() const { return use_addto_; } + size_t GetKey() const { return key_; } + MatmulFusedType ImplType() const { return impl_type_; } + + size_t GenSubKey(int idx) const { + return phi::autotune::GenKey(key_, static_cast(use_addto_), idx); + } private: - MatmulFusedType type; - size_t key; + MatmulFusedType impl_type_; + bool use_addto_; + size_t key_; }; template @@ -124,19 +199,19 @@ struct MatmulDescriptor { } // x_desc, y_desc, op_desc are allocated in heap memory. - template - void Create(const int M, - const int N, - const int K, + template + void Create(const int64_t M, + const int64_t N, + const int64_t K, const bool trans_x, const bool trans_y, phi::funcs::MatmulPlanner* planner, const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0) { + const int64_t stride_x = 0, + const int64_t stride_y = 0, + const int64_t stride_out = 0, + bool grad_for_dx = true) { using MT = typename phi::dtype::MPTypeTrait::Type; - cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); cublasComputeType_t compute_type = GetCudaComputeType(); @@ -145,18 +220,7 @@ struct MatmulDescriptor { // details about defaults; just need to set the transforms for A and B PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); - cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; - cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &cublas_trans_x, - sizeof(cublas_trans_x))); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &cublas_trans_y, - sizeof(cublas_trans_y))); + SetFusedEpilogueOpDescriptor(planner, trans_x, trans_y, N); // Create matrix descriptors CreateMatrixLayout(&x_desc, mat_type, M, K, trans_x); @@ -169,7 +233,6 @@ struct MatmulDescriptor { SetBatchAndStride(y_desc, batch_size, stride_y); SetBatchAndStride(out_desc, batch_size, stride_out); } - SetFusedEpilogueOpDescriptor(planner, N); } cublasLtMatmulAlgo_t* SetAlgo() { @@ -188,14 +251,13 @@ struct MatmulDescriptor { CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_data, sizeof(bias_data))); - - if (planner->aux_data != nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &(planner->aux_data), - sizeof(planner->aux_data))); - } + } + if (planner->aux_data != nullptr) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( + op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &(planner->aux_data), + sizeof(planner->aux_data))); } } @@ -223,7 +285,42 @@ struct MatmulDescriptor { return out.str(); } - private: + void ExchangeXYDesc(bool no_exchange) {} + + protected: + void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner, + const bool trans_x, + const bool trans_y, + int64_t lead_dim) { + cublasOperation_t cublas_trans_x = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t cublas_trans_y = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &cublas_trans_x, + sizeof(cublas_trans_x))); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &cublas_trans_y, + sizeof(cublas_trans_y))); + if (planner->ImplType() != kMatmul) { + auto fused_type = static_cast(planner->ImplType()); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescSetAttribute(op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &fused_type, + sizeof(fused_type))); + } + if (planner->aux_data) { + PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( + op_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &lead_dim, + sizeof(lead_dim))); + } + } + void CreateMatrixLayout(cublasLtMatrixLayout_t* desc, cudaDataType type, uint64_t rows, @@ -252,145 +349,62 @@ struct MatmulDescriptor { &stride, sizeof(stride))); } - - void SetFusedEpilogueOpDescriptor(phi::funcs::MatmulPlanner* planner, - int64_t lead_dim) { - if (planner->bias) { - auto fuse_type = static_cast(planner->ImplType()); - PADDLE_ENFORCE_GPU_SUCCESS( - dynload::cublasLtMatmulDescSetAttribute(op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &fuse_type, - sizeof(fuse_type))); - if (planner->aux_data) { - PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmulDescSetAttribute( - op_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &lead_dim, - sizeof(lead_dim))); - } - } - } }; -template -struct DescriptorSetter { - MatmulDescriptor desc; - size_t sub_key{std::numeric_limits::min()}; +struct MatmulGradDescriptor : MatmulDescriptor { + public: + MatmulGradDescriptor() {} - DescriptorSetter(phi::funcs::MatmulPlanner* planner, - const int M, - const int N, - const int K, - const bool trans_x, - const bool trans_y, - const int batch_size = 1, - int64_t stride_x = 0, - int64_t stride_y = 0, - int64_t stride_out = 0) { - if (planner != nullptr) { - sub_key = planner->GenSubKey(static_cast(planner->ImplType())); - } + template + void Create(const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + phi::funcs::MatmulPlanner* planner, + const int batch_size = 1, + int64_t stride_x = 0, + int64_t stride_y = 0, + int64_t stride_out = 0, + bool grad_for_dx = true) { + using MT = typename phi::dtype::MPTypeTrait::Type; + cudaDataType_t mat_type = phi::backends::gpu::ToCudaDataType(); + cudaDataType_t scale_type = phi::backends::gpu::ToCudaDataType(); + cublasComputeType_t compute_type = GetCudaComputeType(); - auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); - if (mamtul_cache.FindSubKey(sub_key)) { - desc = *( - reinterpret_cast(mamtul_cache.GetSubKey(sub_key))); - desc.SetFusedEpiloguePtr(planner); - VLOG(6) << desc.GetDescResultString("[Heap MatmulDescriptor] "); + PADDLE_ENFORCE_GPU_SUCCESS( + dynload::cublasLtMatmulDescCreate(&op_desc, compute_type, scale_type)); + this->SetFusedEpilogueOpDescriptor( + planner, trans_x, trans_y, TransX ? M : K); + + // Create operation desciriptor; see cublasLtMatmulDescAttributes_t for + // details about defaults; just need to set the transforms for A and B + this->CreateMatrixLayout(&x_desc, mat_type, N, M, true); + if (grad_for_dx) { + this->CreateMatrixLayout(&y_desc, mat_type, K, N, TransY); + this->CreateMatrixLayout( + &out_desc, phi::backends::gpu::ToCudaDataType(), M, K, TransX); } else { - desc.Create(M, - N, - K, - trans_x, - trans_y, - planner, - batch_size, - stride_x, - stride_y, - stride_out); - if (planner != nullptr) { - desc.SetFusedEpiloguePtr(planner); - } - VLOG(6) << desc.GetDescResultString("[Stack MatmulDescriptor] ", false); + this->CreateMatrixLayout(&y_desc, mat_type, M, K, TransX); + this->CreateMatrixLayout( + &out_desc, phi::backends::gpu::ToCudaDataType(), K, N, TransY); } } -}; - -template -struct MatmulWithCublasLt { - public: - using MT = typename phi::dtype::MPTypeTrait::Type; - - static void Run(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - T* out_data, - const int M, - const int N, - const int K, - const bool trans_x, - const bool trans_y, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter(planner, M, N, K, trans_x, trans_y); - RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - static void RunWithBatch(const phi::GPUContext& ctx, - const T* x_data, - const T* y_data, - T* out_data, - const int M, - const int N, - const int K, - bool trans_x, - bool trans_y, - int batch_size, - int64_t stride_x, - int64_t stride_y, - int64_t stride_out, - phi::funcs::MatmulPlanner* planner = nullptr) { - auto setter = DescriptorSetter(planner, - M, - N, - K, - trans_x, - trans_y, - batch_size, - stride_x, - stride_y, - stride_out); - RunImpl( - ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); - } - - static void RunWithBatch(const phi::GPUContext& ctx, - const T** x_data, - const T** y_data, - T** out_data, - const int M, - const int N, - const int K, - bool trans_x, - bool trans_y, - int batch_size, - phi::funcs::MatmulPlanner* planner = nullptr) { - for (int i = 0; i < batch_size; ++i) { - Run(ctx, - x_data[i], - y_data[i], - out_data[i], - M, - N, - K, - trans_x, - trans_y, - planner); + void ExchangeXYDesc(bool no_exchange) { + if (no_exchange) { + return; } + auto* temp = y_desc; + y_desc = x_desc; + x_desc = temp; } +}; - private: +template +struct CublasLtBase { + public: + using MT = typename phi::dtype::MPTypeTrait::Type; static phi::Allocator::AllocationPtr GetWorkspace(const phi::GPUContext& ctx, size_t workspace_size) { return phi::memory_utils::Alloc( @@ -400,16 +414,19 @@ struct MatmulWithCublasLt { } static void RunImpl(const phi::GPUContext& ctx, - MatmulDescriptor* desc, + MatmulDescT* desc, const size_t sub_key, const T* x_ptr, const T* y_ptr, - T* out_ptr, + OutT* out_ptr, phi::funcs::MatmulPlanner* planner) { MT alpha = static_cast(1); - MT beta = static_cast(0); - + MT beta = planner->UseAddTo() ? static_cast(1) : static_cast(0); cublasLtHandle_t cublaslt_handle = ctx.cublaslt_handle(); + + // NOTE(limingshu): As workspace_size varies from different DL framework, + // I wonder is there any smarter idea for workspace setting, currently I + // just followed the settings from the NVIDIA colleague`s setting. size_t workspace_size = static_cast(4) * 1024 * 1024; phi::Allocator::AllocationPtr workspace = GetWorkspace(ctx, workspace_size); @@ -426,16 +443,16 @@ struct MatmulWithCublasLt { out_ptr, workspace->ptr(), workspace_size); - MatmulDescriptor* best_desc = new MatmulDescriptor(*desc); + MatmulDescT* best_desc = new MatmulDescT(*desc); VLOG(6) << best_desc->GetDescResultString( - "[Searched MatmulDescriptor] "); + "[Searched CublasltDescriptor] "); auto& cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); cache.SetSubKey(sub_key, reinterpret_cast(best_desc)); } } - VLOG(6) << desc->GetDescResultString("[Impl MatmulDescriptor] "); + VLOG(6) << desc->GetDescResultString("[Impl CublasltDescriptor] "); PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmul(cublaslt_handle, desc->op_desc, @@ -457,7 +474,7 @@ struct MatmulWithCublasLt { static void SearchBestAlgo(const phi::GPUContext& ctx, const cublasLtHandle_t& lt_handle, - MatmulDescriptor* desc, + MatmulDescT* desc, const void* alpha, const void* beta, const void* y_data, @@ -526,7 +543,7 @@ struct MatmulWithCublasLt { } } float time_cnt = (cur_time / (repeats - 1)); - VLOG(4) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]" + VLOG(6) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]" << "is : " << time_cnt << " s"; if (cur_time < min_time_cost) { @@ -534,12 +551,241 @@ struct MatmulWithCublasLt { min_time_cost = cur_time; } } - VLOG(4) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx; + VLOG(6) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx; *best_algo = heuristic_results[best_algo_idx].algo; PADDLE_ENFORCE_GPU_SUCCESS( dynload::cublasLtMatmulPreferenceDestroy(preference)); } }; + +// To judge if desc is cached or not. +template +struct DescriptorSetter { + public: + DescT desc; + size_t sub_key{std::numeric_limits::min()}; + + DescriptorSetter(phi::funcs::MatmulPlanner* planner, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + const int batch_size = 1, + int64_t stride_x = 0, + int64_t stride_y = 0, + int64_t stride_out = 0, + const bool no_exchange = true, + bool grad_for_dx = true) { + if (planner != nullptr) { + sub_key = planner->GenSubKey(static_cast(planner->ImplType())); + } + + auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); + if (mamtul_cache.FindSubKey(sub_key)) { + desc = *(reinterpret_cast(mamtul_cache.GetSubKey(sub_key))); + desc.template SetFusedEpiloguePtr(planner); + VLOG(6) << desc.GetDescResultString("[Heap CublasltDescriptor] "); + } else { + desc.template Create(M, + N, + K, + trans_x, + trans_y, + planner, + batch_size, + stride_x, + stride_y, + stride_out, + grad_for_dx); + desc.ExchangeXYDesc(no_exchange); + if (planner != nullptr) { + desc.template SetFusedEpiloguePtr(planner); + } + VLOG(6) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false); + } + } +}; + +// For matmul with kernels autotune +template +struct MatmulWithCublasLt : public CublasLtBase { + public: + static void Run(const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + T* out_data, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + phi::funcs::MatmulPlanner* planner = nullptr) { + auto setter = DescriptorSetter( + planner, M, N, K, trans_x, trans_y); + CublasLtBase::RunImpl( + ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); + } + + static void RunWithBatch(const phi::GPUContext& ctx, + const T* x_data, + const T* y_data, + T* out_data, + const int64_t M, + const int64_t N, + const int64_t K, + bool trans_x, + bool trans_y, + int batch_size, + int64_t stride_x, + int64_t stride_y, + int64_t stride_out, + phi::funcs::MatmulPlanner* planner = nullptr) { + auto setter = DescriptorSetter(planner, + M, + N, + K, + trans_x, + trans_y, + batch_size, + stride_x, + stride_y, + stride_out); + CublasLtBase::RunImpl( + ctx, &setter.desc, setter.sub_key, x_data, y_data, out_data, planner); + } + + static void RunWithBatch(const phi::GPUContext& ctx, + const T** x_data, + const T** y_data, + T** out_data, + const int64_t M, + const int64_t N, + const int64_t K, + bool trans_x, + bool trans_y, + int batch_size, + phi::funcs::MatmulPlanner* planner = nullptr) { + for (int i = 0; i < batch_size; ++i) { + Run(ctx, + x_data[i], + y_data[i], + out_data[i], + M, + N, + K, + trans_x, + trans_y, + planner); + } + } +}; + +// As for just Linear fused ephilogue below: out = matmul(x, y) + bias. +template +struct LinearWithCublasLt : public CublasLtBase { + static void Run(const phi::GPUContext& ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + phi::DenseTensor* out, + const void* bias_data, + void* reserve_data, + const int64_t M, + const int64_t N, + const int64_t K, + const bool trans_x, + const bool trans_y, + const MatmulFusedType fused_type) { + auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()), + vectorize(y->dims()), + trans_x, + trans_y, + phi::CppTypeToDataType::Type(), + fused_type, + bias_data, + reserve_data); + auto setter = DescriptorSetter( + &planner, M, N, K, trans_x, trans_y); + CublasLtBase::RunImpl(ctx, + &setter.desc, + setter.sub_key, + x->data(), + y->data(), + out->data(), + &planner); + } +}; + +template +struct LinearGradWithCublasLt : public CublasLtBase { + static void Run( + const phi::GPUContext& ctx, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + phi::DenseTensor* out, + const void* bias_data, + void* reserve_data, + const int64_t M, + const int64_t N, + const int64_t K, + const MatmulFusedType fused_type, + const bool trans_x, + const bool trans_y, + const bool use_addto, + const bool no_exchange, // exchange x_desc and y_desc for grad. + bool grad_for_dx = true) { + auto planner = phi::funcs::MatmulPlanner(vectorize(x->dims()), + vectorize(y->dims()), + trans_x, + trans_y, + phi::CppTypeToDataType::Type(), + fused_type, + bias_data, + reserve_data, + use_addto, + no_exchange); + auto setter = + DescriptorSetter( + &planner, + M, + N, + K, + trans_x, + trans_y, + /*batch_size=*/1, + /*stride_x=*/0, + /*stride_y=*/0, + /*stride_out=*/0, + /*exchange_x_y_desc=*/no_exchange, + /*grad_for_dx=*/grad_for_dx); + + // To setting data type for different kinda out_data. + if (grad_for_dx) { + CublasLtBase::RunImpl( + ctx, + &setter.desc, + setter.sub_key, + no_exchange ? x->data() : y->data(), + no_exchange ? y->data() : x->data(), + out->data(), + &planner); + } else { + CublasLtBase::RunImpl( + ctx, + &setter.desc, + setter.sub_key, + no_exchange ? x->data() : y->data(), + no_exchange ? y->data() : x->data(), + out->data(), + &planner); + } + } +}; #else // A void structure just for successfully complile. struct MatmulPlanner {}; diff --git a/paddle/phi/kernels/funcs/common_shape.h b/paddle/phi/kernels/funcs/common_shape.h index f7524320e88..8db9a92f47d 100644 --- a/paddle/phi/kernels/funcs/common_shape.h +++ b/paddle/phi/kernels/funcs/common_shape.h @@ -52,6 +52,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, "Axis should be less than or equal to %d, but received axis is %d.", max_dim, axis)); + if (x_dims.size() > y_dims.size()) { std::fill(y_dims_array, y_dims_array + axis, 1); if (axis + y_dims.size() < max_dim) { @@ -68,7 +69,7 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims, std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array); } - for (int i = 0; i < max_dim; i++) { + for (int i = 0; i < max_dim; ++i) { PADDLE_ENFORCE_EQ( x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 || y_dims_array[i] <= 1, diff --git a/paddle/phi/kernels/funcs/dropout_impl.cu.h b/paddle/phi/kernels/funcs/dropout_impl.cu.h index 494d95fcf83..0b47febb0d3 100644 --- a/paddle/phi/kernels/funcs/dropout_impl.cu.h +++ b/paddle/phi/kernels/funcs/dropout_impl.cu.h @@ -350,8 +350,10 @@ void DropoutFwGPUKernelDriver( auto dst_functor = DstFunctor(1.0f - dropout_prob, upscale_in_train, x_numel); - std::vector out_dims = phi::vectorize(x.dims()); - std::vector in_dims = phi::vectorize(mask->dims()); + std::vector out_dims = + std::move(phi::vectorize(x.dims())); + std::vector in_dims = + std::move(phi::vectorize(mask->dims())); std::reverse(out_dims.begin(), out_dims.end()); std::reverse(in_dims.begin(), in_dims.end()); kps::details::BroadcastConfig broadcast_config( diff --git a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h index f0e57619940..4f1a1c6f0bd 100644 --- a/paddle/phi/kernels/funcs/fused_gemm_epilogue.h +++ b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h @@ -37,6 +37,7 @@ limitations under the License. */ #include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/scope_guard.h" +#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/utils/optional.h" DECLARE_int64(cublaslt_exhaustive_search_times); @@ -488,62 +489,103 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, phi::dynload::cublasLtMatrixLayoutDestroy(out_desc)); } -enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; - -template -struct FusedGEMMGradTrait; - -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; +struct BwdFusedEpilogueSetter { + public: + static phi::funcs::MatmulFusedType SetForDx( + const std::string& activation_grad) { + if (activation_grad == "none") { + return kMatmulGrad; + } else if (activation_grad == "relu_grad") { + return kMatmulReluGrad; + } else if (activation_grad == "gelu_grad") { + return kMatmulGeluGrad; + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Fued linear epilogue type should be one of {none, relu, gelu}." + "But received activation is %s, please check", + activation_grad)); + } + } -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDX; - static constexpr auto kYGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradATrans = false; - static constexpr auto kYGradBTrans = false; + template + static phi::funcs::MatmulFusedType SetForDy(const phi::GPUContext& dev_ctx, + phi::DenseTensor* dbias) { + if (dbias != nullptr) { + dev_ctx.Alloc(dbias, dbias->numel() * sizeof(DYT)); + return TransY ? kMatmulBiasGradToB : kMatmulBiasGradToA; + } else { + return kMatmulGradWithoutBias; + } + } }; -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradB = FusedGEMMGradInType::kDY; - static constexpr auto kXGradATrans = false; - static constexpr auto kXGradBTrans = false; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = false; -}; +template +void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, + const phi::DenseTensor* dout, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + const phi::DenseTensor* reserve_space, + int64_t M, + int64_t N, + int64_t K, + const std::string activation_grad, + phi::DenseTensor* dx, + phi::DenseTensor* dy, + phi::DenseTensor* dbias, + bool use_addto_dx, + bool use_addto_dy) { + using MT = typename phi::dtype::MPTypeTrait::Type; + static_assert(std::is_same::value || std::is_same::value); + static_assert(std::is_same::value || std::is_same::value); + using Trait = FusedGEMMGradTrait; -template <> -struct FusedGEMMGradTrait { - static constexpr auto kXGradA = FusedGEMMGradInType::kDY; - static constexpr auto kXGradB = FusedGEMMGradInType::kDZ; - static constexpr auto kXGradATrans = true; - static constexpr auto kXGradBTrans = true; - - static constexpr auto kYGradA = FusedGEMMGradInType::kDZ; - static constexpr auto kYGradB = FusedGEMMGradInType::kDX; - static constexpr auto kYGradATrans = true; - static constexpr auto kYGradBTrans = true; -}; + if (dx) { + constexpr auto kXGradAIsDZ = (Trait::kXGradA == FusedGEMMGradInType::kDZ); + auto fused_type = BwdFusedEpilogueSetter::SetForDx(activation_grad); + void* reserve_data = (fused_type == kMatmulGrad) + ? nullptr + : const_cast(reserve_space->data()); + dev_ctx.Alloc(dx, dx->numel() * sizeof(DXT)); + phi::funcs::LinearGradWithCublasLt::Run( + dev_ctx, + dout, + y, + dx, + nullptr, + reserve_data, + M, + N, + K, + fused_type, + Trait::kXGradATrans, + Trait::kXGradBTrans, + use_addto_dx, + kXGradAIsDZ); + } + if (dy) { + auto fused_type = + BwdFusedEpilogueSetter::SetForDy(dev_ctx, dbias); + constexpr auto kYGradAIsDZ = (Trait::kYGradA == FusedGEMMGradInType::kDZ); + // Caution: DYT is in front of DXT in this template arguments. + dev_ctx.Alloc(dy, dy->numel() * sizeof(DYT)); + phi::funcs::LinearGradWithCublasLt::Run( + dev_ctx, + dout, + x, + dy, + dbias ? static_cast(dbias->data()) : nullptr, + nullptr, + M, + N, + K, + fused_type, + Trait::kYGradATrans, + Trait::kYGradBTrans, + use_addto_dy, + kYGradAIsDZ, + /*is_dx=*/false); + } +} static constexpr auto BoolToCuBlasEnum(bool transpose) { return transpose ? CUBLAS_OP_T : CUBLAS_OP_N; @@ -567,20 +609,21 @@ static cublasLtEpilogue_t GetEpilogueGradType( } template -void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, - const phi::DenseTensor* dout, - const phi::DenseTensor* x, - const phi::DenseTensor* y, - const phi::DenseTensor* reserve_space, - int64_t M, - int64_t N, - int64_t K, - const std::string activation_grad, - phi::DenseTensor* dx, - phi::DenseTensor* dy, - phi::DenseTensor* dbias, - bool use_addto_dx, - bool use_addto_dy) { +void ComputeFusedGemmEpilogueBackwardImplDev( + const phi::GPUContext& dev_ctx, + const phi::DenseTensor* dout, + const phi::DenseTensor* x, + const phi::DenseTensor* y, + const phi::DenseTensor* reserve_space, + int64_t M, + int64_t N, + int64_t K, + const std::string activation_grad, + phi::DenseTensor* dx, + phi::DenseTensor* dy, + phi::DenseTensor* dbias, + bool use_addto_dx, + bool use_addto_dy) { using MT = typename phi::dtype::MPTypeTrait::Type; constexpr bool kIsValidDataType = (std::is_same::value || std::is_same::value) && diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index a223cd7c738..0920c15d358 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -559,7 +559,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, // max index to read int idx_max = (i < local_batches) ? element_count : 0; int idx_max_v = idx_max / kVSize; - +#pragma unroll // read data for (int it = 0; it < kIterationsV; ++it) { int src_idx = threadIdx.x + it * kWarpSize; @@ -659,7 +659,7 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss, // loss phi::WarpReduceSum(sumloss); - +#pragma unroll for (int i = 0; i < kBatchSize; i++) { if (i >= local_batches) break; loss[first_batch + i] = sumloss[i]; -- GitLab