diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu index e16c9e8f483ccc2cbf1d7006159cccfe906dd06b..9bf3d1a485efc71a19960525cb427ffb823eeefa 100644 --- a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.cu @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/operators/fused/fused_gemm_epilogue_op.h" #include "paddle/fluid/platform/dynload/cublasLt.h" #include "paddle/fluid/platform/float16.h" @@ -56,7 +57,6 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; if (std::is_same::value) { mat_type = CUDA_R_16F; - scale_type = CUDA_R_16F; } if (std::is_same::value) { mat_type = CUDA_R_64F; @@ -130,7 +130,7 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); size_t workspace_size = 4 * 1024 * 1024; - const cublasLtMatmulAlgo_t* algo = nullptr; + cudaStream_t stream = dev_ctx.stream(); memory::allocation::AllocationPtr workspace = memory::Alloc(dev_ctx, workspace_size); @@ -146,10 +146,26 @@ class FusedGemmEpilogueKernel : public framework::OpKernel { beta = &beta32; } + const auto* y_data = y->data(); + const auto* x_data = x->data(); + + cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, operation_desc, y_desc, x_desc, out_desc, alpha, beta, + y_data, x_data, out_data, stream, workspace->ptr(), workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( - lt_handle, operation_desc, alpha, y->data(), y_desc, x->data(), - x_desc, beta, out_data, out_desc, out_data, out_desc, algo, - workspace->ptr(), workspace_size, stream)); + lt_handle, operation_desc, alpha, y_data, y_desc, x_data, x_desc, beta, + out_data, out_desc, out_data, out_desc, &algo, workspace->ptr(), + workspace_size, stream)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(operation_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(out_desc)); } private: @@ -205,7 +221,6 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { cublasComputeType_t compute_type = CUBLAS_COMPUTE_32F; if (std::is_same::value) { mat_type = CUDA_R_16F; - scale_type = CUDA_R_16F; } if (std::is_same::value) { mat_type = CUDA_R_64F; @@ -215,7 +230,6 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { cublasLtHandle_t lt_handle = dev_ctx.cublaslt_handle(); size_t workspace_size = 4 * 1024 * 1024; - const cublasLtMatmulAlgo_t* algo = nullptr; cudaStream_t stream = dev_ctx.stream(); double alpha64 = 1.0, beta64 = 0.0; @@ -262,8 +276,8 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { &aux_data, sizeof(aux_data))); PADDLE_ENFORCE_GPU_SUCCESS( platform::dynload::cublasLtMatmulDescSetAttribute( - dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &N, - sizeof(N))); + dx_operation_desc, CUBLASLT_MATMUL_DESC_EPILOGUE_AUX_LD, &K, + sizeof(K))); } cublasLtMatrixLayout_t y_desc = NULL, dx_desc = NULL; @@ -277,10 +291,24 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { dx->mutable_data(ctx.GetPlace()); auto* dx_data = dx->data(); + const auto* y_data = y->data(); + const auto* dout_data = dout->data(); + + cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, dx_operation_desc, y_desc, dout_desc, dx_desc, alpha, beta, + y_data, dout_data, dx_data, stream, dx_workspace->ptr(), + workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( lt_handle, dx_operation_desc, alpha, y->data(), y_desc, dout->data(), dout_desc, beta, dx_data, dx_desc, dx_data, dx_desc, - algo, dx_workspace->ptr(), workspace_size, stream)); + &algo, dx_workspace->ptr(), workspace_size, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dx_operation_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(y_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(dx_desc)); } if (dy) { @@ -324,11 +352,27 @@ class FusedGemmEpilogueGradKernel : public framework::OpKernel { dy->mutable_data(ctx.GetPlace()); auto* dy_data = dy->data(); + const auto* dout_data = dout->data(); + const auto* x_data = x->data(); + + cublasLtMatmulAlgo_t algo = GemmEpilogueAlgoCache::Instance().GetGemmAlgo( + lt_handle, dy_operation_desc, dout_desc, x_desc, dy_desc, alpha, beta, + dout_data, x_data, dy_data, stream, dy_workspace->ptr(), + workspace_size); + PADDLE_ENFORCE_GPU_SUCCESS(platform::dynload::cublasLtMatmul( - lt_handle, dy_operation_desc, alpha, dout->data(), dout_desc, - x->data(), x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, algo, + lt_handle, dy_operation_desc, alpha, dout_data, dout_desc, x_data, + x_desc, beta, dy_data, dy_desc, dy_data, dy_desc, &algo, dy_workspace->ptr(), workspace_size, stream)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescDestroy(dy_operation_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(x_desc)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(dy_desc)); } + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutDestroy(dout_desc)); } private: diff --git a/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h new file mode 100644 index 0000000000000000000000000000000000000000..c90a6966fe0a841dd3eb692aaafcdd03535b16a0 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_gemm_epilogue_op.h @@ -0,0 +1,271 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include "gflags/gflags.h" +#include "paddle/fluid/platform/dynload/cublasLt.h" +#include "paddle/fluid/platform/float16.h" + +DECLARE_int64(cublaslt_exhaustive_search_times); + +namespace paddle { +namespace operators { + +class GemmEpilogueAlgoCache { + public: + static GemmEpilogueAlgoCache &Instance() { + static GemmEpilogueAlgoCache instance( + FLAGS_cublaslt_exhaustive_search_times); + return instance; + } + + GemmEpilogueAlgoCache(GemmEpilogueAlgoCache const &) = delete; + void operator=(GemmEpilogueAlgoCache const &) = delete; + + cublasLtMatmulAlgo_t GetGemmAlgo( + cublasLtHandle_t lt_handle, cublasLtMatmulDesc_t op_desc, + cublasLtMatrixLayout_t a_desc, cublasLtMatrixLayout_t b_desc, + cublasLtMatrixLayout_t c_desc, const void *alpha, const void *beta, + const void *a, const void *b, void *c, cudaStream_t stream, + void *workspace, size_t workspace_size) { + int64_t seed = 0; + std::hash hash_fn; + + HashMatmulDesc_(op_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(a_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(b_desc, &seed, hash_fn); + HashMatrixLayoutDesc_(c_desc, &seed, hash_fn); + + cublasLtMatmulAlgo_t ret; + auto it = map_.end(); + bool have_found = false; + { + std::lock_guard lock(cache_mutex_); + it = map_.find(seed); + + if (it != map_.end()) { + ret = it->second; + have_found = true; + } + } + + if (!have_found) { + cublasLtMatmulPreference_t preference; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceCreate(&preference)); + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceSetAttribute( + preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspace_size, sizeof(workspace_size))); + + int returned_results = 0; + cublasLtMatmulHeuristicResult_t heuristic_results[requested_algo_count_] = + {0}; + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulAlgoGetHeuristic( + lt_handle, op_desc, a_desc, b_desc, c_desc, c_desc, preference, + requested_algo_count_, heuristic_results, &returned_results)); + + PADDLE_ENFORCE_GT( + returned_results, 0, + platform::errors::Unavailable("No GEMM epilogue algorithm support!")); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulPreferenceDestroy(preference)); + + if (search_times_ > 0) { + int best_algo_idx = -1; + float best_algo_time = 0; + + // Run 100 times for warmup + int warmup_algo_idx = 0; + for (int t = 0; t < 100; t++) { + cublasStatus_t status = platform::dynload::cublasLtMatmul( + lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, + c, c_desc, &heuristic_results[warmup_algo_idx].algo, workspace, + workspace_size, stream); + if (status != CUBLAS_STATUS_SUCCESS) { + t = -1; + warmup_algo_idx += 1; + if (warmup_algo_idx == requested_algo_count_) { + PADDLE_THROW(platform::errors::Unavailable( + "No GEMM epilogue algorithm support!")); + } + } + } + + cudaEvent_t start_event, stop_event; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&stop_event)); + + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + float curr_time = 0; + for (int check_idx = 0; check_idx < search_times_; check_idx++) { + float time = 0; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(start_event, stream)); + + cublasStatus_t status = platform::dynload::cublasLtMatmul( + lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, + c_desc, c, c_desc, &heuristic_results[algo_idx].algo, workspace, + workspace_size, stream); + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(stop_event, stream)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(stop_event)); + PADDLE_ENFORCE_GPU_SUCCESS( + cudaEventElapsedTime(&time, start_event, stop_event)); + curr_time += time; + if (status != CUBLAS_STATUS_SUCCESS) { + curr_time = 3.40282e+038; // Max Value of float + break; + } + } + + curr_time = curr_time / search_times_; + if (curr_time < best_algo_time || algo_idx == 0) { + best_algo_idx = algo_idx; + best_algo_time = curr_time; + } + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(start_event)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventDestroy(stop_event)); + + if (best_algo_idx == -1) { + PADDLE_THROW(platform::errors::Unavailable( + "No GEMM epilogue algorithm support!")); + } + + ret = heuristic_results[best_algo_idx].algo; + } else { + int decided_algo_idx = -1; + for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { + cublasStatus_t status = platform::dynload::cublasLtMatmul( + lt_handle, op_desc, alpha, a, a_desc, b, b_desc, beta, c, c_desc, + c, c_desc, &heuristic_results[algo_idx].algo, workspace, + workspace_size, stream); + if (status == CUBLAS_STATUS_SUCCESS) { + decided_algo_idx = algo_idx; + break; + } + } + if (decided_algo_idx == -1) { + PADDLE_THROW(platform::errors::Unavailable( + "No GEMM epilogue algorithm support!")); + } + ret = heuristic_results[decided_algo_idx].algo; + } + + std::lock_guard lock(cache_mutex_); + map_[seed] = ret; + } + + VLOG(4) << "Search time:" << search_times_ << ", Is hash-key (" << seed + << ") found in GemmEpilogueAlgoCache? " << have_found; + + return ret; + } + + private: + explicit GemmEpilogueAlgoCache(int search_times) + : search_times_(search_times) { + map_.clear(); + } + std::unordered_map map_; + int search_times_; + const int requested_algo_count_ = 10; + std::mutex cache_mutex_; + + void HashMatmulDesc_(cublasLtMatmulDesc_t desc, int64_t *seed, + const std::hash &hash_fn) { + size_t size_to_write; + int trans_a, trans_b; + uint32_t epilogue; + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescGetAttribute( + desc, CUBLASLT_MATMUL_DESC_TRANSA, &trans_a, sizeof(trans_a), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_a)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescGetAttribute( + desc, CUBLASLT_MATMUL_DESC_TRANSB, &trans_b, sizeof(trans_b), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(trans_b)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatmulDescGetAttribute( + desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(epilogue)); + } + + void HashMatrixLayoutDesc_(cublasLtMatrixLayout_t desc, int64_t *seed, + const std::hash &hash_fn) { + size_t size_to_write; + uint32_t dtype; + int32_t batch; + uint64_t row, col; + int64_t ld, batch_offset; + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_TYPE, &dtype, sizeof(dtype), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(dtype)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_ROWS, &row, sizeof(row), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(row)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_COLS, &col, sizeof(col), + &size_to_write)); + HashValue_(seed, hash_fn, static_cast(col)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_LD, &ld, sizeof(ld), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(ld)); + + PADDLE_ENFORCE_GPU_SUCCESS( + platform::dynload::cublasLtMatrixLayoutGetAttribute( + desc, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &batch_offset, + sizeof(batch_offset), &size_to_write)); + HashValue_(seed, hash_fn, static_cast(batch_offset)); + } + + void HashValue_(int64_t *seed, const std::hash &hash_fn, + int64_t value) { + *seed ^= hash_fn(value) + 0x9e3779b9 + (*seed << 6) + (*seed >> 2); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/platform/dynload/cublasLt.h b/paddle/fluid/platform/dynload/cublasLt.h index c9a59751a320a22139df4e918cecd6cb1eb40c87..5157cfdad2e5939afb1b66b8c5ac80a4556669b6 100644 --- a/paddle/fluid/platform/dynload/cublasLt.h +++ b/paddle/fluid/platform/dynload/cublasLt.h @@ -1,4 +1,5 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -38,19 +39,25 @@ namespace dynload { // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); CUBLASLT_BLAS_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index f89452853b49b40f11351423f582932897d6fe09..054a804e6b38e840ac0c9890a1c6f5ebcdb19341 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -1,4 +1,5 @@ // Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -107,6 +108,29 @@ PADDLE_DEFINE_EXPORTED_string( "share-memory only."); #endif +#if defined(PADDLE_WITH_CUDA) +/** + * CUDA related FLAG + * Name: FLAGS_cublaslt_exhaustive_search_times + * Since Version: 2.3.0 + * Value Range: int64_t, default=0 + * Example: + * Note: Represents times of exhaustive search to evaluate performance of + * cuBlasLt matmul algorithm (with/without epilogue). Set this flag + * with value > 0 to enable exhaustive search. Default is 0, means + * getting algorithms via heuristic search. There are two search methods + * in cuBlasLt, heuristic search and exhaustive search. Exhaustive search + * attempts all cuBlasLt algorithms to select the fastest, which is very + * time-consuming, and the selected algorithm will be cached for a given + * layer specification Once you change the layer specifications + * (such as M, N and K), it will re-search again. + */ +PADDLE_DEFINE_EXPORTED_int64( + cublaslt_exhaustive_search_times, 0, + "The times of exhaustive search for cuBlasLt matmul with/without " + " epilogue algorithms, default is 0, means disabling exhaustive search."); +#endif + #if defined(PADDLE_WITH_ASCEND_CL) PADDLE_DEFINE_EXPORTED_string( selected_npus, "", diff --git a/paddle/phi/backends/dynload/cublasLt.h b/paddle/phi/backends/dynload/cublasLt.h index a1562370c377b247a0b80e6412ff30da8da2e5b0..4c7ac9c3f21c45a301dc947f32471f364ec12439 100644 --- a/paddle/phi/backends/dynload/cublasLt.h +++ b/paddle/phi/backends/dynload/cublasLt.h @@ -1,4 +1,5 @@ /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Copyright (c) 2022 NVIDIA Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -52,19 +53,25 @@ extern void *cublasLt_dso_handle; // APIs available after CUDA 10.1 // #if CUDA_VERSION >= 10100 -#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ - __macro(cublasLtCreate); \ - __macro(cublasLtDestroy); \ - __macro(cublasLtMatmul); \ - __macro(cublasLtMatmulDescCreate); \ - __macro(cublasLtMatmulDescDestroy); \ - __macro(cublasLtMatmulDescSetAttribute); \ - __macro(cublasLtMatrixLayoutCreate); \ - __macro(cublasLtMatrixLayoutDestroy); \ - __macro(cublasLtMatrixLayoutSetAttribute); \ - __macro(cublasLtMatrixTransform); \ - __macro(cublasLtMatrixTransformDescCreate); \ - __macro(cublasLtMatrixTransformDescDestroy); \ +#define CUBLASLT_BLAS_ROUTINE_EACH(__macro) \ + __macro(cublasLtCreate); \ + __macro(cublasLtDestroy); \ + __macro(cublasLtMatmul); \ + __macro(cublasLtMatmulDescCreate); \ + __macro(cublasLtMatmulDescDestroy); \ + __macro(cublasLtMatmulDescSetAttribute); \ + __macro(cublasLtMatmulDescGetAttribute); \ + __macro(cublasLtMatrixLayoutCreate); \ + __macro(cublasLtMatrixLayoutDestroy); \ + __macro(cublasLtMatrixLayoutSetAttribute); \ + __macro(cublasLtMatrixLayoutGetAttribute); \ + __macro(cublasLtMatmulPreferenceCreate); \ + __macro(cublasLtMatmulPreferenceDestroy); \ + __macro(cublasLtMatmulPreferenceSetAttribute); \ + __macro(cublasLtMatmulAlgoGetHeuristic); \ + __macro(cublasLtMatrixTransform); \ + __macro(cublasLtMatrixTransformDescCreate); \ + __macro(cublasLtMatrixTransformDescDestroy); \ __macro(cublasLtMatrixTransformDescSetAttribute); CUBLASLT_BLAS_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUBLASLT_WRAP) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 5235b7f1e88abce0db5af2bc1f96190e672c6890..32d8f5e3847c81bed286c7a0bda07c764312d5f7 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -129,18 +129,11 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) - LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) - LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) - LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) endif() -if (WITH_GPU) - if (CUDA_VERSION LESS 11.6) - LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) - LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) - LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) - endif() -endif() +LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_op) +LIST(REMOVE_ITEM TEST_OPS test_fused_gemm_epilogue_grad_op) +LIST(REMOVE_ITEM TEST_OPS test_fuse_gemm_epilogue_pass) if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_c_comm_init_all_op) @@ -644,6 +637,15 @@ py_test_modules(test_imperative_static_runner_mnist MODULES test_imperative_stat FLAGS_cudnn_deterministic=1) py_test_modules(test_imperative_static_runner_while MODULES test_imperative_static_runner_while ENVS FLAGS_cudnn_deterministic=1) + +if ((WITH_GPU) AND (CUDA_VERSION GREATER_EQUAL 11.6)) + py_test_modules(test_fused_gemm_epilogue_op MODULES test_fused_gemm_epilogue_op) + py_test_modules(test_fused_gemm_epilogue_grad_op MODULES test_fused_gemm_epilogue_grad_op) + py_test_modules(test_fused_gemm_epilogue_op_with_es MODULES test_fused_gemm_epilogue_op ENVS FLAGS_cublaslt_exhaustive_search_times=30) + py_test_modules(test_fused_gemm_epilogue_grad_op_with_es MODULES test_fused_gemm_epilogue_grad_op ENVS FLAGS_cublaslt_exhaustive_search_times=30) + py_test_modules(test_fuse_gemm_epilogue_pass MODULES test_fuse_gemm_epilogue_pass) +endif() + set_tests_properties(test_conv2d_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_faster_tokenizer_op PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") set_tests_properties(test_conv2d_op_depthwise_conv PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") diff --git a/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py b/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py index 7f3180e21d8c63dd3fbc87d58c01f43422a01bcb..00d91b1fab0f1322f676ac15ca315c728dc72c74 100644 --- a/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py +++ b/python/paddle/fluid/tests/unittests/test_fuse_gemm_epilogue_pass.py @@ -49,8 +49,8 @@ def verify_node_count(graph, node_name, target_count): class MultiFCLayer(paddle.nn.Layer): def __init__(self, hidden, Activation): super(MultiFCLayer, self).__init__() - self.linear1 = paddle.nn.Linear(hidden, hidden) - self.linear2 = paddle.nn.Linear(hidden, hidden) + self.linear1 = paddle.nn.Linear(hidden, 4 * hidden) + self.linear2 = paddle.nn.Linear(4 * hidden, hidden) self.linear3 = paddle.nn.Linear(hidden, hidden) self.relu1 = Activation()