提交 2cc5baed 编写于 作者: W Wilber 提交者: GitHub

fix graphics memory leak problem. test=develop (#3598)

上级 a9664357
......@@ -33,6 +33,9 @@ bool BatchedGemm<float, float>::init(const bool trans_a,
}
cu_trans_a_ = trans_a ? CUBLAS_OP_T : CUBLAS_OP_N;
cu_trans_b_ = trans_b ? CUBLAS_OP_T : CUBLAS_OP_N;
if (A_ != nullptr) {
cudaFree(A_);
}
cudaMalloc(reinterpret_cast<void **>(&A_),
3 * max_batch_size * sizeof(float *));
return true;
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#pragma once
#include <limits>
#include <memory>
#include "lite/backends/cuda/math/batched_gemm.h"
#include "lite/core/context.h"
......@@ -32,6 +33,7 @@ class SearchAlignedMatMulCompute
void PrepareForRun() override {
batched_gemm_impl_.reset(new lite::cuda::math::BatchedGemm<float, float>);
last_seq_num_ = std::numeric_limits<int>::min();
}
void Run() override {
......@@ -75,8 +77,11 @@ class SearchAlignedMatMulCompute
A_[seq + seq_num * 2] = out_data + seq * out_stride;
}
CHECK(
batched_gemm_impl_->init(x_transpose, y_transpose, seq_num, &cuda_ctx));
if (seq_num != last_seq_num_) {
CHECK(batched_gemm_impl_->init(
x_transpose, y_transpose, seq_num, &cuda_ctx));
last_seq_num_ = seq_num;
}
batched_gemm_impl_->run(
alpha, 0.0f, const_cast<const float**>(A_), M, N, K, seq_num);
}
......@@ -86,6 +91,7 @@ class SearchAlignedMatMulCompute
private:
std::unique_ptr<lite::cuda::math::BatchedGemm<float, float>>
batched_gemm_impl_;
int last_seq_num_;
};
} // namespace cuda
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册