diff --git a/paddle/fluid/framework/scope_guard.h b/paddle/fluid/framework/scope_guard.h index 9c741f7bfc5734e4966d12d409d4d8fbc8d2a350..cb4f37b12cd57cb7fbda23e4f0d79548c72d0dff 100644 --- a/paddle/fluid/framework/scope_guard.h +++ b/paddle/fluid/framework/scope_guard.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,40 +14,4 @@ #pragma once -#include -#include - -#include "paddle/fluid/platform/macros.h" - -namespace paddle { -namespace framework { - -template -class ScopeGuard { - public: - explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {} - - ~ScopeGuard() { callback_(); } - - private: - DISABLE_COPY_AND_ASSIGN(ScopeGuard); - - private: - ReleaseCallback callback_; -}; - -// Two macros are needed here. -// See: -// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names -#define _PADDLE_CONCAT_TOKEN(x, y) x##y -#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y) - -#define DEFINE_PADDLE_SCOPE_GUARD(...) \ - auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \ - ::paddle::framework::ScopeGuard::type> \ - PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \ - PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__)) - -} // namespace framework -} // namespace paddle +#include "paddle/phi/core/scope_guard.h" diff --git a/paddle/fluid/operators/fused/attn_gemm.h b/paddle/fluid/operators/fused/attn_gemm.h index 1d6320e70132c69e97b1fc0bf1e4b8de33e5d201..9ec25c110e56ad56f84ee8ef2036af27571b4ce2 100644 --- a/paddle/fluid/operators/fused/attn_gemm.h +++ b/paddle/fluid/operators/fused/attn_gemm.h @@ -14,13 +14,13 @@ limitations under the License. */ #pragma once -#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" namespace paddle { @@ -129,21 +129,21 @@ class AttnMatMul { bool fused = false) { #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 if (compute_bias_ && fused) { - ComputeFusedGemmEpilogueBackward(dev_ctx_, - d_output, - input, - weight, - nullptr, - bsz_seq_, // M - output_size_, // N - input_size_, // K - transA_, - transB_, - "none", - d_input, - d_weight, - d_bias, - use_addto); + phi::funcs::ComputeFusedGemmEpilogueBackward(dev_ctx_, + d_output, + input, + weight, + nullptr, + bsz_seq_, // M + output_size_, // N + input_size_, // K + transA_, + transB_, + "none", + d_input, + d_weight, + d_bias, + use_addto); return; } #endif diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc index 31494ad0093ddf8f629cabb91d066e3bc121012f..dc1c9c3f0af3775d23f5a2b53c0fe1b60bc04147 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cc @@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index 8e523e88e028fa1877b346052801880178024875..0e66ebb3d5a9d18b806afbae0904b5ceb23efca7 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/float16.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" namespace paddle { namespace operators { @@ -151,20 +151,20 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { << ", activation=" << activation_grad << ", reserve_space=" << reserve_space; - ComputeFusedGemmEpilogueBackward(dev_ctx, - dout, - x, - y, - reserve_space, - M, - N, - K, - trans_x, - trans_y, - activation_grad, - dx, - dy, - dbias); + phi::funcs::ComputeFusedGemmEpilogueBackward(dev_ctx, + dout, + x, + y, + reserve_space, + M, + N, + K, + trans_x, + trans_y, + activation_grad, + dx, + dy, + dbias); } }; #endif diff --git a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h index 5d0dab032e012c22d6911dada1770350ce065bc9..4769433317f0f9983118138ab52771240d633499 100644 --- a/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h +++ b/paddle/fluid/operators/fused/fused_multi_transformer_op.cu.h @@ -30,11 +30,11 @@ limitations under the License. */ #include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" -#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/phi/api/include/tensor.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #include "paddle/phi/kernels/funcs/math_function.h" #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) @@ -1871,19 +1871,20 @@ class CublasFusedMLP { const auto *x_data = x->data(); const auto *w_data = weight->data(); - auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, - operation_desc_, - w_desc_, - x_desc_, - out_desc_, - alpha, - beta, - w_data, - x_data, - out_data, - stream, - workspace->ptr(), - workspace_size); + auto algo = phi::funcs::GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, + operation_desc_, + w_desc_, + x_desc_, + out_desc_, + alpha, + beta, + w_data, + x_data, + out_data, + stream, + workspace->ptr(), + workspace_size); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmul(lt_handle, diff --git a/paddle/phi/api/yaml/legacy_ops.yaml b/paddle/phi/api/yaml/legacy_ops.yaml index 54329cef0ad7a5ee423d502d0b5a520b98b2eb1c..f399ab6e577399b4361d444cc4b99603e5499862 100755 --- a/paddle/phi/api/yaml/legacy_ops.yaml +++ b/paddle/phi/api/yaml/legacy_ops.yaml @@ -726,6 +726,16 @@ optional : skip_update, master_params inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) +- op : fused_linear_param_grad_add + args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true) + output : Tensor(dweight_out), Tensor(dbias_out) + infer_meta: + func : FusedLinearParamGradAddInferMeta + optional : dweight, dbias + kernel: + func : fused_linear_param_grad_add + data_type : dout + - op : gather args : (Tensor x, Tensor index, Scalar(int) axis=0) output : Tensor(out) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 9b9fe6afb349647e3bc5fb1debaf0d8157c2b1a9..ccf2f9bfbab81860e65e6062b4c0309ab093f8a1 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -614,16 +614,6 @@ data_type : x backward : fused_dropout_add_grad -- op : fused_linear_param_grad_add - args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true) - output : Tensor(dweight_out), Tensor(dbias_out) - infer_meta: - func : FusedLinearParamGradAddInferMeta - optional : dweight, dbias - kernel: - func : fused_linear_param_grad_add - data_type : dout - - op : gather_nd args : (Tensor x, Tensor index) output : Tensor diff --git a/paddle/phi/core/scope_guard.h b/paddle/phi/core/scope_guard.h new file mode 100644 index 0000000000000000000000000000000000000000..d682d92117f91d0c76bb13976da7bc9af4eb644b --- /dev/null +++ b/paddle/phi/core/scope_guard.h @@ -0,0 +1,51 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/phi/core/macros.h" + +namespace phi { + +template +class ScopeGuard { + public: + explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {} + + ~ScopeGuard() { callback_(); } + + private: + DISABLE_COPY_AND_ASSIGN(ScopeGuard); + + private: + ReleaseCallback callback_; +}; + +// Two macros are needed here. +// See: +// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names +#define _PADDLE_CONCAT_TOKEN(x, y) x##y +#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y) + +#define DEFINE_PADDLE_SCOPE_GUARD(...) \ + auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \ + ::phi::ScopeGuard::type> \ + PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \ + PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__)) + +} // namespace phi diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h similarity index 67% rename from paddle/fluid/operators/fused/fused_gemm_epilogue_op.h rename to paddle/phi/kernels/funcs/fused_gemm_epilogue.h index 9fee9dc119cff11c4fbb9b88c5acd5d29a9fad6f..8edfea5f0789cb9e53be056b9b37d1c8c50593e0 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h +++ b/paddle/phi/kernels/funcs/fused_gemm_epilogue.h @@ -27,21 +27,21 @@ limitations under the License. */ #if CUDA_VERSION >= 11060 #include "gflags/gflags.h" -#include "paddle/fluid/framework/scope_guard.h" -#include "paddle/fluid/memory/memory.h" -#include "paddle/fluid/platform/dynload/cublasLt.h" -#include "paddle/fluid/platform/enforce.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/all_context.h" +#include "paddle/phi/backends/dynload/cublasLt.h" #include "paddle/phi/backends/gpu/cuda/cuda_helper.h" #include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/common/float16.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/scope_guard.h" #include "paddle/utils/optional.h" DECLARE_int64(cublaslt_exhaustive_search_times); -namespace paddle { -namespace operators { +namespace phi { +namespace funcs { class GemmEpilogueAlgoCache { public: @@ -88,9 +88,9 @@ class GemmEpilogueAlgoCache { cublasLtMatmulPreference_t preference; PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); + phi::dynload::cublasLtMatmulPreferenceCreate(&preference)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceSetAttribute( + phi::dynload::cublasLtMatmulPreferenceSetAttribute( preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, @@ -100,25 +100,24 @@ class GemmEpilogueAlgoCache { std::vector heuristic_results( requested_algo_count_); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulAlgoGetHeuristic( - lt_handle, - op_desc, - a_desc, - b_desc, - c_desc, - c_desc, - preference, - requested_algo_count_, - heuristic_results.data(), - &returned_results)); + phi::dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle, + op_desc, + a_desc, + b_desc, + c_desc, + c_desc, + preference, + requested_algo_count_, + heuristic_results.data(), + &returned_results)); PADDLE_ENFORCE_GT( returned_results, 0, - platform::errors::Unavailable("No GEMM epilogue algorithm support!")); + phi::errors::Unavailable("No GEMM epilogue algorithm support!")); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulPreferenceDestroy(preference)); + phi::dynload::cublasLtMatmulPreferenceDestroy(preference)); int best_algo_idx = -1; float best_algo_time = 0; @@ -126,29 +125,29 @@ class GemmEpilogueAlgoCache { // Run 100 times for warmup int warmup_algo_idx = 0; for (int t = 0; t < 100; t++) { - cublasStatus_t status = platform::dynload::cublasLtMatmul( - lt_handle, - op_desc, - alpha, - a, - a_desc, - b, - b_desc, - beta, - c, - c_desc, - c, - c_desc, - &heuristic_results[warmup_algo_idx].algo, - workspace, - workspace_size, - stream); + cublasStatus_t status = + phi::dynload::cublasLtMatmul(lt_handle, + op_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + c, + c_desc, + &heuristic_results[warmup_algo_idx].algo, + workspace, + workspace_size, + stream); if (status != CUBLAS_STATUS_SUCCESS) { t = -1; warmup_algo_idx += 1; if (warmup_algo_idx == requested_algo_count_) { - PADDLE_THROW(platform::errors::Unavailable( - "No GEMM epilogue algorithm support!")); + PADDLE_THROW( + phi::errors::Unavailable("No GEMM epilogue algorithm support!")); } } } @@ -164,22 +163,22 @@ class GemmEpilogueAlgoCache { PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); cublasStatus_t status = - platform::dynload::cublasLtMatmul(lt_handle, - op_desc, - alpha, - a, - a_desc, - b, - b_desc, - beta, - c, - c_desc, - c, - c_desc, - &heuristic_results[algo_idx].algo, - workspace, - workspace_size, - stream); + phi::dynload::cublasLtMatmul(lt_handle, + op_desc, + alpha, + a, + a_desc, + b, + b_desc, + beta, + c, + c_desc, + c, + c_desc, + &heuristic_results[algo_idx].algo, + workspace, + workspace_size, + stream); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event)); @@ -204,7 +203,7 @@ class GemmEpilogueAlgoCache { if (best_algo_idx == -1) { PADDLE_THROW( - platform::errors::Unavailable("No GEMM epilogue algorithm support!")); + phi::errors::Unavailable("No GEMM epilogue algorithm support!")); } ret = heuristic_results[best_algo_idx].algo; @@ -235,31 +234,28 @@ class GemmEpilogueAlgoCache { int trans_a, trans_b; uint32_t epilogue; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescGetAttribute( - desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &trans_a, - sizeof(trans_a), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute( + desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &trans_a, + sizeof(trans_a), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(trans_a)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescGetAttribute( - desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &trans_b, - sizeof(trans_b), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute( + desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &trans_b, + sizeof(trans_b), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(trans_b)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescGetAttribute( - desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epilogue, - sizeof(epilogue), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute( + desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue, + sizeof(epilogue), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(epilogue)); } @@ -272,54 +268,40 @@ class GemmEpilogueAlgoCache { uint64_t row, col; int64_t ld, batch_offset; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_TYPE, - &dtype, - sizeof(dtype), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_TYPE, + &dtype, + sizeof(dtype), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(dtype)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, - &batch, - sizeof(batch), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, + &batch, + sizeof(batch), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(batch)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_ROWS, - &row, - sizeof(row), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write)); HashValue_(seed, hash_fn, static_cast(row)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_COLS, - &col, - sizeof(col), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write)); HashValue_(seed, hash_fn, static_cast(col)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); HashValue_(seed, hash_fn, static_cast(ld)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutGetAttribute( - desc, - CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, - &batch_offset, - sizeof(batch_offset), - &size_to_write)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute( + desc, + CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, + &batch_offset, + sizeof(batch_offset), + &size_to_write)); HashValue_(seed, hash_fn, static_cast(batch_offset)); } @@ -341,7 +323,7 @@ static cublasLtEpilogue_t GetEpilogueType(const std::string& activation, } else if (activation == "none") { return CUBLASLT_EPILOGUE_BIAS; } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "The activation attribute of fused_gemm_epilogue op should be" " one of {\"none\", \"relu\", \"gelu\"}. But received %s." "But received activation=%s.", @@ -381,24 +363,24 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, } cublasLtMatmulDesc_t operation_desc = NULL; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate( &operation_desc, compute_type, scale_type)); cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx))); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy))); cublasLtEpilogue_t epiloque_func = GetEpilogueType(activation, enable_auxiliary); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func, sizeof(epiloque_func))); const T* bias_data = bias->data(); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias_data, @@ -420,45 +402,43 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, void* aux_data = reserve_space->data(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &aux_data, - sizeof(aux_data))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, + sizeof(aux_data))); int64_t aux_ld = N; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &aux_ld, - sizeof(aux_ld))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &aux_ld, + sizeof(aux_ld))); } cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; if (trans_x) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc, mat_type, M, K, M)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M, K, M)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &x_desc, mat_type, K, M, K)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K, M, K)); } if (trans_y) { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &y_desc, mat_type, K, N, K)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K, N, K)); } else { - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &y_desc, mat_type, N, K, N)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N, K, N)); } - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( - &out_desc, mat_type, N, M, N)); + PADDLE_ENFORCE_GPU_SUCCESS( + phi::dynload::cublasLtMatrixLayoutCreate(&out_desc, mat_type, N, M, N)); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); // NOTE(zengjinle): I do not know whether the 4MB workspace size is // "enough". I just followed the settings from the NVIDIA MLPerf BERT code. size_t workspace_size = static_cast(4) * 1024 * 1024; cudaStream_t stream = dev_ctx.stream(); - memory::allocation::AllocationPtr workspace = memory::Alloc( + auto workspace = memory_utils::Alloc( dev_ctx.GetPlace(), workspace_size, phi::Stream(reinterpret_cast(dev_ctx.stream()))); @@ -482,31 +462,29 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, stream, workspace->ptr(), workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle, - operation_desc, - &alpha, - y_data, - y_desc, - x_data, - x_desc, - &beta, - out_data, - out_desc, - out_data, - out_desc, - algo, - workspace->ptr(), - workspace_size, - stream)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle, + operation_desc, + &alpha, + y_data, + y_desc, + x_data, + x_desc, + &beta, + out_data, + out_desc, + out_data, + out_desc, + algo, + workspace->ptr(), + workspace_size, + stream)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(operation_desc)); + phi::dynload::cublasLtMatmulDescDestroy(operation_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(x_desc)); PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(out_desc)); + phi::dynload::cublasLtMatrixLayoutDestroy(out_desc)); } enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; @@ -579,7 +557,7 @@ static cublasLtEpilogue_t GetEpilogueGradType( } else if (activation_grad == "gelu_grad") { return CUBLASLT_EPILOGUE_DGELU; } else { - PADDLE_THROW(platform::errors::InvalidArgument( + PADDLE_THROW(phi::errors::InvalidArgument( "The activation_grad attribute of fused_gemm_epilogue op should " "be one of {\"none\", \"relu\", \"gelu\"}. But received %s." "But received activation_grad=%s.", @@ -644,18 +622,18 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, for (auto desc : descs) { if (desc) { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutDestroy(desc)); + phi::dynload::cublasLtMatrixLayoutDestroy(desc)); } } if (dx_operation_desc) { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); + phi::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); } if (dy_operation_desc) { PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); + phi::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); } }); @@ -673,16 +651,16 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, if (TransX) { dx_dout_desc = &dout_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( dx_dout_desc, mat_type, z_row, z_col, z_row)); } else { dx_dout_desc = &dout_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( dx_dout_desc, mat_type, z_col, z_row, z_col)); } dx_y_desc = &y_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( dx_y_desc, mat_type, y_col, y_row, y_col)); auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc); @@ -690,55 +668,50 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans); auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( &dx_desc, phi::backends::gpu::ToCudaDataType(), x_col, x_row, x_col)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate( &dx_operation_desc, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &a_trans, - sizeof(a_trans))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &b_trans, - sizeof(b_trans))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &a_trans, + sizeof(a_trans))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &b_trans, + sizeof(b_trans))); cublasLtEpilogue_t epiloque_func_for_dx = GetEpilogueGradType(activation_grad); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func_for_dx, - sizeof(epiloque_func_for_dx))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dx, + sizeof(epiloque_func_for_dx))); if (activation_grad != "none") { auto* aux_data = reserve_space->data(); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, - &aux_data, - sizeof(aux_data))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, + &aux_data, + sizeof(aux_data))); int64_t aux_ld = TransX ? M : K; - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, - &aux_ld, - sizeof(aux_ld))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dx_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, + &aux_ld, + sizeof(aux_ld))); } - auto dx_workspace = memory::Alloc( + auto dx_workspace = memory_utils::Alloc( dev_ctx.GetPlace(), workspace_size, phi::Stream(reinterpret_cast(dev_ctx.stream()))); @@ -764,23 +737,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, dx_workspace->ptr(), workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - dx_operation_desc, - &alpha, - b_data, - b_desc, - a_data, - a_desc, - &beta_dx, - dx_data, - dx_desc, - dx_data, - dx_desc, - algo, - dx_workspace->ptr(), - workspace_size, - stream)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle, + dx_operation_desc, + &alpha, + b_data, + b_desc, + a_data, + a_desc, + &beta_dx, + dx_data, + dx_desc, + dx_data, + dx_desc, + algo, + dx_workspace->ptr(), + workspace_size, + stream)); } // dy = func(dout, x) @@ -791,21 +763,19 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, if (TransX) { dy_dout_desc = &dout_trans_desc; if (dout_trans_desc == nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dy_dout_desc, mat_type, z_row, z_col, z_row)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_row, z_col, z_row)); } } else { dy_dout_desc = &dout_desc; if (dout_desc == nullptr) { - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatrixLayoutCreate( - dy_dout_desc, mat_type, z_col, z_row, z_col)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( + dy_dout_desc, mat_type, z_col, z_row, z_col)); } } dy_x_desc = &x_trans_desc; - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( dy_x_desc, mat_type, x_col, x_row, x_col)); auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc); @@ -813,28 +783,26 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans); auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate( &dy_desc, phi::backends::gpu::ToCudaDataType(), y_col, y_row, y_col)); - PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate( &dy_operation_desc, compute_type, scale_type)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSB, - &a_trans, - sizeof(a_trans))); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_TRANSA, - &b_trans, - sizeof(b_trans))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSB, + &a_trans, + sizeof(a_trans))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_TRANSA, + &b_trans, + sizeof(b_trans))); cublasLtEpilogue_t epiloque_func_for_dy; if (dbias == nullptr) { @@ -847,25 +815,23 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, } } - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_EPILOGUE, - &epiloque_func_for_dy, - sizeof(epiloque_func_for_dy))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_EPILOGUE, + &epiloque_func_for_dy, + sizeof(epiloque_func_for_dy))); if (dbias) { auto* dbias_data = dev_ctx.Alloc(dbias, dbias->numel() * sizeof(DYT)); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmulDescSetAttribute( - dy_operation_desc, - CUBLASLT_MATMUL_DESC_BIAS_POINTER, - &dbias_data, - sizeof(dbias_data))); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute( + dy_operation_desc, + CUBLASLT_MATMUL_DESC_BIAS_POINTER, + &dbias_data, + sizeof(dbias_data))); } - auto dy_workspace = memory::Alloc( + auto dy_workspace = memory_utils::Alloc( dev_ctx.GetPlace(), workspace_size, phi::Stream(reinterpret_cast(dev_ctx.stream()))); @@ -890,23 +856,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, dy_workspace->ptr(), workspace_size); - PADDLE_ENFORCE_GPU_SUCCESS( - platform::dynload::cublasLtMatmul(lt_handle, - dy_operation_desc, - &alpha, - b_data, - b_desc, - a_data, - a_desc, - &beta_dy, - dy_data, - dy_desc, - dy_data, - dy_desc, - algo, - dy_workspace->ptr(), - workspace_size, - stream)); + PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle, + dy_operation_desc, + &alpha, + b_data, + b_desc, + a_data, + a_desc, + &beta_dy, + dy_data, + dy_desc, + dy_data, + dy_desc, + algo, + dy_workspace->ptr(), + workspace_size, + stream)); } } @@ -1002,7 +967,7 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx, } } -} // namespace operators -} // namespace paddle +} // namespace funcs +} // namespace phi #endif #endif diff --git a/paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu index 7e13776da25146609c43b83cba33ca3607ee2176..a3a31cf069f4b3033a679d9a9ddeb17c1fc8d5eb 100644 --- a/paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu +++ b/paddle/phi/kernels/fusion/gpu/fused_linear_param_grad_add_kernel.cu @@ -15,7 +15,7 @@ #include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h" #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 -#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" +#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h" #endif #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/data_type.h" @@ -41,7 +41,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, const bool fuse_bias_grad = kIsMultiPrecision && dweight_out; if (dweight_out) { - paddle::operators::ComputeFusedGemmEpilogueBackward( + phi::funcs::ComputeFusedGemmEpilogueBackward( ctx, &dout, &x, @@ -184,10 +184,6 @@ void FusedLinearParamGradAdd(const Context &ctx, FusedLinearParamGradAddImpl( ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); } - - if (VLOG_IS_ON(kLogLevel)) { - ctx.Wait(); - } } #else