未验证 提交 2425f9a1 编写于 作者: Z zhaoyang-star 提交者: GitHub

[Cherry-pick][Bugfix][OpenCL][Core] fix opencl multi-run result error (#4413)

* [Bugfix][OpenCL][Core] fix opencl multi-run result error when using memory_optimize_pass (#4410)

* [Bugfix][OpenCL][Core] fix opencl multi-run result error when using memory_optimize_pass. test=develop

* test=develop
Co-authored-by: Nysh329 <ysh329@users.noreply.github.com>
上级 01242ee2
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <string> #include <string>
#include "lite/api/paddle_place.h" #include "lite/api/paddle_place.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
...@@ -135,20 +136,21 @@ class Buffer { ...@@ -135,20 +136,21 @@ class Buffer {
#ifdef LITE_WITH_OPENCL #ifdef LITE_WITH_OPENCL
template <typename T> template <typename T>
void ResetLazyImage2D(TargetType target, void ResetLazyImage2D(TargetType target,
const size_t img_w, const size_t img_w_req,
const size_t img_h, const size_t img_h_req,
void* host_ptr = nullptr) { void* host_ptr = nullptr) {
if (target != target_ || cl_image2d_width_ < img_w || if (target != target_ || cl_image2d_width_ < img_w_req ||
cl_image2d_height_ < img_h || host_ptr != nullptr) { cl_image2d_height_ < img_h_req || host_ptr != nullptr) {
CHECK_EQ(own_data_, true) << "Can not reset unowned buffer."; CHECK_EQ(own_data_, true) << "Can not reset unowned buffer.";
cl_image2d_width_ = std::max(cl_image2d_width_, img_w_req);
cl_image2d_height_ = std::max(cl_image2d_height_, img_h_req);
Free(); Free();
data_ = TargetWrapperCL::MallocImage<T>(img_w, img_h, host_ptr); data_ = TargetWrapperCL::MallocImage<T>(
cl_image2d_width_, cl_image2d_height_, host_ptr);
target_ = target; target_ = target;
space_ = sizeof(T) * img_w * img_h * space_ = sizeof(T) * cl_image2d_width_ * cl_image2d_height_ *
4; // un-used for opencl Image2D, 4 for RGBA, 4; // un-used for opencl Image2D, 4 for RGBA,
cl_use_image2d_ = true; cl_use_image2d_ = true;
cl_image2d_width_ = img_w;
cl_image2d_height_ = img_h;
} }
} }
#endif #endif
......
...@@ -28,6 +28,12 @@ TEST(memory, test) { ...@@ -28,6 +28,12 @@ TEST(memory, test) {
ASSERT_TRUE(buf_cuda); ASSERT_TRUE(buf_cuda);
TargetFree(TARGET(kCUDA), buf_cuda); TargetFree(TARGET(kCUDA), buf_cuda);
#endif #endif
#ifdef LITE_WITH_OPENCL
auto* buf_cl = TargetMalloc(TARGET(kOpenCL), 10);
ASSERT_TRUE(buf_cl);
TargetFree(TARGET(kOpenCL), buf_cl);
#endif
} }
} // namespace lite } // namespace lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册