未验证 提交 5da1a27b 编写于 作者: S sneaxiy 提交者: GitHub

Remove fluid deps in fused_linear_param_grad_add_kernel.cu (#51975)

* remove fluid deps in fused_linear_param_grad_add_kernel

* fix compile error

* fix ut error

* follow comments
上级 101c9bb0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,40 +14,4 @@
#pragma once
#include <type_traits>
#include <utility>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
template <typename ReleaseCallback>
class ScopeGuard {
public:
explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {}
~ScopeGuard() { callback_(); }
private:
DISABLE_COPY_AND_ASSIGN(ScopeGuard);
private:
ReleaseCallback callback_;
};
// Two macros are needed here.
// See:
// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names
#define _PADDLE_CONCAT_TOKEN(x, y) x##y
#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y)
#define DEFINE_PADDLE_SCOPE_GUARD(...) \
auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \
::paddle::framework::ScopeGuard<typename std::remove_reference< \
decltype(PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))>::type> \
PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \
PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))
} // namespace framework
} // namespace paddle
#include "paddle/phi/core/scope_guard.h"
......@@ -14,13 +14,13 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace paddle {
......@@ -129,7 +129,7 @@ class AttnMatMul {
bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) {
ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
d_output,
input,
weight,
......
......@@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle {
namespace operators {
......
......@@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle {
namespace operators {
......@@ -151,7 +151,7 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
<< ", activation=" << activation_grad
<< ", reserve_space=" << reserve_space;
ComputeFusedGemmEpilogueBackward<T>(dev_ctx,
phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx,
dout,
x,
y,
......
......@@ -30,11 +30,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......@@ -1871,7 +1871,8 @@ class CublasFusedMLP {
const auto *x_data = x->data<T>();
const auto *w_data = weight->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle,
auto algo = phi::funcs::GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle,
operation_desc_,
w_desc_,
x_desc_,
......
......@@ -726,6 +726,16 @@
optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout
- op : gather
args : (Tensor x, Tensor index, Scalar(int) axis=0)
output : Tensor(out)
......
......@@ -614,16 +614,6 @@
data_type : x
backward : fused_dropout_add_grad
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout
- op : gather_nd
args : (Tensor x, Tensor index)
output : Tensor
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <utility>
#include "paddle/phi/core/macros.h"
namespace phi {
template <typename ReleaseCallback>
class ScopeGuard {
public:
explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {}
~ScopeGuard() { callback_(); }
private:
DISABLE_COPY_AND_ASSIGN(ScopeGuard);
private:
ReleaseCallback callback_;
};
// Two macros are needed here.
// See:
// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names
#define _PADDLE_CONCAT_TOKEN(x, y) x##y
#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y)
#define DEFINE_PADDLE_SCOPE_GUARD(...) \
auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \
::phi::ScopeGuard<typename std::remove_reference< \
decltype(PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))>::type> \
PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \
PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))
} // namespace phi
......@@ -27,21 +27,21 @@ limitations under the License. */
#if CUDA_VERSION >= 11060
#include "gflags/gflags.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times);
namespace paddle {
namespace operators {
namespace phi {
namespace funcs {
class GemmEpilogueAlgoCache {
public:
......@@ -88,9 +88,9 @@ class GemmEpilogueAlgoCache {
cublasLtMatmulPreference_t preference;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceCreate(&preference));
phi::dynload::cublasLtMatmulPreferenceCreate(&preference));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceSetAttribute(
phi::dynload::cublasLtMatmulPreferenceSetAttribute(
preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size,
......@@ -100,8 +100,7 @@ class GemmEpilogueAlgoCache {
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
requested_algo_count_);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulAlgoGetHeuristic(
lt_handle,
phi::dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle,
op_desc,
a_desc,
b_desc,
......@@ -115,10 +114,10 @@ class GemmEpilogueAlgoCache {
PADDLE_ENFORCE_GT(
returned_results,
0,
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceDestroy(preference));
phi::dynload::cublasLtMatmulPreferenceDestroy(preference));
int best_algo_idx = -1;
float best_algo_time = 0;
......@@ -126,8 +125,8 @@ class GemmEpilogueAlgoCache {
// Run 100 times for warmup
int warmup_algo_idx = 0;
for (int t = 0; t < 100; t++) {
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle,
cublasStatus_t status =
phi::dynload::cublasLtMatmul(lt_handle,
op_desc,
alpha,
a,
......@@ -147,8 +146,8 @@ class GemmEpilogueAlgoCache {
t = -1;
warmup_algo_idx += 1;
if (warmup_algo_idx == requested_algo_count_) {
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!"));
PADDLE_THROW(
phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
}
}
}
......@@ -164,7 +163,7 @@ class GemmEpilogueAlgoCache {
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
cublasStatus_t status =
platform::dynload::cublasLtMatmul(lt_handle,
phi::dynload::cublasLtMatmul(lt_handle,
op_desc,
alpha,
a,
......@@ -204,7 +203,7 @@ class GemmEpilogueAlgoCache {
if (best_algo_idx == -1) {
PADDLE_THROW(
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
}
ret = heuristic_results[best_algo_idx].algo;
......@@ -235,8 +234,7 @@ class GemmEpilogueAlgoCache {
int trans_a, trans_b;
uint32_t epilogue;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&trans_a,
......@@ -244,8 +242,7 @@ class GemmEpilogueAlgoCache {
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_a));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&trans_b,
......@@ -253,8 +250,7 @@ class GemmEpilogueAlgoCache {
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_b));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epilogue,
......@@ -272,8 +268,7 @@ class GemmEpilogueAlgoCache {
uint64_t row, col;
int64_t ld, batch_offset;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_TYPE,
&dtype,
......@@ -281,8 +276,7 @@ class GemmEpilogueAlgoCache {
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(dtype));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
&batch,
......@@ -290,31 +284,19 @@ class GemmEpilogueAlgoCache {
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(batch));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_ROWS,
&row,
sizeof(row),
&size_to_write));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(row));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_COLS,
&col,
sizeof(col),
&size_to_write));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(col));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(ld));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
desc,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
&batch_offset,
......@@ -341,7 +323,7 @@ static cublasLtEpilogue_t GetEpilogueType(const std::string& activation,
} else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(phi::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.",
......@@ -381,24 +363,24 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
}
cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy)));
cublasLtEpilogue_t epiloque_func =
GetEpilogueType(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func,
sizeof(epiloque_func)));
const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data,
......@@ -420,15 +402,13 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
......@@ -437,28 +417,28 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, M, K, M));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M, K, M));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&x_desc, mat_type, K, M, K));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K, M, K));
}
if (trans_y) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, K, N, K));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K, N, K));
} else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&y_desc, mat_type, N, K, N));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N, K, N));
}
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
&out_desc, mat_type, N, M, N));
PADDLE_ENFORCE_GPU_SUCCESS(
phi::dynload::cublasLtMatrixLayoutCreate(&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc(
auto workspace = memory_utils::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
......@@ -482,7 +462,7 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
stream,
workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
operation_desc,
&alpha,
y_data,
......@@ -500,13 +480,11 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
phi::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
phi::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
......@@ -579,7 +557,7 @@ static cublasLtEpilogue_t GetEpilogueGradType(
} else if (activation_grad == "gelu_grad") {
return CUBLASLT_EPILOGUE_DGELU;
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
PADDLE_THROW(phi::errors::InvalidArgument(
"The activation_grad attribute of fused_gemm_epilogue op should "
"be one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation_grad=%s.",
......@@ -644,18 +622,18 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
for (auto desc : descs) {
if (desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(desc));
phi::dynload::cublasLtMatrixLayoutDestroy(desc));
}
}
if (dx_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
phi::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
}
if (dy_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
phi::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
}
});
......@@ -673,16 +651,16 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (TransX) {
dx_dout_desc = &dout_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_row, z_col, z_row));
} else {
dx_dout_desc = &dout_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_col, z_row, z_col));
}
dx_y_desc = &y_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_y_desc, mat_type, y_col, y_row, y_col));
auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
......@@ -690,23 +668,21 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
&dx_desc,
phi::backends::gpu::ToCudaDataType<DXT>(),
x_col,
x_row,
x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
......@@ -714,8 +690,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
cublasLtEpilogue_t epiloque_func_for_dx =
GetEpilogueGradType(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dx,
......@@ -723,22 +698,20 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (activation_grad != "none") {
auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
&aux_data,
sizeof(aux_data)));
int64_t aux_ld = TransX ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
&aux_ld,
sizeof(aux_ld)));
}
auto dx_workspace = memory::Alloc(
auto dx_workspace = memory_utils::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
......@@ -764,8 +737,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dx_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
dx_operation_desc,
&alpha,
b_data,
......@@ -791,21 +763,19 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (TransX) {
dy_dout_desc = &dout_trans_desc;
if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_row, z_col, z_row));
}
} else {
dy_dout_desc = &dout_desc;
if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dy_dout_desc, mat_type, z_col, z_row, z_col));
}
}
dy_x_desc = &x_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dy_x_desc, mat_type, x_col, x_row, x_col));
auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
......@@ -813,24 +783,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
&dy_desc,
phi::backends::gpu::ToCudaDataType<DYT>(),
y_col,
y_row,
y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSB,
&a_trans,
sizeof(a_trans)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_TRANSA,
&b_trans,
......@@ -847,8 +815,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
}
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func_for_dy,
......@@ -857,15 +824,14 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (dbias) {
auto* dbias_data =
dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
dy_operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&dbias_data,
sizeof(dbias_data)));
}
auto dy_workspace = memory::Alloc(
auto dy_workspace = memory_utils::Alloc(
dev_ctx.GetPlace(),
workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
......@@ -890,8 +856,7 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dy_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle,
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
dy_operation_desc,
&alpha,
b_data,
......@@ -1002,7 +967,7 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
}
}
} // namespace operators
} // namespace paddle
} // namespace funcs
} // namespace phi
#endif
#endif
......@@ -15,7 +15,7 @@
#include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#endif
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h"
......@@ -41,7 +41,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
const bool fuse_bias_grad = kIsMultiPrecision && dweight_out;
if (dweight_out) {
paddle::operators::ComputeFusedGemmEpilogueBackward<T, T, MT>(
phi::funcs::ComputeFusedGemmEpilogueBackward<T, T, MT>(
ctx,
&dout,
&x,
......@@ -184,10 +184,6 @@ void FusedLinearParamGradAdd(const Context &ctx,
FusedLinearParamGradAddImpl<T, T, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
}
if (VLOG_IS_ON(kLogLevel)) {
ctx.Wait();
}
}
#else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册