提交 bb07379e 编写于 作者: L liuruilong

update cl tensor

上级 ca9f3648
......@@ -48,12 +48,14 @@ class CLTensor : TensorBase {
return *this;
}
inline cl_mem mutable_with_data(void *data) {
int64_t size = numel() * sizeof(float);
holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_,
command_queue_));
return reinterpret_cast<cl_mem>(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(holder_->ptr())));
template <typename T>
inline cl_mem mutable_with_data(const T *data) {
int64_t size = numel() * sizeof(T);
holder_.reset(new PlaceholderImpl(
size, reinterpret_cast<void *>(const_cast<T *>(data)), typeid(T),
context_, command_queue_));
return reinterpret_cast<cl_mem>(holder_->ptr());
}
inline cl_mem mutable_data(std::type_index type) {
......
......@@ -39,7 +39,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data);
cl_mem inputBuffer = input_cl_tensor.mutable_with_data<float>(input_data);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer);
CL_CHECK_ERRORS(status);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册