提交 bb07379e 编写于 作者: L liuruilong

update cl tensor

上级 ca9f3648
...@@ -48,12 +48,14 @@ class CLTensor : TensorBase { ...@@ -48,12 +48,14 @@ class CLTensor : TensorBase {
return *this; return *this;
} }
inline cl_mem mutable_with_data(void *data) { template <typename T>
int64_t size = numel() * sizeof(float); inline cl_mem mutable_with_data(const T *data) {
holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_, int64_t size = numel() * sizeof(T);
command_queue_));
return reinterpret_cast<cl_mem>( holder_.reset(new PlaceholderImpl(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(holder_->ptr()))); size, reinterpret_cast<void *>(const_cast<T *>(data)), typeid(T),
context_, command_queue_));
return reinterpret_cast<cl_mem>(holder_->ptr());
} }
inline cl_mem mutable_data(std::type_index type) { inline cl_mem mutable_data(std::type_index type) {
......
...@@ -39,7 +39,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) { ...@@ -39,7 +39,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
CLTensor input_cl_tensor(this->cl_helper_.CLContext(), CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue()); this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims()); input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data); cl_mem inputBuffer = input_cl_tensor.mutable_with_data<float>(input_data);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer);
CL_CHECK_ERRORS(status); CL_CHECK_ERRORS(status);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册