diff --git a/mace/core/buffer.h b/mace/core/buffer.h index c859268f818d998983d610333636f187195e8aea..0ec4782433cd810657d8a97545ad0880d2e5ad46 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -25,6 +25,12 @@ #include "mace/core/types.h" namespace mace { +namespace core { +enum BufferType { + BT_BUFFER, + BT_IMAGE, +}; +} // namespace core class BufferBase { public: @@ -32,6 +38,8 @@ class BufferBase { explicit BufferBase(index_t size) : size_(size) {} virtual ~BufferBase() {} + virtual core::BufferType buffer_type() const = 0; + virtual void *buffer() = 0; virtual const void *raw_data() const = 0; @@ -63,6 +71,8 @@ class BufferBase { virtual void Clear(index_t size) = 0; + virtual const std::vector shape() const = 0; + virtual index_t offset() const { return 0; } template @@ -106,6 +116,10 @@ class Buffer : public BufferBase { } } + core::BufferType buffer_type() const { + return core::BufferType::BT_BUFFER; + } + void *buffer() { MACE_CHECK_NOTNULL(buf_); return buf_; @@ -207,6 +221,11 @@ class Buffer : public BufferBase { memset(reinterpret_cast(raw_mutable_data()), 0, size); } + const std::vector shape() const { + MACE_NOT_IMPLEMENTED; + return {}; + } + protected: Allocator *allocator_; void *buf_; @@ -238,6 +257,10 @@ class Image : public BufferBase { return data_type_; } + core::BufferType buffer_type() const { + return core::BufferType::BT_IMAGE; + } + void *buffer() { MACE_CHECK_NOTNULL(buf_); return buf_; @@ -253,8 +276,6 @@ class Image : public BufferBase { return mapped_buf_; } - std::vector image_shape() const { return shape_; } - MaceStatus Allocate(index_t nbytes) { MACE_UNUSED(nbytes); LOG(FATAL) << "Image should not call this allocate function"; @@ -328,6 +349,10 @@ class Image : public BufferBase { MACE_NOT_IMPLEMENTED; } + const std::vector shape() const { + return shape_; + } + private: Allocator *allocator_; std::vector shape_; @@ -365,6 +390,10 @@ class BufferSlice : public BufferBase { } } + core::BufferType buffer_type() const { + return core::BufferType::BT_BUFFER; + } + void *buffer() { MACE_CHECK_NOTNULL(buffer_); return buffer_->buffer(); @@ -454,6 +483,11 @@ class BufferSlice : public BufferBase { memset(raw_mutable_data(), 0, size); } + const std::vector shape() const { + MACE_NOT_IMPLEMENTED; + return {}; + } + private: BufferBase *buffer_; void *mapped_buf_; diff --git a/mace/core/runtime/opencl/scratch_image.cc b/mace/core/runtime/opencl/scratch_image.cc index d2d4dcfebca536e2ef99e37ac90cdd6194053108..ccca3896b972412f75376973ff1d210131ca0691 100644 --- a/mace/core/runtime/opencl/scratch_image.cc +++ b/mace/core/runtime/opencl/scratch_image.cc @@ -33,7 +33,7 @@ Image *ScratchImageManager::Spawn( for (int i = 0; i < image_count; ++i) { int count = reference_count_[i]; if (count == 0 && images_.at(count)->dtype() == dt) { - auto image_shape = images_.at(count)->image_shape(); + auto image_shape = images_.at(count)->shape(); if (image_shape[0] >= shape[0] && image_shape[1] >= shape[1]) { found_image_idx = i; break; diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 22d5f77270fc030c6915805c850ef2bb379ee489..e70d48a9feca6705deefde877eeb00212cc26b9c 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -215,7 +215,7 @@ class Tensor { inline bool has_opencl_image() const { return buffer_ != nullptr && !buffer_->OnHost() && - typeid(*buffer_) == typeid(Image); + buffer_->buffer_type() == core::BufferType::BT_IMAGE; } inline bool has_opencl_buffer() const { @@ -226,7 +226,7 @@ class Tensor { MACE_CHECK(buffer_ != nullptr, "Tensor ", name_, " is empty"); if (buffer_->OnHost()) { return MemoryType::CPU_BUFFER; - } else if (typeid(*buffer_) == typeid(Image)) { + } else if (buffer_->buffer_type() == core::BufferType::BT_IMAGE) { return MemoryType::GPU_IMAGE; } else { return MemoryType::GPU_BUFFER; @@ -343,12 +343,11 @@ class Tensor { } else { MACE_CHECK(has_opencl_image(), name_, ": Cannot ResizeImage buffer, use Resize."); - Image *image = dynamic_cast(buffer_); - MACE_CHECK(image_shape[0] <= image->image_shape()[0] && - image_shape[1] <= image->image_shape()[1], + MACE_CHECK(image_shape[0] <= buffer_->shape()[0] && + image_shape[1] <= buffer_->shape()[1], "tensor (source op ", name_, - "): current physical image shape: ", image->image_shape()[0], - ", ", image->image_shape()[1], " < logical image shape: ", + "): current physical image shape: ", buffer_->shape()[0], + ", ", buffer_->shape()[1], " < logical image shape: ", image_shape[0], ", ", image_shape[1]); return MaceStatus::MACE_SUCCESS; } diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 5123e670bc900fb5e6f4be145f8fc64be5105b5d..43950a9db00a76ba84b35cb519c4c0c30ded6263 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -272,11 +272,9 @@ MaceStatus Workspace::PreallocateOutputTensor( << " Mem: " << tensor_mem.second.first << " Data type: " << tensor->dtype() << " Image shape: " - << dynamic_cast(tensor->UnderlyingBuffer()) - ->image_shape()[0] + << tensor->UnderlyingBuffer()->shape()[0] << ", " - << dynamic_cast(tensor->UnderlyingBuffer()) - ->image_shape()[1]; + << tensor->UnderlyingBuffer()->shape()[1]; tensor->set_data_format(DataFormat::NHWC); } else { VLOG(1) << "Tensor: " << tensor_mem.first