diff --git a/src/framework/cl/cl_tensor.h b/src/framework/cl/cl_tensor.h index 1d6829fe4b77639f34df0be37d7a539b91ff4bcc..fc805cbba148516a1759a9b2c98e02ede26344e2 100644 --- a/src/framework/cl/cl_tensor.h +++ b/src/framework/cl/cl_tensor.h @@ -48,16 +48,15 @@ class CLTensor : TensorBase { return *this; } - template - inline T mutable_with_data(void *data) { + inline cl_mem mutable_with_data(void *data) { int64_t size = numel() * sizeof(float); - holder_.reset( - new PlaceholderImpl(size, data, typeid(T), context_, command_queue_)); - return reinterpret_cast( + holder_.reset(new PlaceholderImpl(size, data, typeid(cl_mem), context_, + command_queue_)); + return reinterpret_cast( reinterpret_cast(reinterpret_cast(holder_->ptr()))); } - inline void *mutable_data(std::type_index type) { + inline cl_mem mutable_data(std::type_index type) { if (holder_ != nullptr) { holder_->set_type(type); } @@ -67,22 +66,20 @@ class CLTensor : TensorBase { holder_.reset(new PlaceholderImpl(size, type, context_, command_queue_)); offset_ = 0; } - return reinterpret_cast( - reinterpret_cast(holder_->ptr()) + offset_); + return reinterpret_cast(holder_->ptr()); } /** - * @brief Return a pointer to mutable memory block. + * @brief Return a pointer to cl buffer. * @note If not exist, then allocation. */ template - inline T *mutable_data() { - static_assert(std::is_pod::value, "T must be POD"); - return reinterpret_cast(mutable_data(typeid(T))); + inline cl_mem mutable_data() { + return reinterpret_cast(mutable_data(typeid(T))); } /** - * @brief Return a pointer to mutable memory block. + * @brief Return a pointer to cl buffer. * * @param[in] dims The dimensions of the memory block. * @param[in] place The place of the memory block. @@ -90,27 +87,44 @@ class CLTensor : TensorBase { * @note If not exist, then allocation. */ template - inline T *mutable_data(DDim dims) { - static_assert(std::is_pod::value, "T must be POD"); + inline cl_mem mutable_data(DDim dims) { Resize(dims); return mutable_data(); } - private: - cl_context context_; - cl_command_queue command_queue_; - - /* - * virtual ~Placeholder() = default; + inline cl_mem CLBuffer() { + check_memory_size(); + return reinterpret_cast( + reinterpret_cast(holder_->ptr()) + offset_); + } - virtual void *ptr() const = 0; + template + inline T *Data() { + if (host_ptr_) { + delete (host_ptr_); + host_ptr_ = nullptr; + } + cl_mem buffer = CLBuffer(); + host_ptr_ = new char[holder_->size()]; + cl_int status; + status = clEnqueueReadBuffer(command_queue_, buffer, CL_TRUE, 0, + holder_->size(), host_ptr_, 0, NULL, NULL); + CL_CHECK_ERRORS(status); + return reinterpret_cast(host_ptr_); + } - virtual size_t size() const = 0; + ~CLTensor() { + if (host_ptr_) { + delete (host_ptr_); + host_ptr_ = nullptr; + } + } - virtual std::type_index type() const = 0; + private: + cl_context context_; + cl_command_queue command_queue_; + void *host_ptr_; - virtual void set_type(std::type_index type) = 0; - * */ struct PlaceholderImpl : public Placeholder { PlaceholderImpl(size_t size, void *input, std::type_index type, cl_context context, cl_command_queue command_queue) @@ -129,15 +143,7 @@ class CLTensor : TensorBase { virtual size_t size() const { return size_; } - virtual void *ptr() const { - if (host_ptr_) { - delete (host_ptr_); - } - char *host_ptr = new char[size_]; - clEnqueueReadBuffer(command_queue_, ptr_.get(), CL_TRUE, 0, size_, - host_ptr, 0, NULL, NULL); - return static_cast(host_ptr); - } + virtual void *ptr() const { return static_cast(ptr_.get()); } virtual std::type_index type() const { return type_; } @@ -151,15 +157,6 @@ class CLTensor : TensorBase { std::type_index type_; cl_command_queue command_queue_; - - ~PlaceholderImpl() { - if (host_ptr_) { - delete (host_ptr_); - } - } - - private: - void *host_ptr_; }; }; diff --git a/src/framework/tensor.h b/src/framework/tensor.h index b6990a07d81f5b8ce005ec56aca9c7403d028157..240f78e3c4c1ed3e9e744214dc1384a1249ee4fc 100644 --- a/src/framework/tensor.h +++ b/src/framework/tensor.h @@ -157,6 +157,34 @@ class Tensor : public TensorBase { } } + /*! Return a pointer to mutable memory block. */ + template + inline T *data() { + check_memory_size(); + PADDLE_MOBILE_ENFORCE( + (std::is_same::value || + holder_->type().hash_code() == typeid(T).hash_code()), + "Tensor holds the wrong type, it holds %s", + this->holder_->type().name()); + + return reinterpret_cast(reinterpret_cast(holder_->ptr()) + + offset_); + } + + /*! Return a pointer to constant memory block. */ + template + inline const T *data() const { + check_memory_size(); + PADDLE_MOBILE_ENFORCE( + (std::is_same::value || + holder_->type().hash_code() == typeid(T).hash_code()), + "Tensor holds the wrong type, it holds %s ,requested:%s", + this->holder_->type().name(), typeid(T).name()); + + return reinterpret_cast( + reinterpret_cast(holder_->ptr()) + offset_); + } + private: struct PlaceholderImpl : public Placeholder { PlaceholderImpl(size_t size, std::type_index type) diff --git a/src/framework/tensor_base.h b/src/framework/tensor_base.h index fe0c9116d4182218a6349a50ca6740c3dbcd5b6b..e1539d2e681973b39eeca5b30e2ed35b535be8cb 100644 --- a/src/framework/tensor_base.h +++ b/src/framework/tensor_base.h @@ -72,36 +72,6 @@ class TensorBase { inline bool IsInitialized() const { return holder_ != nullptr; } - virtual inline void *mutable_data(std::type_index type) = 0; - - /*! Return a pointer to mutable memory block. */ - template - inline T *data() { - check_memory_size(); - PADDLE_MOBILE_ENFORCE( - (std::is_same::value || - holder_->type().hash_code() == typeid(T).hash_code()), - "Tensor holds the wrong type, it holds %s", - this->holder_->type().name()); - - return reinterpret_cast(reinterpret_cast(holder_->ptr()) + - offset_); - } - - /*! Return a pointer to constant memory block. */ - template - inline const T *data() const { - check_memory_size(); - PADDLE_MOBILE_ENFORCE( - (std::is_same::value || - holder_->type().hash_code() == typeid(T).hash_code()), - "Tensor holds the wrong type, it holds %s ,requested:%s", - this->holder_->type().name(), typeid(T).name()); - - return reinterpret_cast( - reinterpret_cast(holder_->ptr()) + offset_); - } - /*! Return the dimensions of the memory block. */ inline const DDim &dims() const { return dims_; } diff --git a/src/operators/kernel/cl/feed_kernel.cpp b/src/operators/kernel/cl/feed_kernel.cpp index 0db2b7cc4665ff74d06ca62ba9e77d427d883233..7c61180e51d9d0c0598b1b8f6d61351a4567ed8c 100644 --- a/src/operators/kernel/cl/feed_kernel.cpp +++ b/src/operators/kernel/cl/feed_kernel.cpp @@ -39,8 +39,7 @@ void FeedKernel::Compute(const FeedParam ¶m) { CLTensor input_cl_tensor(this->cl_helper_.CLContext(), this->cl_helper_.CLCommandQueue()); input_cl_tensor.Resize(input->dims()); - cl_mem inputBuffer = - input_cl_tensor.mutable_with_data((void *)input_data); + cl_mem inputBuffer = input_cl_tensor.mutable_with_data((void *)input_data); status = clSetKernelArg(kernel, 0, sizeof(cl_mem), (void *)&inputBuffer); CL_CHECK_ERRORS(status);