From bb07379e2ba78babf5e1d3a7e5f15afdb677e6b8 Mon Sep 17 00:00:00 2001 From: liuruilong Date: Wed, 17 Oct 2018 16:44:40 +0800 Subject: [PATCH] update cl tensor --- src/framework/cl/cl_tensor.h | 14 ++++++++------ src/operators/kernel/cl/feed_kernel.cpp | 2 +- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/framework/cl/cl_tensor.h b/src/framework/cl/cl_tensor.h index fc805cbba1..204c3e3ca1 100644 --- a/src/framework/cl/cl_tensor.h +++ b/src/framework/cl/cl_tensor.h @@ -48,12 +48,14 @@ class CLTensor : TensorBase { return *this; } - inline cl_mem mutable_with_data(void *data) { - int64_t size = numel() * sizeof(float); - holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_, - command_queue_)); - return reinterpret_cast( - reinterpret_cast(reinterpret_cast(holder_->ptr()))); + template + inline cl_mem mutable_with_data(const T *data) { + int64_t size = numel() * sizeof(T); + + holder_.reset(new PlaceholderImpl( + size, reinterpret_cast(const_cast(data)), typeid(T), + context_, command_queue_)); + return reinterpret_cast(holder_->ptr()); } inline cl_mem mutable_data(std::type_index type) { diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 7c61180e51..0edb1f2bee 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -39,7 +39,7 @@ void FeedKernel::Compute(const FeedParam ¶m) { CLTensor input_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); input_cl_tensor.Resize(input->dims()); - cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data); + cl_mem inputBuffer = input_cl_tensor.mutable_with_data(input_data); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer); CL_CHECK_ERRORS(status); -- GitLab