diff --git a/src/framework/cl/cl_image.cpp b/src/framework/cl/cl_image.cpp index fd2cf760cdd9d5d6a0a0ddb8d83dc095447ad0cc..19a6b2896e434c710ad11ee08bc524708ddbbcdd 100644 --- a/src/framework/cl/cl_image.cpp +++ b/src/framework/cl/cl_image.cpp @@ -125,7 +125,7 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { float *data = new float[cl_image.numel()]; DDim ddim = cl_image.dims(); size_t N, C, H, W, width, height; - if (cl_image.GetImageType() == Normal) { + if (cl_image.GetImageType() == 0 || ddim.size() == 4) { if (ddim.size() == 4) { N = ddim[0]; if (N < 0) { @@ -152,7 +152,7 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { W = ddim[0]; } float *p = data; - half imageData[width * height * 4]; + half *imageData = new half[height * width * 4]; cl_int err; cl_mem image = cl_image.GetCLImage(); size_t origin[3] = {0, 0, 0}; @@ -175,7 +175,7 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { } i0 += width * H; } - + delete (imageData); CL_CHECK_ERRORS(err); } else { if (ddim.size() == 2) { @@ -191,7 +191,7 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { W = ddim[0]; } float *p = data; - half imageData[width * height * 4]; + half *imageData = new half[width * height * 4]; cl_int err; cl_mem image = cl_image.GetCLImage(); size_t origin[3] = {0, 0, 0}; @@ -210,6 +210,7 @@ Print &operator<<(Print &printer, const CLImage &cl_image) { for (int i = 0; i < cl_image.numel(); i += stride) { printer << data[i] << " "; } + delete (data); return printer; } #endif diff --git a/src/framework/cl/cl_image.h b/src/framework/cl/cl_image.h index de2578e7cb682126baa5a752dfd29dc1d680f8b4..0493264eaa6072f27d0a5d8574ecb0dfee9ada90 100644 --- a/src/framework/cl/cl_image.h +++ b/src/framework/cl/cl_image.h @@ -26,7 +26,7 @@ limitations under the License. */ namespace paddle_mobile { namespace framework { -enum ImageType { Normal, Folder }; +enum ImageType { Normal = 0, Folder = 1 }; class CLImage { public: @@ -88,7 +88,7 @@ class CLImage { cl_mem GetCLImage() const { return cl_image_; } - const DDim &ImageDims() { return image_dims_; } + const DDim &ImageDims() const { return image_dims_; } inline size_t ImageWidth() const { return image_width_; } @@ -139,7 +139,7 @@ class CLImage { * */ const DDim &dims() const { return tensor_dims_; } - const ImageType GetImageType() const { type; } + const ImageType GetImageType() const { return type; } private: ImageType type; @@ -157,6 +157,14 @@ class CLImage { } int width = (tdim[1] + 3) / 4; int height = tdim[0]; + + width_of_one_block_ = tdim[1]; + height_of_one_block_ = tdim[0]; + + image_width_ = width; + image_height_ = height; + image_dims_ = make_ddim({image_width_, image_height_}); + c_block_ = tdim[1] / width; std::unique_ptr imageData{}; if (tensor_data) { imageData.reset(new half_t[width * height * 4]);