未验证 提交 5da1a27b 编写于 作者: S sneaxiy 提交者: GitHub

Remove fluid deps in fused_linear_param_grad_add_kernel.cu (#51975)

* remove fluid deps in fused_linear_param_grad_add_kernel

* fix compile error

* fix ut error

* follow comments
上级 101c9bb0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -14,40 +14,4 @@ ...@@ -14,40 +14,4 @@
#pragma once #pragma once
#include <type_traits> #include "paddle/phi/core/scope_guard.h"
#include <utility>
#include "paddle/fluid/platform/macros.h"
namespace paddle {
namespace framework {
template <typename ReleaseCallback>
class ScopeGuard {
public:
explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {}
~ScopeGuard() { callback_(); }
private:
DISABLE_COPY_AND_ASSIGN(ScopeGuard);
private:
ReleaseCallback callback_;
};
// Two macros are needed here.
// See:
// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names
#define _PADDLE_CONCAT_TOKEN(x, y) x##y
#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y)
#define DEFINE_PADDLE_SCOPE_GUARD(...) \
auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \
::paddle::framework::ScopeGuard<typename std::remove_reference< \
decltype(PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))>::type> \
PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \
PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))
} // namespace framework
} // namespace paddle
...@@ -14,13 +14,13 @@ limitations under the License. */ ...@@ -14,13 +14,13 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/broadcast_function.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h"
namespace paddle { namespace paddle {
...@@ -129,21 +129,21 @@ class AttnMatMul { ...@@ -129,21 +129,21 @@ class AttnMatMul {
bool fused = false) { bool fused = false) {
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
if (compute_bias_ && fused) { if (compute_bias_ && fused) {
ComputeFusedGemmEpilogueBackward<T>(dev_ctx_, phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx_,
d_output, d_output,
input, input,
weight, weight,
nullptr, nullptr,
bsz_seq_, // M bsz_seq_, // M
output_size_, // N output_size_, // N
input_size_, // K input_size_, // K
transA_, transA_,
transB_, transB_,
"none", "none",
d_input, d_input,
d_weight, d_weight,
d_bias, d_bias,
use_addto); use_addto);
return; return;
} }
#endif #endif
......
...@@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,9 +13,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
...@@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -13,12 +13,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h" #include "paddle/phi/kernels/funcs/blas/blaslt_impl.cu.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -151,20 +151,20 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> { ...@@ -151,20 +151,20 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
<< ", activation=" << activation_grad << ", activation=" << activation_grad
<< ", reserve_space=" << reserve_space; << ", reserve_space=" << reserve_space;
ComputeFusedGemmEpilogueBackward<T>(dev_ctx, phi::funcs::ComputeFusedGemmEpilogueBackward<T>(dev_ctx,
dout, dout,
x, x,
y, y,
reserve_space, reserve_space,
M, M,
N, N,
K, K,
trans_x, trans_x,
trans_y, trans_y,
activation_grad, activation_grad,
dx, dx,
dy, dy,
dbias); dbias);
} }
}; };
#endif #endif
......
...@@ -30,11 +30,11 @@ limitations under the License. */ ...@@ -30,11 +30,11 @@ limitations under the License. */
#include "paddle/fluid/operators/fused/attn_gemm.h" #include "paddle/fluid/operators/fused/attn_gemm.h"
#include "paddle/fluid/operators/fused/fmha_ref.h" #include "paddle/fluid/operators/fused/fmha_ref.h"
#include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h" #include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
...@@ -1871,19 +1871,20 @@ class CublasFusedMLP { ...@@ -1871,19 +1871,20 @@ class CublasFusedMLP {
const auto *x_data = x->data<T>(); const auto *x_data = x->data<T>();
const auto *w_data = weight->data<T>(); const auto *w_data = weight->data<T>();
auto algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(lt_handle, auto algo = phi::funcs::GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
operation_desc_, lt_handle,
w_desc_, operation_desc_,
x_desc_, w_desc_,
out_desc_, x_desc_,
alpha, out_desc_,
beta, alpha,
w_data, beta,
x_data, w_data,
out_data, x_data,
stream, out_data,
workspace->ptr(), stream,
workspace_size); workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmul(lt_handle, platform::dynload::cublasLtMatmul(lt_handle,
......
...@@ -726,6 +726,16 @@ ...@@ -726,6 +726,16 @@
optional : skip_update, master_params optional : skip_update, master_params
inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out) inplace : (params -> params_out), (moments1 -> moments1_out), (moments2 -> moments2_out), (beta1_pows -> beta1_pows_out), (beta2_pows -> beta2_pows_out), (master_params -> master_params_out)
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout
- op : gather - op : gather
args : (Tensor x, Tensor index, Scalar(int) axis=0) args : (Tensor x, Tensor index, Scalar(int) axis=0)
output : Tensor(out) output : Tensor(out)
......
...@@ -614,16 +614,6 @@ ...@@ -614,16 +614,6 @@
data_type : x data_type : x
backward : fused_dropout_add_grad backward : fused_dropout_add_grad
- op : fused_linear_param_grad_add
args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true)
output : Tensor(dweight_out), Tensor(dbias_out)
infer_meta:
func : FusedLinearParamGradAddInferMeta
optional : dweight, dbias
kernel:
func : fused_linear_param_grad_add
data_type : dout
- op : gather_nd - op : gather_nd
args : (Tensor x, Tensor index) args : (Tensor x, Tensor index)
output : Tensor output : Tensor
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <type_traits>
#include <utility>
#include "paddle/phi/core/macros.h"
namespace phi {
template <typename ReleaseCallback>
class ScopeGuard {
public:
explicit ScopeGuard(const ReleaseCallback &callback) : callback_(callback) {}
~ScopeGuard() { callback_(); }
private:
DISABLE_COPY_AND_ASSIGN(ScopeGuard);
private:
ReleaseCallback callback_;
};
// Two macros are needed here.
// See:
// https://stackoverflow.com/questions/10379691/creating-macro-using-line-for-different-variable-names
#define _PADDLE_CONCAT_TOKEN(x, y) x##y
#define PADDLE_CONCAT_TOKEN(x, y) _PADDLE_CONCAT_TOKEN(x, y)
#define DEFINE_PADDLE_SCOPE_GUARD(...) \
auto PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__) = __VA_ARGS__; \
::phi::ScopeGuard<typename std::remove_reference< \
decltype(PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))>::type> \
PADDLE_CONCAT_TOKEN(__scope_guard, __LINE__)( \
PADDLE_CONCAT_TOKEN(__scope_guard_func, __LINE__))
} // namespace phi
...@@ -27,21 +27,21 @@ limitations under the License. */ ...@@ -27,21 +27,21 @@ limitations under the License. */
#if CUDA_VERSION >= 11060 #if CUDA_VERSION >= 11060
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/scope_guard.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/all_context.h" #include "paddle/phi/backends/all_context.h"
#include "paddle/phi/backends/dynload/cublasLt.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h" #include "paddle/phi/backends/gpu/cuda/cuda_helper.h"
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/dense_tensor.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/scope_guard.h"
#include "paddle/utils/optional.h" #include "paddle/utils/optional.h"
DECLARE_int64(cublaslt_exhaustive_search_times); DECLARE_int64(cublaslt_exhaustive_search_times);
namespace paddle { namespace phi {
namespace operators { namespace funcs {
class GemmEpilogueAlgoCache { class GemmEpilogueAlgoCache {
public: public:
...@@ -88,9 +88,9 @@ class GemmEpilogueAlgoCache { ...@@ -88,9 +88,9 @@ class GemmEpilogueAlgoCache {
cublasLtMatmulPreference_t preference; cublasLtMatmulPreference_t preference;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); phi::dynload::cublasLtMatmulPreferenceCreate(&preference));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceSetAttribute( phi::dynload::cublasLtMatmulPreferenceSetAttribute(
preference, preference,
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, &workspace_size,
...@@ -100,25 +100,24 @@ class GemmEpilogueAlgoCache { ...@@ -100,25 +100,24 @@ class GemmEpilogueAlgoCache {
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results( std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
requested_algo_count_); requested_algo_count_);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulAlgoGetHeuristic( phi::dynload::cublasLtMatmulAlgoGetHeuristic(lt_handle,
lt_handle, op_desc,
op_desc, a_desc,
a_desc, b_desc,
b_desc, c_desc,
c_desc, c_desc,
c_desc, preference,
preference, requested_algo_count_,
requested_algo_count_, heuristic_results.data(),
heuristic_results.data(), &returned_results));
&returned_results));
PADDLE_ENFORCE_GT( PADDLE_ENFORCE_GT(
returned_results, returned_results,
0, 0,
platform::errors::Unavailable("No GEMM epilogue algorithm support!")); phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceDestroy(preference)); phi::dynload::cublasLtMatmulPreferenceDestroy(preference));
int best_algo_idx = -1; int best_algo_idx = -1;
float best_algo_time = 0; float best_algo_time = 0;
...@@ -126,29 +125,29 @@ class GemmEpilogueAlgoCache { ...@@ -126,29 +125,29 @@ class GemmEpilogueAlgoCache {
// Run 100 times for warmup // Run 100 times for warmup
int warmup_algo_idx = 0; int warmup_algo_idx = 0;
for (int t = 0; t < 100; t++) { for (int t = 0; t < 100; t++) {
cublasStatus_t status = platform::dynload::cublasLtMatmul( cublasStatus_t status =
lt_handle, phi::dynload::cublasLtMatmul(lt_handle,
op_desc, op_desc,
alpha, alpha,
a, a,
a_desc, a_desc,
b, b,
b_desc, b_desc,
beta, beta,
c, c,
c_desc, c_desc,
c, c,
c_desc, c_desc,
&heuristic_results[warmup_algo_idx].algo, &heuristic_results[warmup_algo_idx].algo,
workspace, workspace,
workspace_size, workspace_size,
stream); stream);
if (status != CUBLAS_STATUS_SUCCESS) { if (status != CUBLAS_STATUS_SUCCESS) {
t = -1; t = -1;
warmup_algo_idx += 1; warmup_algo_idx += 1;
if (warmup_algo_idx == requested_algo_count_) { if (warmup_algo_idx == requested_algo_count_) {
PADDLE_THROW(platform::errors::Unavailable( PADDLE_THROW(
"No GEMM epilogue algorithm support!")); phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
} }
} }
} }
...@@ -164,22 +163,22 @@ class GemmEpilogueAlgoCache { ...@@ -164,22 +163,22 @@ class GemmEpilogueAlgoCache {
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
cublasStatus_t status = cublasStatus_t status =
platform::dynload::cublasLtMatmul(lt_handle, phi::dynload::cublasLtMatmul(lt_handle,
op_desc, op_desc,
alpha, alpha,
a, a,
a_desc, a_desc,
b, b,
b_desc, b_desc,
beta, beta,
c, c,
c_desc, c_desc,
c, c,
c_desc, c_desc,
&heuristic_results[algo_idx].algo, &heuristic_results[algo_idx].algo,
workspace, workspace,
workspace_size, workspace_size,
stream); stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event)); PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event));
...@@ -204,7 +203,7 @@ class GemmEpilogueAlgoCache { ...@@ -204,7 +203,7 @@ class GemmEpilogueAlgoCache {
if (best_algo_idx == -1) { if (best_algo_idx == -1) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Unavailable("No GEMM epilogue algorithm support!")); phi::errors::Unavailable("No GEMM epilogue algorithm support!"));
} }
ret = heuristic_results[best_algo_idx].algo; ret = heuristic_results[best_algo_idx].algo;
...@@ -235,31 +234,28 @@ class GemmEpilogueAlgoCache { ...@@ -235,31 +234,28 @@ class GemmEpilogueAlgoCache {
int trans_a, trans_b; int trans_a, trans_b;
uint32_t epilogue; uint32_t epilogue;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
platform::dynload::cublasLtMatmulDescGetAttribute( desc,
desc, CUBLASLT_MATMUL_DESC_TRANSA,
CUBLASLT_MATMUL_DESC_TRANSA, &trans_a,
&trans_a, sizeof(trans_a),
sizeof(trans_a), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_a)); HashValue_(seed, hash_fn, static_cast<int64_t>(trans_a));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
platform::dynload::cublasLtMatmulDescGetAttribute( desc,
desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUBLASLT_MATMUL_DESC_TRANSB, &trans_b,
&trans_b, sizeof(trans_b),
sizeof(trans_b), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_b)); HashValue_(seed, hash_fn, static_cast<int64_t>(trans_b));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescGetAttribute(
platform::dynload::cublasLtMatmulDescGetAttribute( desc,
desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue,
&epilogue, sizeof(epilogue),
sizeof(epilogue), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(epilogue)); HashValue_(seed, hash_fn, static_cast<int64_t>(epilogue));
} }
...@@ -272,54 +268,40 @@ class GemmEpilogueAlgoCache { ...@@ -272,54 +268,40 @@ class GemmEpilogueAlgoCache {
uint64_t row, col; uint64_t row, col;
int64_t ld, batch_offset; int64_t ld, batch_offset;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc,
desc, CUBLASLT_MATRIX_LAYOUT_TYPE,
CUBLASLT_MATRIX_LAYOUT_TYPE, &dtype,
&dtype, sizeof(dtype),
sizeof(dtype), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(dtype)); HashValue_(seed, hash_fn, static_cast<int64_t>(dtype));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc,
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT,
CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
&batch, sizeof(batch),
sizeof(batch), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(batch)); HashValue_(seed, hash_fn, static_cast<int64_t>(batch));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), &size_to_write));
desc,
CUBLASLT_MATRIX_LAYOUT_ROWS,
&row,
sizeof(row),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(row)); HashValue_(seed, hash_fn, static_cast<int64_t>(row));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), &size_to_write));
desc,
CUBLASLT_MATRIX_LAYOUT_COLS,
&col,
sizeof(col),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(col)); HashValue_(seed, hash_fn, static_cast<int64_t>(col));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write));
desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(ld)); HashValue_(seed, hash_fn, static_cast<int64_t>(ld));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutGetAttribute(
platform::dynload::cublasLtMatrixLayoutGetAttribute( desc,
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset,
&batch_offset, sizeof(batch_offset),
sizeof(batch_offset), &size_to_write));
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset)); HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset));
} }
...@@ -341,7 +323,7 @@ static cublasLtEpilogue_t GetEpilogueType(const std::string& activation, ...@@ -341,7 +323,7 @@ static cublasLtEpilogue_t GetEpilogueType(const std::string& activation,
} else if (activation == "none") { } else if (activation == "none") {
return CUBLASLT_EPILOGUE_BIAS; return CUBLASLT_EPILOGUE_BIAS;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"The activation attribute of fused_gemm_epilogue op should be" "The activation attribute of fused_gemm_epilogue op should be"
" one of {\"none\", \"relu\", \"gelu\"}. But received %s." " one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation=%s.", "But received activation=%s.",
...@@ -381,24 +363,24 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, ...@@ -381,24 +363,24 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
} }
cublasLtMatmulDesc_t operation_desc = NULL; cublasLtMatmulDesc_t operation_desc = NULL;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&operation_desc, compute_type, scale_type)); &operation_desc, compute_type, scale_type));
cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transx = trans_x ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N; cublasOperation_t transy = trans_y ? CUBLAS_OP_T : CUBLAS_OP_N;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx))); operation_desc, CUBLASLT_MATMUL_DESC_TRANSB, &transx, sizeof(transx)));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy))); operation_desc, CUBLASLT_MATMUL_DESC_TRANSA, &transy, sizeof(transy)));
cublasLtEpilogue_t epiloque_func = cublasLtEpilogue_t epiloque_func =
GetEpilogueType(activation, enable_auxiliary); GetEpilogueType(activation, enable_auxiliary);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, operation_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&epiloque_func, &epiloque_func,
sizeof(epiloque_func))); sizeof(epiloque_func)));
const T* bias_data = bias->data<T>(); const T* bias_data = bias->data<T>();
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescSetAttribute( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
operation_desc, operation_desc,
CUBLASLT_MATMUL_DESC_BIAS_POINTER, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
&bias_data, &bias_data,
...@@ -420,45 +402,43 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, ...@@ -420,45 +402,43 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
void* aux_data = reserve_space->data(); void* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( operation_desc,
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_data,
&aux_data, sizeof(aux_data)));
sizeof(aux_data)));
int64_t aux_ld = N; int64_t aux_ld = N;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( operation_desc,
operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
&aux_ld, sizeof(aux_ld)));
sizeof(aux_ld)));
} }
cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL; cublasLtMatrixLayout_t x_desc = NULL, y_desc = NULL, out_desc = NULL;
if (trans_x) { if (trans_x) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(
&x_desc, mat_type, M, K, M)); phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, M, K, M));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(
&x_desc, mat_type, K, M, K)); phi::dynload::cublasLtMatrixLayoutCreate(&x_desc, mat_type, K, M, K));
} }
if (trans_y) { if (trans_y) {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(
&y_desc, mat_type, K, N, K)); phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, K, N, K));
} else { } else {
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(
&y_desc, mat_type, N, K, N)); phi::dynload::cublasLtMatrixLayoutCreate(&y_desc, mat_type, N, K, N));
} }
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(
&out_desc, mat_type, N, M, N)); phi::dynload::cublasLtMatrixLayoutCreate(&out_desc, mat_type, N, M, N));
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
// NOTE(zengjinle): I do not know whether the 4MB workspace size is // NOTE(zengjinle): I do not know whether the 4MB workspace size is
// "enough". I just followed the settings from the NVIDIA MLPerf BERT code. // "enough". I just followed the settings from the NVIDIA MLPerf BERT code.
size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024; size_t workspace_size = static_cast<size_t>(4) * 1024 * 1024;
cudaStream_t stream = dev_ctx.stream(); cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace = memory::Alloc( auto workspace = memory_utils::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
...@@ -482,31 +462,29 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx, ...@@ -482,31 +462,29 @@ void ComputeFusedGemmEpilogueForward(const phi::GPUContext& dev_ctx,
stream, stream,
workspace->ptr(), workspace->ptr(),
workspace_size); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(lt_handle, PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
operation_desc, operation_desc,
&alpha, &alpha,
y_data, y_data,
y_desc, y_desc,
x_data, x_data,
x_desc, x_desc,
&beta, &beta,
out_data, out_data,
out_desc, out_desc,
out_data, out_data,
out_desc, out_desc,
algo, algo,
workspace->ptr(), workspace->ptr(),
workspace_size, workspace_size,
stream)); stream));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc)); phi::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); phi::dynload::cublasLtMatrixLayoutDestroy(out_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
} }
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
...@@ -579,7 +557,7 @@ static cublasLtEpilogue_t GetEpilogueGradType( ...@@ -579,7 +557,7 @@ static cublasLtEpilogue_t GetEpilogueGradType(
} else if (activation_grad == "gelu_grad") { } else if (activation_grad == "gelu_grad") {
return CUBLASLT_EPILOGUE_DGELU; return CUBLASLT_EPILOGUE_DGELU;
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(phi::errors::InvalidArgument(
"The activation_grad attribute of fused_gemm_epilogue op should " "The activation_grad attribute of fused_gemm_epilogue op should "
"be one of {\"none\", \"relu\", \"gelu\"}. But received %s." "be one of {\"none\", \"relu\", \"gelu\"}. But received %s."
"But received activation_grad=%s.", "But received activation_grad=%s.",
...@@ -644,18 +622,18 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -644,18 +622,18 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
for (auto desc : descs) { for (auto desc : descs) {
if (desc) { if (desc) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(desc)); phi::dynload::cublasLtMatrixLayoutDestroy(desc));
} }
} }
if (dx_operation_desc) { if (dx_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); phi::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
} }
if (dy_operation_desc) { if (dy_operation_desc) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); phi::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
} }
}); });
...@@ -673,16 +651,16 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -673,16 +651,16 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (TransX) { if (TransX) {
dx_dout_desc = &dout_trans_desc; dx_dout_desc = &dout_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_row, z_col, z_row)); dx_dout_desc, mat_type, z_row, z_col, z_row));
} else { } else {
dx_dout_desc = &dout_desc; dx_dout_desc = &dout_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_dout_desc, mat_type, z_col, z_row, z_col)); dx_dout_desc, mat_type, z_col, z_row, z_col));
} }
dx_y_desc = &y_trans_desc; dx_y_desc = &y_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dx_y_desc, mat_type, y_col, y_row, y_col)); dx_y_desc, mat_type, y_col, y_row, y_col));
auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc); auto& a_desc = kXGradAIsDZ ? (*dx_dout_desc) : (*dx_y_desc);
...@@ -690,55 +668,50 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -690,55 +668,50 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans); auto a_trans = BoolToCuBlasEnum(Trait::kXGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans); auto b_trans = BoolToCuBlasEnum(Trait::kXGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
&dx_desc, &dx_desc,
phi::backends::gpu::ToCudaDataType<DXT>(), phi::backends::gpu::ToCudaDataType<DXT>(),
x_col, x_col,
x_row, x_row,
x_col)); x_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&dx_operation_desc, compute_type, scale_type)); &dx_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dx_operation_desc,
dx_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUBLASLT_MATMUL_DESC_TRANSB, &a_trans,
&a_trans, sizeof(a_trans)));
sizeof(a_trans))); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS( dx_operation_desc,
platform::dynload::cublasLtMatmulDescSetAttribute( CUBLASLT_MATMUL_DESC_TRANSA,
dx_operation_desc, &b_trans,
CUBLASLT_MATMUL_DESC_TRANSA, sizeof(b_trans)));
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dx = cublasLtEpilogue_t epiloque_func_for_dx =
GetEpilogueGradType(activation_grad); GetEpilogueGradType(activation_grad);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dx_operation_desc,
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func_for_dx,
&epiloque_func_for_dx, sizeof(epiloque_func_for_dx)));
sizeof(epiloque_func_for_dx)));
if (activation_grad != "none") { if (activation_grad != "none") {
auto* aux_data = reserve_space->data(); auto* aux_data = reserve_space->data();
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dx_operation_desc,
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_POINTER, &aux_data,
&aux_data, sizeof(aux_data)));
sizeof(aux_data)));
int64_t aux_ld = TransX ? M : K; int64_t aux_ld = TransX ? M : K;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dx_operation_desc,
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD,
CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &aux_ld,
&aux_ld, sizeof(aux_ld)));
sizeof(aux_ld)));
} }
auto dx_workspace = memory::Alloc( auto dx_workspace = memory_utils::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
...@@ -764,23 +737,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -764,23 +737,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dx_workspace->ptr(), dx_workspace->ptr(),
workspace_size); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
platform::dynload::cublasLtMatmul(lt_handle, dx_operation_desc,
dx_operation_desc, &alpha,
&alpha, b_data,
b_data, b_desc,
b_desc, a_data,
a_data, a_desc,
a_desc, &beta_dx,
&beta_dx, dx_data,
dx_data, dx_desc,
dx_desc, dx_data,
dx_data, dx_desc,
dx_desc, algo,
algo, dx_workspace->ptr(),
dx_workspace->ptr(), workspace_size,
workspace_size, stream));
stream));
} }
// dy = func(dout, x) // dy = func(dout, x)
...@@ -791,21 +763,19 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -791,21 +763,19 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
if (TransX) { if (TransX) {
dy_dout_desc = &dout_trans_desc; dy_dout_desc = &dout_trans_desc;
if (dout_trans_desc == nullptr) { if (dout_trans_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
platform::dynload::cublasLtMatrixLayoutCreate( dy_dout_desc, mat_type, z_row, z_col, z_row));
dy_dout_desc, mat_type, z_row, z_col, z_row));
} }
} else { } else {
dy_dout_desc = &dout_desc; dy_dout_desc = &dout_desc;
if (dout_desc == nullptr) { if (dout_desc == nullptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
platform::dynload::cublasLtMatrixLayoutCreate( dy_dout_desc, mat_type, z_col, z_row, z_col));
dy_dout_desc, mat_type, z_col, z_row, z_col));
} }
} }
dy_x_desc = &x_trans_desc; dy_x_desc = &x_trans_desc;
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
dy_x_desc, mat_type, x_col, x_row, x_col)); dy_x_desc, mat_type, x_col, x_row, x_col));
auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc); auto& a_desc = kYGradAIsDZ ? (*dy_dout_desc) : (*dy_x_desc);
...@@ -813,28 +783,26 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -813,28 +783,26 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans); auto a_trans = BoolToCuBlasEnum(Trait::kYGradATrans);
auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans); auto b_trans = BoolToCuBlasEnum(Trait::kYGradBTrans);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatrixLayoutCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatrixLayoutCreate(
&dy_desc, &dy_desc,
phi::backends::gpu::ToCudaDataType<DYT>(), phi::backends::gpu::ToCudaDataType<DYT>(),
y_col, y_col,
y_row, y_row,
y_col)); y_col));
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmulDescCreate( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescCreate(
&dy_operation_desc, compute_type, scale_type)); &dy_operation_desc, compute_type, scale_type));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc,
dy_operation_desc, CUBLASLT_MATMUL_DESC_TRANSB,
CUBLASLT_MATMUL_DESC_TRANSB, &a_trans,
&a_trans, sizeof(a_trans)));
sizeof(a_trans))); PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
PADDLE_ENFORCE_GPU_SUCCESS( dy_operation_desc,
platform::dynload::cublasLtMatmulDescSetAttribute( CUBLASLT_MATMUL_DESC_TRANSA,
dy_operation_desc, &b_trans,
CUBLASLT_MATMUL_DESC_TRANSA, sizeof(b_trans)));
&b_trans,
sizeof(b_trans)));
cublasLtEpilogue_t epiloque_func_for_dy; cublasLtEpilogue_t epiloque_func_for_dy;
if (dbias == nullptr) { if (dbias == nullptr) {
...@@ -847,25 +815,23 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -847,25 +815,23 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
} }
} }
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc,
dy_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE,
CUBLASLT_MATMUL_DESC_EPILOGUE, &epiloque_func_for_dy,
&epiloque_func_for_dy, sizeof(epiloque_func_for_dy)));
sizeof(epiloque_func_for_dy)));
if (dbias) { if (dbias) {
auto* dbias_data = auto* dbias_data =
dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT)); dev_ctx.Alloc<DYT>(dbias, dbias->numel() * sizeof(DYT));
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmulDescSetAttribute(
platform::dynload::cublasLtMatmulDescSetAttribute( dy_operation_desc,
dy_operation_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER,
CUBLASLT_MATMUL_DESC_BIAS_POINTER, &dbias_data,
&dbias_data, sizeof(dbias_data)));
sizeof(dbias_data)));
} }
auto dy_workspace = memory::Alloc( auto dy_workspace = memory_utils::Alloc(
dev_ctx.GetPlace(), dev_ctx.GetPlace(),
workspace_size, workspace_size,
phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream()))); phi::Stream(reinterpret_cast<phi::StreamId>(dev_ctx.stream())));
...@@ -890,23 +856,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx, ...@@ -890,23 +856,22 @@ void ComputeFusedGemmEpilogueBackwardImpl(const phi::GPUContext& dev_ctx,
dy_workspace->ptr(), dy_workspace->ptr(),
workspace_size); workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(phi::dynload::cublasLtMatmul(lt_handle,
platform::dynload::cublasLtMatmul(lt_handle, dy_operation_desc,
dy_operation_desc, &alpha,
&alpha, b_data,
b_data, b_desc,
b_desc, a_data,
a_data, a_desc,
a_desc, &beta_dy,
&beta_dy, dy_data,
dy_data, dy_desc,
dy_desc, dy_data,
dy_data, dy_desc,
dy_desc, algo,
algo, dy_workspace->ptr(),
dy_workspace->ptr(), workspace_size,
workspace_size, stream));
stream));
} }
} }
...@@ -1002,7 +967,7 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx, ...@@ -1002,7 +967,7 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
} }
} }
} // namespace operators } // namespace funcs
} // namespace paddle } // namespace phi
#endif #endif
#endif #endif
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h" #include "paddle/phi/kernels/fusion/fused_linear_param_grad_add_kernel.h"
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 11060
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/phi/kernels/funcs/fused_gemm_epilogue.h"
#endif #endif
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
...@@ -41,7 +41,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx, ...@@ -41,7 +41,7 @@ void FusedLinearParamGradAddImpl(const Context &ctx,
const bool fuse_bias_grad = kIsMultiPrecision && dweight_out; const bool fuse_bias_grad = kIsMultiPrecision && dweight_out;
if (dweight_out) { if (dweight_out) {
paddle::operators::ComputeFusedGemmEpilogueBackward<T, T, MT>( phi::funcs::ComputeFusedGemmEpilogueBackward<T, T, MT>(
ctx, ctx,
&dout, &dout,
&x, &x,
...@@ -184,10 +184,6 @@ void FusedLinearParamGradAdd(const Context &ctx, ...@@ -184,10 +184,6 @@ void FusedLinearParamGradAdd(const Context &ctx,
FusedLinearParamGradAddImpl<T, T, Context>( FusedLinearParamGradAddImpl<T, T, Context>(
ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out); ctx, x, dout, dbias, M, K, N, use_addto, dweight_out, dbias_out);
} }
if (VLOG_IS_ON(kLogLevel)) {
ctx.Wait();
}
} }
#else #else
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册