diff --git a/src/framework/cl/cl_tensor.h b/src/framework/cl/cl_tensor.h index fc805cbba148516a1759a9b2c98e02ede26344e2..204c3e3ca185817bdf5b966fe04fac8574acd8a7 100644 --- a/src/framework/cl/cl_tensor.h +++ b/src/framework/cl/cl_tensor.h @@ -48,12 +48,14 @@ class CLTensor : TensorBase { return *this; } - inline cl_mem mutable_with_data(void *data) { - int64_t size = numel() * sizeof(float); - holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_, - command_queue_)); - return reinterpret_cast( - reinterpret_cast(reinterpret_cast(holder_->ptr()))); + template + inline cl_mem mutable_with_data(const T *data) { + int64_t size = numel() * sizeof(T); + + holder_.reset(new PlaceholderImpl( + size, reinterpret_cast(const_cast(data)), typeid(T), + context_, command_queue_)); + return reinterpret_cast(holder_->ptr()); } inline cl_mem mutable_data(std::type_index type) { diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 7c61180e51d9d0c0598b1b8f6d61351a4567ed8c..0edb1f2bee2dd65861e665f76bdd28ac0a3913c0 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -39,7 +39,7 @@ void FeedKernel::Compute(const FeedParam ¶m) { CLTensor input_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); input_cl_tensor.Resize(input->dims()); - cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data); + cl_mem inputBuffer = input_cl_tensor.mutable_with_data(input_data); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer); CL_CHECK_ERRORS(status);