提交 ca9f3648 编写于 作者: L liuruilong

update cl tensor

上级 e5fbd777
......@@ -48,16 +48,15 @@ class CLTensor : TensorBase {
return *this;
}
template <typename T>
inline T mutable_with_data(void *data) {
inline cl_mem mutable_with_data(void *data) {
int64_t size = numel() * sizeof(float);
holder_.reset(
new PlaceholderImpl(size, data, typeid(T), context_, command_queue_));
return reinterpret_cast<T>(
holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_,
command_queue_));
return reinterpret_cast<cl_mem>(
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(holder_->ptr())));
}
inline void *mutable_data(std::type_index type) {
inline cl_mem mutable_data(std::type_index type) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
......@@ -67,22 +66,20 @@ class CLTensor : TensorBase {
holder_.reset(new PlaceholderImpl(size, type, context_, command_queue_));
offset_ = 0;
}
return reinterpret_cast<void *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
return reinterpret_cast<cl_mem>(holder_->ptr());
}
/**
* @brief Return a pointer to mutable memory block.
* @brief Return a pointer to cl buffer.
* @note If not exist, then allocation.
*/
template <typename T>
inline T *mutable_data() {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T *>(mutable_data(typeid(T)));
inline cl_mem mutable_data() {
return reinterpret_cast<cl_mem>(mutable_data(typeid(T)));
}
/**
* @brief Return a pointer to mutable memory block.
* @brief Return a pointer to cl buffer.
*
* @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block.
......@@ -90,27 +87,44 @@ class CLTensor : TensorBase {
* @note If not exist, then allocation.
*/
template <typename T>
inline T *mutable_data(DDim dims) {
static_assert(std::is_pod<T>::value, "T must be POD");
inline cl_mem mutable_data(DDim dims) {
Resize(dims);
return mutable_data<T>();
}
private:
cl_context context_;
cl_command_queue command_queue_;
/*
* virtual ~Placeholder() = default;
inline cl_mem CLBuffer() {
check_memory_size();
return reinterpret_cast<cl_mem>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
virtual void *ptr() const = 0;
template <typename T>
inline T *Data() {
if (host_ptr_) {
delete (host_ptr_);
host_ptr_ = nullptr;
}
cl_mem buffer = CLBuffer();
host_ptr_ = new char[holder_->size()];
cl_int status;
status = clEnqueueReadBuffer(command_queue_, buffer, CL_TRUE, 0,
holder_->size(), host_ptr_, 0, NULL, NULL);
CL_CHECK_ERRORS(status);
return reinterpret_cast<T *>(host_ptr_);
}
virtual size_t size() const = 0;
~CLTensor() {
if (host_ptr_) {
delete (host_ptr_);
host_ptr_ = nullptr;
}
}
virtual std::type_index type() const = 0;
private:
cl_context context_;
cl_command_queue command_queue_;
void *host_ptr_;
virtual void set_type(std::type_index type) = 0;
* */
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, void *input, std::type_index type,
cl_context context, cl_command_queue command_queue)
......@@ -129,15 +143,7 @@ class CLTensor : TensorBase {
virtual size_t size() const { return size_; }
virtual void *ptr() const {
if (host_ptr_) {
delete (host_ptr_);
}
char *host_ptr = new char[size_];
clEnqueueReadBuffer(command_queue_, ptr_.get(), CL_TRUE, 0, size_,
host_ptr, 0, NULL, NULL);
return static_cast<void *>(host_ptr);
}
virtual void *ptr() const { return static_cast<void *>(ptr_.get()); }
virtual std::type_index type() const { return type_; }
......@@ -151,15 +157,6 @@ class CLTensor : TensorBase {
std::type_index type_;
cl_command_queue command_queue_;
~PlaceholderImpl() {
if (host_ptr_) {
delete (host_ptr_);
}
}
private:
void *host_ptr_;
};
};
......
......@@ -157,6 +157,34 @@ class Tensor : public TensorBase {
}
}
/*! Return a pointer to mutable memory block. */
template <typename T>
inline T *data() {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
/*! Return a pointer to constant memory block. */
template <typename T>
inline const T *data() const {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s ,requested:%s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
private:
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t size, std::type_index type)
......
......@@ -72,36 +72,6 @@ class TensorBase {
inline bool IsInitialized() const { return holder_ != nullptr; }
virtual inline void *mutable_data(std::type_index type) = 0;
/*! Return a pointer to mutable memory block. */
template <typename T>
inline T *data() {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s",
this->holder_->type().name());
return reinterpret_cast<T *>(reinterpret_cast<uintptr_t>(holder_->ptr()) +
offset_);
}
/*! Return a pointer to constant memory block. */
template <typename T>
inline const T *data() const {
check_memory_size();
PADDLE_MOBILE_ENFORCE(
(std::is_same<T, void>::value ||
holder_->type().hash_code() == typeid(T).hash_code()),
"Tensor holds the wrong type, it holds %s ,requested:%s",
this->holder_->type().name(), typeid(T).name());
return reinterpret_cast<const T *>(
reinterpret_cast<uintptr_t>(holder_->ptr()) + offset_);
}
/*! Return the dimensions of the memory block. */
inline const DDim &dims() const { return dims_; }
......
......@@ -39,8 +39,7 @@ void FeedKernel<GPU_CL, float>::Compute(const FeedParam<GPU_CL> &param) {
CLTensor input_cl_tensor(this->cl_helper_.CLContext(),
this->cl_helper_.CLCommandQueue());
input_cl_tensor.Resize(input->dims());
cl_mem inputBuffer =
input_cl_tensor.mutable_with_data<cl_mem>((void *)input_data);
cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data);
status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer);
CL_CHECK_ERRORS(status);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册