diff --git a/src/framework/cl/cl_tensor.h b/src/framework/cl/cl_tensor.h index e4d840f8322c05b21c3bcc8af46a284501129b49..c38091dd39c776254035f9b13c8505d64686915a 100644 --- a/src/framework/cl/cl_tensor.h +++ b/src/framework/cl/cl_tensor.h @@ -37,10 +37,10 @@ class CLTensor : TensorBase { } template - inline T *mutable_with_data(void *data) { + inline T mutable_with_data(void *data) { int64_t size = numel() * sizeof(float); holder_.reset(new PlaceholderImpl(size, data, typeid(T), context_)); - return reinterpret_cast( + return reinterpret_cast( reinterpret_cast(reinterpret_cast(holder_->ptr()))); } diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 93bb371eba6638807943a1b66768305481e4755e..d12dcdbdcea2e1070666ef297fe60b13c2ec6c6d 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -32,16 +32,13 @@ void FeedKernel::Compute(const FeedParam ¶m) { const Tensor *input = param.InputX(); const float *input_data = input->data(); int numel = input->numel(); - DLOG << "numel = " << numel; cl_mem cl_image = output->GetCLImage(); int height = output->dims()[2]; int width = output->dims()[3]; - DLOG << output->dims(); CLTensor input_cl_tensor(this->cl_helper_.CLContext()); input_cl_tensor.Resize(input->dims()); - cl_mem *inputBuffer = + cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data); - DLOG << "yangfei"; status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer); CL_CHECK_ERRORS(status); @@ -53,21 +50,18 @@ void FeedKernel::Compute(const FeedParam ¶m) { CL_CHECK_ERRORS(status); size_t global_work_size[2] = {height, width}; - DLOG << "yangfei"; status = clEnqueueNDRangeKernel(this->cl_helper_.CLCommandQueue(), kernel, 2, NULL, global_work_size, NULL, 0, NULL, NULL); CL_CHECK_ERRORS(status); int len = 4 * 224 * 224; half *out = new half[len]; - DLOG << "yangfei"; cl_command_queue commandQueue = this->cl_helper_.CLCommandQueue(); size_t origin[3] = {0, 0, 0}; size_t region[3] = {height, width, 1}; clEnqueueReadImage(commandQueue, cl_image, CL_TRUE, origin, region, 0, 0, out, 0, NULL, NULL); - DLOG << "yangfei"; - for (int i = 0; i < 100; i++) DLOG << out[i]; + for (int i = 0; i < numel; i++) DLOG << Half2Float(out[i]); } template class FeedKernel;