diff --git a/mace/libmace/mace.cc b/mace/libmace/mace.cc index e9256b9fdd88813fbcadd8ce211c328e6fceb1c2..ff812c6bb893af994299a8a99820afce8c21ff58 100644 --- a/mace/libmace/mace.cc +++ b/mace/libmace/mace.cc @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include "mace/core/device_context.h" @@ -313,6 +315,7 @@ class MaceTensor::Impl { std::vector shape; std::shared_ptr data; DataFormat format; + int64_t buffer_size; }; MaceTensor::MaceTensor(const std::vector &shape, @@ -323,6 +326,8 @@ MaceTensor::MaceTensor(const std::vector &shape, impl_->shape = shape; impl_->data = data; impl_->format = format; + impl_->buffer_size = + std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); } MaceTensor::MaceTensor() { @@ -334,6 +339,7 @@ MaceTensor::MaceTensor(const MaceTensor &other) { impl_->shape = other.shape(); impl_->data = other.data(); impl_->format = other.data_format(); + impl_->buffer_size = other.impl_->buffer_size; } MaceTensor::MaceTensor(const MaceTensor &&other) { @@ -341,12 +347,14 @@ MaceTensor::MaceTensor(const MaceTensor &&other) { impl_->shape = other.shape(); impl_->data = other.data(); impl_->format = other.data_format(); + impl_->buffer_size = other.impl_->buffer_size; } MaceTensor &MaceTensor::operator=(const MaceTensor &other) { impl_->shape = other.shape(); impl_->data = other.data(); impl_->format = other.data_format(); + impl_->buffer_size = other.impl_->buffer_size; return *this; } @@ -354,6 +362,7 @@ MaceTensor &MaceTensor::operator=(const MaceTensor &&other) { impl_->shape = other.shape(); impl_->data = other.data(); impl_->format = other.data_format(); + impl_->buffer_size = other.impl_->buffer_size; return *this; } @@ -484,7 +493,14 @@ MaceStatus MaceEngine::Impl::Init( << "' does not belong to model's inputs: " << MakeString(MapKeys(input_info_map_)); } - ws_->CreateTensor(input_name, device_->allocator(), DT_FLOAT); + Tensor *input_tensor = + ws_->CreateTensor(input_name, device_->allocator(), DT_FLOAT); + // Resize to possible largest shape to avoid resize during running. + std::vector shape(input_info_map_[input_name].dims_size()); + for (int i = 0; i < input_info_map_[input_name].dims_size(); ++i) { + shape[i] = input_info_map_[input_name].dims(i); + } + input_tensor->Resize(shape); } for (auto output_name : output_nodes) { if (output_info_map_.find(output_name) == output_info_map_.end()) { @@ -637,10 +653,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput( std::vector shape = TransposeShape(output_tensor->shape(), dst_dims); - MACE_CHECK(shape == output->second.shape()) - << "Output shape mismatch: " - << MakeString(shape) << " != " - << MakeString(output->second.shape()); + int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + MACE_CHECK(output_size <= output->second.impl_->buffer_size) + << "Output size exceeds buffer size: shape" + << MakeString(shape) << " vs buffer size " + << output->second.impl_->buffer_size; + output->second.impl_->shape = shape; Tensor::MappingGuard output_guard(output_tensor); const float *output_data = output_tensor->data(); return ops::Transpose(output_data, @@ -660,10 +679,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput( std::vector shape = TransposeShape(output_tensor->shape(), dst_dims); - MACE_CHECK(shape == output->second.shape()) - << "Output shape mismatch: " - << MakeString(shape) << " != " - << MakeString(output->second.shape()); + int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, + std::multiplies()); + MACE_CHECK(output_size <= output->second.impl_->buffer_size) + << "Output size exceeds buffer size: shape" + << MakeString(shape) << " vs buffer size " + << output->second.impl_->buffer_size; + output->second.impl_->shape = shape; Tensor::MappingGuard output_guard(output_tensor); const float *output_data = output_tensor->data(); return ops::Transpose(output_data, @@ -675,10 +697,11 @@ MaceStatus MaceEngine::Impl::TransposeOutput( auto shape = output_tensor->shape(); int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, std::multiplies()); - MACE_CHECK(shape == output->second.shape()) - << "Output shape mismatch: " - << MakeString(shape) << " != " - << MakeString(output->second.shape()); + MACE_CHECK(output_size <= output->second.impl_->buffer_size) + << "Output size exceeds buffer size: shape" + << MakeString(shape) << " vs buffer size " + << output->second.impl_->buffer_size; + output->second.impl_->shape = shape; std::memcpy(output->second.data().get(), output_tensor->data(), output_size * sizeof(float)); return MaceStatus::MACE_SUCCESS; diff --git a/mace/public/mace.h b/mace/public/mace.h index 912867f74b60b613439a7f545e0b2d2fab335454..aa8cad4037fc9835188fa4a5274f1cb4fea46f24 100644 --- a/mace/public/mace.h +++ b/mace/public/mace.h @@ -282,8 +282,12 @@ class MACE_API MaceEngineConfig { // MACE input/output tensor class MACE_API MaceTensor { + friend class MaceEngine; + public: - // shape - the shape of the tensor, with size n + // shape - the shape of the tensor, with size n, if shape is unknown + // in advance, it should be specified large enough to hold tensor of all + // possible size. // data - the buffer of the tensor, must not be null with size equals // shape[0] * shape[1] * ... * shape[n-1]. // If you want to pass a buffer which is unsuitable to use the default @@ -301,6 +305,7 @@ class MACE_API MaceTensor { MaceTensor &operator=(const MaceTensor &&other); ~MaceTensor(); + // shape will be updated to the actual output shape after running. const std::vector &shape() const; const std::shared_ptr data() const; std::shared_ptr data();