diff --git a/lite/backends/cuda/math/batched_gemm.cc b/lite/backends/cuda/math/batched_gemm.cc index e81510927615daa88e7f5bef3ce7b8421d8f6539..bc605e39fb2acdc53c1f2ac9da738a24f29330c8 100644 --- a/lite/backends/cuda/math/batched_gemm.cc +++ b/lite/backends/cuda/math/batched_gemm.cc @@ -33,6 +33,9 @@ bool BatchedGemm::init(const bool trans_a, } cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + if (A_ != nullptr) { + cudaFree(A_); + } cudaMalloc(reinterpret_cast(&A_), 3 * max_batch_size * sizeof(float *)); return true; diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.h b/lite/kernels/cuda/search_aligned_mat_mul_compute.h index 8304b0f2b43d4114def029e32aa9086fc29199a4..3d5fc19f1479b65370d823e46b7e18ae9d742139 100644 --- a/lite/kernels/cuda/search_aligned_mat_mul_compute.h +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include "lite/backends/cuda/math/batched_gemm.h" #include "lite/core/context.h" @@ -32,6 +33,7 @@ class SearchAlignedMatMulCompute void PrepareForRun() override { batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm); + last_seq_num_ = std::numeric_limits::min(); } void Run() override { @@ -75,8 +77,11 @@ class SearchAlignedMatMulCompute A_[seq + seq_num * 2] = out_data + seq * out_stride; } - CHECK( - batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); + if (seq_num != last_seq_num_) { + CHECK(batched_gemm_impl_->init( + x_transpose, y_transpose, seq_num, &cuda_ctx)); + last_seq_num_ = seq_num; + } batched_gemm_impl_->run( alpha, 0.0f, const_cast(A_), M, N, K, seq_num); } @@ -86,6 +91,7 @@ class SearchAlignedMatMulCompute private: std::unique_ptr> batched_gemm_impl_; + int last_seq_num_; }; } // namespace cuda