提交 2cc5baed 编写于 作者: W Wilber 提交者: GitHub

fix graphics memory leak problem. test=develop (#3598)

上级 a9664357
...@@ -33,6 +33,9 @@ bool BatchedGemm<float, float>::init(const bool trans_a, ...@@ -33,6 +33,9 @@ bool BatchedGemm<float, float>::init(const bool trans_a,
} }
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N; cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N; cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
if (A_ != nullptr) {
cudaFree(A_);
}
cudaMalloc(reinterpret_cast<void **>(&A_), cudaMalloc(reinterpret_cast<void **>(&A_),
3 * max_batch_size * sizeof(float *)); 3 * max_batch_size * sizeof(float *));
return true; return true;
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <limits>
#include <memory> #include <memory>
#include "lite/backends/cuda/math/batched_gemm.h" #include "lite/backends/cuda/math/batched_gemm.h"
#include "lite/core/context.h" #include "lite/core/context.h"
...@@ -32,6 +33,7 @@ class SearchAlignedMatMulCompute ...@@ -32,6 +33,7 @@ class SearchAlignedMatMulCompute
void PrepareForRun() override { void PrepareForRun() override {
batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm<float, float>); batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm<float, float>);
last_seq_num_ = std::numeric_limits<int>::min();
} }
void Run() override { void Run() override {
...@@ -75,8 +77,11 @@ class SearchAlignedMatMulCompute ...@@ -75,8 +77,11 @@ class SearchAlignedMatMulCompute
A_[seq + seq_num * 2] = out_data + seq * out_stride; A_[seq + seq_num * 2] = out_data + seq * out_stride;
} }
CHECK( if (seq_num != last_seq_num_) {
batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx)); CHECK(batched_gemm_impl_->init(
x_transpose, y_transpose, seq_num, &cuda_ctx));
last_seq_num_ = seq_num;
}
batched_gemm_impl_->run( batched_gemm_impl_->run(
alpha, 0.0f, const_cast<const float**>(A_), M, N, K, seq_num); alpha, 0.0f, const_cast<const float**>(A_), M, N, K, seq_num);
} }
...@@ -86,6 +91,7 @@ class SearchAlignedMatMulCompute ...@@ -86,6 +91,7 @@ class SearchAlignedMatMulCompute
private: private:
std::unique_ptr<lite::cuda::math::BatchedGemm<float, float>> std::unique_ptr<lite::cuda::math::BatchedGemm<float, float>>
batched_gemm_impl_; batched_gemm_impl_;
int last_seq_num_;
}; };
} // namespace cuda } // namespace cuda
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册