From 2cc5baed5cec9a853eab003a33861412e56a70da Mon Sep 17 00:00:00 2001 From: Wilber Date: Sat, 9 May 2020 21:25:42 +0800 Subject: [PATCH] fix graphics memory leak problem. test=develop (#3598) --- lite/backends/cuda/math/batched_gemm.cc | 3 +++ lite/kernels/cuda/search_aligned_mat_mul_compute.h | 10 ++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/lite/backends/cuda/math/batched_gemm.cc b/lite/backends/cuda/math/batched_gemm.cc index e815109276..bc605e39fb 100644 --- a/lite/backends/cuda/math/batched_gemm.cc +++ b/lite/backends/cuda/math/batched_gemm.cc @@ -33,6 +33,9 @@ bool BatchedGemm::init(const bool trans_a, } cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; + if (A_ != nullptr) { + cudaFree(A_); + } cudaMalloc(reinterpret_cast(&A_), 3 * max_batch_size * sizeof(float *)); return true; diff --git a/lite/kernels/cuda/search_aligned_mat_mul_compute.h b/lite/kernels/cuda/search_aligned_mat_mul_compute.h index 8304b0f2b4..3d5fc19f14 100644 --- a/lite/kernels/cuda/search_aligned_mat_mul_compute.h +++ b/lite/kernels/cuda/search_aligned_mat_mul_compute.h @@ -13,6 +13,7 @@ // limitations under the License. #pragma once +#include #include #include "lite/backends/cuda/math/batched_gemm.h" #include "lite/core/context.h" @@ -32,6 +33,7 @@ class SearchAlignedMatMulCompute void PrepareForRun() override { batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm); + last_seq_num_ = std::numeric_limits::min(); } void Run() override { @@ -75,8 +77,11 @@ class SearchAlignedMatMulCompute A_[seq + seq_num * 2] = out_data + seq * out_stride; } - CHECK( - batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); + if (seq_num != last_seq_num_) { + CHECK(batched_gemm_impl_->init( + x_transpose, y_transpose, seq_num, &cuda_ctx)); + last_seq_num_ = seq_num; + } batched_gemm_impl_->run( alpha, 0.0f, const_cast(A_), M, N, K, seq_num); } @@ -86,6 +91,7 @@ class SearchAlignedMatMulCompute private: std::unique_ptr> batched_gemm_impl_; + int last_seq_num_; }; } // namespace cuda -- GitLab