From b3524eb512ce8acace96917d5349e7f98f2ddc3d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E5=AF=85?= Date: Thu, 20 Sep 2018 15:33:54 +0800 Subject: [PATCH] Preallocate full size scratch buffer for variable-length models --- mace/core/buffer.h | 1 + mace/core/operator.h | 9 +++++++++ mace/core/tensor.h | 30 ++++++++++++++++++++++++++++++ mace/core/workspace.cc | 11 +++++++++++ mace/kernels/matmul.h | 8 ++++++++ 5 files changed, 59 insertions(+) diff --git a/mace/core/buffer.h b/mace/core/buffer.h index ba43e96c..521ccc82 100644 --- a/mace/core/buffer.h +++ b/mace/core/buffer.h @@ -469,6 +469,7 @@ class ScratchBuffer: public Buffer { MaceStatus GrowSize(index_t size) { if (size > size_) { + VLOG(1) << "Grow scratch size to: " << size; MACE_CHECK(offset_ == 0, "scratch is being used, cannot grow size"); return Resize(size); } diff --git a/mace/core/operator.h b/mace/core/operator.h index 6be38890..e0b84535 100644 --- a/mace/core/operator.h +++ b/mace/core/operator.h @@ -117,6 +117,15 @@ class Operator : public OperatorBase { } outputs_.push_back(MACE_CHECK_NOTNULL(ws->CreateTensor( output_str, context->device()->allocator(), output_type))); + + if (i < operator_def.output_shape_size()) { + std::vector + shape_configured(operator_def.output_shape(i).dims_size()); + for (size_t dim = 0; dim < shape_configured.size(); ++dim) { + shape_configured[dim] = operator_def.output_shape(i).dims(dim); + } + ws->GetTensor(output_str)->SetShapeConfigured(shape_configured); + } } } } diff --git a/mace/core/tensor.h b/mace/core/tensor.h index 713a6d1e..b3c58152 100644 --- a/mace/core/tensor.h +++ b/mace/core/tensor.h @@ -18,6 +18,7 @@ #include #include #include +#include #include "mace/core/buffer.h" #include "mace/core/preallocated_pooled_allocator.h" @@ -159,6 +160,34 @@ class Tensor { inline const std::vector &shape() const { return shape_; } + inline std::vector max_shape() const { + if (shape_configured_.empty()) { + return shape(); + } else { + auto &_shape = shape(); + std::vector max_shape(_shape.size()); + MACE_CHECK(_shape.size() == shape_configured_.size()); + for (size_t i = 0; i < shape_configured_.size(); ++i) { + max_shape[i] = std::max(_shape[i], shape_configured_[i]); + } + return max_shape; + } + } + + inline index_t max_size() const { + auto _max_shape = max_shape(); + return std::accumulate(_max_shape.begin(), + _max_shape.end(), + 1, + std::multiplies()); + } + + inline index_t raw_max_size() const { return max_size() * SizeOfType(); } + + inline void SetShapeConfigured(const std::vector &shape_configured) { + shape_configured_ = shape_configured; + } + inline index_t dim_size() const { return shape_.size(); } inline index_t dim(unsigned int index) const { @@ -431,6 +460,7 @@ class Tensor { Allocator *allocator_; DataType dtype_; std::vector shape_; + std::vector shape_configured_; std::vector image_shape_; BufferBase *buffer_; BufferSlice buffer_slice_; diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 4a34cd44..333d6e60 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -340,6 +340,17 @@ MaceStatus Workspace::CreateOutputTensorBuffer(const NetDef &net_def, output_type); } } + + for (int output_idx = 0; output_idx < op.output_shape_size(); + ++output_idx) { + std::vector + shape_configured(op.output_shape(output_idx).dims_size()); + for (size_t dim = 0; dim < shape_configured.size(); ++dim) { + shape_configured[dim] = op.output_shape(output_idx).dims(dim); + } + tensor_map_[op.output(output_idx)]->SetShapeConfigured( + shape_configured); + } } } return MaceStatus::MACE_SUCCESS; diff --git a/mace/kernels/matmul.h b/mace/kernels/matmul.h index ad7ab968..137c7151 100644 --- a/mace/kernels/matmul.h +++ b/mace/kernels/matmul.h @@ -91,6 +91,14 @@ struct MatMulFunctor : OpKernel { auto scratch_buffer = context_->workspace()->GetScratchBuffer(D); scratch_buffer->Rewind(); + index_t scratch_size = C->raw_max_size(); + if (!A->is_weight()) { + scratch_size += A->raw_max_size(); + } + if (!B->is_weight()) { + scratch_size += B->raw_max_size(); + } + scratch_buffer->GrowSize(scratch_size); sgemm_.Run(a_ptr_base, b_ptr_base, -- GitLab