diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 1d7d5debf9dfefaeab59205d6de67d29867d2c35..53ac3c2e34e65f72d75a5aa518a48f0eeab3ed28 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -178,6 +178,18 @@ class Tensor { } } + inline void ResizeWithBuffer(const std::vector &shape, + BufferBase *buffer) { + MACE_CHECK(!has_opencl_image(), "Cannot resize image, use ResizeImage."); + shape_ = shape; + image_shape_.clear(); + if (buffer_ != nullptr && is_buffer_owner_) { + delete buffer_; + } + buffer_ = buffer; + is_buffer_owner_ = false; + } + inline void ResizeImage(const std::vector &shape, const std::vector &image_shape) { shape_ = shape; diff --git a/mace/kernels/reshape.h b/mace/kernels/reshape.h index 14e560789db709464400136116ba02d373207c65..ddcd0dba58c5241554623c67a884aca7cbe0c060 100644 --- a/mace/kernels/reshape.h +++ b/mace/kernels/reshape.h @@ -21,9 +21,7 @@ struct ReshapeFunctor { const std::vector &out_shape, Tensor *output, StatsFuture *future) { - output->Resize(out_shape); - // TODO(liuqi): copy on write to avoid this copy. - output->CopyBytes(input->raw_data(), input->size() * sizeof(T)); + output->ResizeWithBuffer(out_shape, input->UnderlyingBuffer()); } };