未验证 提交 19650d72 编写于 作者: M Ming-Xu Huang 提交者: GitHub

[WIP] Algorithm Cache of cuBlasLt Epilogue (#41010)

* Fix leading dimension setting error in fused_gemm_epilogue_grad_op.

* Add dyload to cuBlasLt functions.

* Added cublasLtMatmulAlgoGetHeuristic to improve performance.

* Added FLAGS_cublaslt_exhaustive_search_times to cublasLt epilogue

* Added UTs to FLAGS_cublaslt_exhaustive_search_times

* Added warmup runs in algo searching of Gemm epilogue.

* Update copyright and documents.

* Fixed error handling.
上级 9e3cfdfa
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"
......@@ -56,7 +57,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
......@@ -130,7 +130,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
memory::allocation::AllocationPtr workspace =
memory::Alloc(dev_ctx, workspace_size);
......@@ -146,10 +146,26 @@ class FusedGemmEpilogueKernel : public framework::OpKernel<T> {
beta = &beta32;
}
const auto* y_data = y->data<T>();
const auto* x_data = x->data<T>();
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, operation_desc, y_desc, x_desc, out_desc, alpha, beta,
y_data, x_data, out_data, stream, workspace->ptr(), workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, operation_desc, alpha, y->data<T>(), y_desc, x->data<T>(),
x_desc, beta, out_data, out_desc, out_data, out_desc, algo,
workspace->ptr(), workspace_size, stream));
lt_handle, operation_desc, alpha, y_data, y_desc, x_data, x_desc, beta,
out_data, out_desc, out_data, out_desc, &algo, workspace->ptr(),
workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(out_desc));
}
private:
......@@ -205,7 +221,6 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F;
if (std::is_same<T, paddle::platform::float16>::value) {
mat_type = CUDA_R_16F;
scale_type = CUDA_R_16F;
}
if (std::is_same<T, double>::value) {
mat_type = CUDA_R_64F;
......@@ -215,7 +230,6 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle();
size_t workspace_size = 4 * 1024 * 1024;
const cublasLtMatmulAlgo_t* algo = nullptr;
cudaStream_t stream = dev_ctx.stream();
double alpha64 = 1.0, beta64 = 0.0;
......@@ -262,8 +276,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
&aux_data, sizeof(aux_data)));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescSetAttribute(
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N,
sizeof(N)));
dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &K,
sizeof(K)));
}
cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL;
......@@ -277,10 +291,24 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
dx->mutable_data<T>(ctx.GetPlace());
auto* dx_data = dx->data<T>();
const auto* y_data = y->data<T>();
const auto* dout_data = dout->data<T>();
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, dx_operation_desc, y_desc, dout_desc, dx_desc, alpha, beta,
y_data, dout_data, dx_data, stream, dx_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dx_operation_desc, alpha, y->data<T>(), y_desc,
dout->data<T>(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc,
algo, dx_workspace->ptr(), workspace_size, stream));
&algo, dx_workspace->ptr(), workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(y_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dx_desc));
}
if (dy) {
......@@ -324,11 +352,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel<T> {
dy->mutable_data<T>(ctx.GetPlace());
auto* dy_data = dy->data<T>();
const auto* dout_data = dout->data<T>();
const auto* x_data = x->data<T>();
cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo(
lt_handle, dy_operation_desc, dout_desc, x_desc, dy_desc, alpha, beta,
dout_data, x_data, dy_data, stream, dy_workspace->ptr(),
workspace_size);
PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul(
lt_handle, dy_operation_desc, alpha, dout->data<T>(), dout_desc,
x->data<T>(), x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, algo,
lt_handle, dy_operation_desc, alpha, dout_data, dout_desc, x_data,
x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, &algo,
dy_workspace->ptr(), workspace_size, stream));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(x_desc));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dy_desc));
}
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutDestroy(dout_desc));
}
private:
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cuda_runtime_api.h>
#include <algorithm>
#include <mutex>
#include <unordered_map>
#include "gflags/gflags.h"
#include "paddle/fluid/platform/dynload/cublasLt.h"
#include "paddle/fluid/platform/float16.h"
DECLARE_int64(cublaslt_exhaustive_search_times);
namespace paddle {
namespace operators {
class GemmEpilogueAlgoCache {
public:
static GemmEpilogueAlgoCache &Instance() {
static GemmEpilogueAlgoCache instance(
FLAGS_cublaslt_exhaustive_search_times);
return instance;
}
GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete;
void operator=(GemmEpilogueAlgoCache const &) = delete;
cublasLtMatmulAlgo_t GetGemmAlgo(
cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc,
cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc,
cublasLtMatrixLayout_t c_desc, const void *alpha, const void *beta,
const void *a, const void *b, void *c, cudaStream_t stream,
void *workspace, size_t workspace_size) {
int64_t seed = 0;
std::hash<int64_t> hash_fn;
HashMatmulDesc_(op_desc, &seed, hash_fn);
HashMatrixLayoutDesc_(a_desc, &seed, hash_fn);
HashMatrixLayoutDesc_(b_desc, &seed, hash_fn);
HashMatrixLayoutDesc_(c_desc, &seed, hash_fn);
cublasLtMatmulAlgo_t ret;
auto it = map_.end();
bool have_found = false;
{
std::lock_guard<std::mutex> lock(cache_mutex_);
it = map_.find(seed);
if (it != map_.end()) {
ret = it->second;
have_found = true;
}
}
if (!have_found) {
cublasLtMatmulPreference_t preference;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceCreate(&preference));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceSetAttribute(
preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, sizeof(workspace_size)));
int returned_results = 0;
cublasLtMatmulHeuristicResult_t heuristic_results[requested_algo_count_] =
{0};
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulAlgoGetHeuristic(
lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference,
requested_algo_count_, heuristic_results, &returned_results));
PADDLE_ENFORCE_GT(
returned_results, 0,
platform::errors::Unavailable("No GEMM epilogue algorithm support!"));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulPreferenceDestroy(preference));
if (search_times_ > 0) {
int best_algo_idx = -1;
float best_algo_time = 0;
// Run 100 times for warmup
int warmup_algo_idx = 0;
for (int t = 0; t < 100; t++) {
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc,
c, c_desc, &heuristic_results[warmup_algo_idx].algo, workspace,
workspace_size, stream);
if (status != CUBLAS_STATUS_SUCCESS) {
t = -1;
warmup_algo_idx += 1;
if (warmup_algo_idx == requested_algo_count_) {
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!"));
}
}
}
cudaEvent_t start_event, stop_event;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event));
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
float curr_time = 0;
for (int check_idx = 0; check_idx < search_times_; check_idx++) {
float time = 0;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream));
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c,
c_desc, c, c_desc, &heuristic_results[algo_idx].algo, workspace,
workspace_size, stream);
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event));
PADDLE_ENFORCE_GPU_SUCCESS(
cudaEventElapsedTime(&time, start_event, stop_event));
curr_time += time;
if (status != CUBLAS_STATUS_SUCCESS) {
curr_time = 3.40282e+038; // Max Value of float
break;
}
}
curr_time = curr_time / search_times_;
if (curr_time < best_algo_time || algo_idx == 0) {
best_algo_idx = algo_idx;
best_algo_time = curr_time;
}
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event));
if (best_algo_idx == -1) {
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!"));
}
ret = heuristic_results[best_algo_idx].algo;
} else {
int decided_algo_idx = -1;
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
cublasStatus_t status = platform::dynload::cublasLtMatmul(
lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc,
c, c_desc, &heuristic_results[algo_idx].algo, workspace,
workspace_size, stream);
if (status == CUBLAS_STATUS_SUCCESS) {
decided_algo_idx = algo_idx;
break;
}
}
if (decided_algo_idx == -1) {
PADDLE_THROW(platform::errors::Unavailable(
"No GEMM epilogue algorithm support!"));
}
ret = heuristic_results[decided_algo_idx].algo;
}
std::lock_guard<std::mutex> lock(cache_mutex_);
map_[seed] = ret;
}
VLOG(4) << "Search time:" << search_times_ << ", Is hash-key (" << seed
<< ") found in GemmEpilogueAlgoCache? " << have_found;
return ret;
}
private:
explicit GemmEpilogueAlgoCache(int search_times)
: search_times_(search_times) {
map_.clear();
}
std::unordered_map<int64_t, cublasLtMatmulAlgo_t> map_;
int search_times_;
const int requested_algo_count_ = 10;
std::mutex cache_mutex_;
void HashMatmulDesc_(cublasLtMatmulDesc_t desc, int64_t *seed,
const std::hash<int64_t> &hash_fn) {
size_t size_to_write;
int trans_a, trans_b;
uint32_t epilogue;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_a));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(trans_b));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatmulDescGetAttribute(
desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(epilogue));
}
void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, int64_t *seed,
const std::hash<int64_t> &hash_fn) {
size_t size_to_write;
uint32_t dtype;
int32_t batch;
uint64_t row, col;
int64_t ld, batch_offset;
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_TYPE, &dtype, sizeof(dtype),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(dtype));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(batch));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(row));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col),
&size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(col));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(ld));
PADDLE_ENFORCE_GPU_SUCCESS(
platform::dynload::cublasLtMatrixLayoutGetAttribute(
desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset,
sizeof(batch_offset), &size_to_write));
HashValue_(seed, hash_fn, static_cast<int64_t>(batch_offset));
}
void HashValue_(int64_t *seed, const std::hash<int64_t> &hash_fn,
int64_t value) {
*seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2);
}
};
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -38,19 +39,25 @@ namespace dynload {
// APIs available after CUDA 10.1
// #if CUDA_VERSION >= 10100
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -107,6 +108,29 @@ PADDLE_DEFINE_EXPORTED_string(
"share-memory only.");
#endif
#if defined(PADDLE_WITH_CUDA)
/**
* CUDA related FLAG
* Name: FLAGS_cublaslt_exhaustive_search_times
* Since Version: 2.3.0
* Value Range: int64_t, default=0
* Example:
* Note: Represents times of exhaustive search to evaluate performance of
* cuBlasLt matmul algorithm (with/without epilogue). Set this flag
* with value > 0 to enable exhaustive search. Default is 0, means
* getting algorithms via heuristic search. There are two search methods
* in cuBlasLt, heuristic search and exhaustive search. Exhaustive search
* attempts all cuBlasLt algorithms to select the fastest, which is very
* time-consuming, and the selected algorithm will be cached for a given
* layer specification Once you change the layer specifications
* (such as M, N and K), it will re-search again.
*/
PADDLE_DEFINE_EXPORTED_int64(
cublaslt_exhaustive_search_times, 0,
"The times of exhaustive search for cuBlasLt matmul with/without "
" epilogue algorithms, default is 0, means disabling exhaustive search.");
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
PADDLE_DEFINE_EXPORTED_string(
selected_npus, "",
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Copyright (c) 2022 NVIDIA Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -52,19 +53,25 @@ extern void *cublasLt_dso_handle;
// APIs available after CUDA 10.1
// #if CUDA_VERSION >= 10100
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \
__macro(cublasLtCreate); \
__macro(cublasLtDestroy); \
__macro(cublasLtMatmul); \
__macro(cublasLtMatmulDescCreate); \
__macro(cublasLtMatmulDescDestroy); \
__macro(cublasLtMatmulDescSetAttribute); \
__macro(cublasLtMatmulDescGetAttribute); \
__macro(cublasLtMatrixLayoutCreate); \
__macro(cublasLtMatrixLayoutDestroy); \
__macro(cublasLtMatrixLayoutSetAttribute); \
__macro(cublasLtMatrixLayoutGetAttribute); \
__macro(cublasLtMatmulPreferenceCreate); \
__macro(cublasLtMatmulPreferenceDestroy); \
__macro(cublasLtMatmulPreferenceSetAttribute); \
__macro(cublasLtMatmulAlgoGetHeuristic); \
__macro(cublasLtMatrixTransform); \
__macro(cublasLtMatrixTransformDescCreate); \
__macro(cublasLtMatrixTransformDescDestroy); \
__macro(cublasLtMatrixTransformDescSetAttribute);
CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP)
......
......@@ -129,18 +129,11 @@ if(NOT WITH_GPU)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api)
LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
endif()
if (WITH_GPU)
if (CUDA_VERSION LESS 11.6)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
endif()
endif()
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op)
LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op)
LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass)
if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op)
......@@ -644,6 +637,15 @@ py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_stat
FLAGS_cudnn_deterministic=1)
py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS
FLAGS_cudnn_deterministic=1)
if ((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6))
py_test_modules(test_fused_gemm_epilogue_op MODULES test_fused_gemm_epilogue_op)
py_test_modules(test_fused_gemm_epilogue_grad_op MODULES test_fused_gemm_epilogue_grad_op)
py_test_modules(test_fused_gemm_epilogue_op_with_es MODULES test_fused_gemm_epilogue_op ENVS FLAGS_cublaslt_exhaustive_search_times=30)
py_test_modules(test_fused_gemm_epilogue_grad_op_with_es MODULES test_fused_gemm_epilogue_grad_op ENVS FLAGS_cublaslt_exhaustive_search_times=30)
py_test_modules(test_fuse_gemm_epilogue_pass MODULES test_fuse_gemm_epilogue_pass)
endif()
set_tests_properties(test_conv2d_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_faster_tokenizer_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_conv2d_op_depthwise_conv PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
......
......@@ -49,8 +49,8 @@ def verify_node_count(graph, node_name, target_count):
class MultiFCLayer(paddle.nn.Layer):
def __init__(self, hidden, Activation):
super(MultiFCLayer, self).__init__()
self.linear1 = paddle.nn.Linear(hidden, hidden)
self.linear2 = paddle.nn.Linear(hidden, hidden)
self.linear1 = paddle.nn.Linear(hidden, 4 * hidden)
self.linear2 = paddle.nn.Linear(4 * hidden, hidden)
self.linear3 = paddle.nn.Linear(hidden, hidden)
self.relu1 = Activation()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册