提交 f33d4d5a 编写于 作者: 李寅

Support various model input/output

上级 a74002cb
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <sys/mman.h> #include <sys/mman.h>
#include <unistd.h> #include <unistd.h>
#include <algorithm>
#include <numeric>
#include <memory> #include <memory>
#include "mace/core/device_context.h" #include "mace/core/device_context.h"
...@@ -313,6 +315,7 @@ class MaceTensor::Impl { ...@@ -313,6 +315,7 @@ class MaceTensor::Impl {
std::vector<int64_t> shape; std::vector<int64_t> shape;
std::shared_ptr<float> data; std::shared_ptr<float> data;
DataFormat format; DataFormat format;
int64_t buffer_size;
}; };
MaceTensor::MaceTensor(const std::vector<int64_t> &shape, MaceTensor::MaceTensor(const std::vector<int64_t> &shape,
...@@ -323,6 +326,8 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape, ...@@ -323,6 +326,8 @@ MaceTensor::MaceTensor(const std::vector<int64_t> &shape,
impl_->shape = shape; impl_->shape = shape;
impl_->data = data; impl_->data = data;
impl_->format = format; impl_->format = format;
impl_->buffer_size =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<float>());
} }
MaceTensor::MaceTensor() { MaceTensor::MaceTensor() {
...@@ -334,6 +339,7 @@ MaceTensor::MaceTensor(const MaceTensor &other) { ...@@ -334,6 +339,7 @@ MaceTensor::MaceTensor(const MaceTensor &other) {
impl_->shape = other.shape(); impl_->shape = other.shape();
impl_->data = other.data(); impl_->data = other.data();
impl_->format = other.data_format(); impl_->format = other.data_format();
impl_->buffer_size = other.impl_->buffer_size;
} }
MaceTensor::MaceTensor(const MaceTensor &&other) { MaceTensor::MaceTensor(const MaceTensor &&other) {
...@@ -341,12 +347,14 @@ MaceTensor::MaceTensor(const MaceTensor &&other) { ...@@ -341,12 +347,14 @@ MaceTensor::MaceTensor(const MaceTensor &&other) {
impl_->shape = other.shape(); impl_->shape = other.shape();
impl_->data = other.data(); impl_->data = other.data();
impl_->format = other.data_format(); impl_->format = other.data_format();
impl_->buffer_size = other.impl_->buffer_size;
} }
MaceTensor &MaceTensor::operator=(const MaceTensor &other) { MaceTensor &MaceTensor::operator=(const MaceTensor &other) {
impl_->shape = other.shape(); impl_->shape = other.shape();
impl_->data = other.data(); impl_->data = other.data();
impl_->format = other.data_format(); impl_->format = other.data_format();
impl_->buffer_size = other.impl_->buffer_size;
return *this; return *this;
} }
...@@ -354,6 +362,7 @@ MaceTensor &MaceTensor::operator=(const MaceTensor &&other) { ...@@ -354,6 +362,7 @@ MaceTensor &MaceTensor::operator=(const MaceTensor &&other) {
impl_->shape = other.shape(); impl_->shape = other.shape();
impl_->data = other.data(); impl_->data = other.data();
impl_->format = other.data_format(); impl_->format = other.data_format();
impl_->buffer_size = other.impl_->buffer_size;
return *this; return *this;
} }
...@@ -484,7 +493,14 @@ MaceStatus MaceEngine::Impl::Init( ...@@ -484,7 +493,14 @@ MaceStatus MaceEngine::Impl::Init(
<< "' does not belong to model's inputs: " << "' does not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_)); << MakeString(MapKeys(input_info_map_));
} }
ws_->CreateTensor(input_name, device_->allocator(), DT_FLOAT); Tensor *input_tensor =
ws_->CreateTensor(input_name, device_->allocator(), DT_FLOAT);
// Resize to possible largest shape to avoid resize during running.
std::vector<index_t> shape(input_info_map_[input_name].dims_size());
for (int i = 0; i < input_info_map_[input_name].dims_size(); ++i) {
shape[i] = input_info_map_[input_name].dims(i);
}
input_tensor->Resize(shape);
} }
for (auto output_name : output_nodes) { for (auto output_name : output_nodes) {
if (output_info_map_.find(output_name) == output_info_map_.end()) { if (output_info_map_.find(output_name) == output_info_map_.end()) {
...@@ -637,10 +653,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -637,10 +653,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
std::vector<index_t> shape = std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(), TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims); dst_dims);
MACE_CHECK(shape == output->second.shape()) int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
<< "Output shape mismatch: " std::multiplies<int64_t>());
<< MakeString<int64_t>(shape) << " != " MACE_CHECK(output_size <= output->second.impl_->buffer_size)
<< MakeString<int64_t>(output->second.shape()); << "Output size exceeds buffer size: shape"
<< MakeString<int64_t>(shape) << " vs buffer size "
<< output->second.impl_->buffer_size;
output->second.impl_->shape = shape;
Tensor::MappingGuard output_guard(output_tensor); Tensor::MappingGuard output_guard(output_tensor);
const float *output_data = output_tensor->data<float>(); const float *output_data = output_tensor->data<float>();
return ops::Transpose(output_data, return ops::Transpose(output_data,
...@@ -660,10 +679,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -660,10 +679,13 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
std::vector<index_t> shape = std::vector<index_t> shape =
TransposeShape<index_t, index_t>(output_tensor->shape(), TransposeShape<index_t, index_t>(output_tensor->shape(),
dst_dims); dst_dims);
MACE_CHECK(shape == output->second.shape()) int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
<< "Output shape mismatch: " std::multiplies<int64_t>());
<< MakeString<int64_t>(shape) << " != " MACE_CHECK(output_size <= output->second.impl_->buffer_size)
<< MakeString<int64_t>(output->second.shape()); << "Output size exceeds buffer size: shape"
<< MakeString<int64_t>(shape) << " vs buffer size "
<< output->second.impl_->buffer_size;
output->second.impl_->shape = shape;
Tensor::MappingGuard output_guard(output_tensor); Tensor::MappingGuard output_guard(output_tensor);
const float *output_data = output_tensor->data<float>(); const float *output_data = output_tensor->data<float>();
return ops::Transpose(output_data, return ops::Transpose(output_data,
...@@ -675,10 +697,11 @@ MaceStatus MaceEngine::Impl::TransposeOutput( ...@@ -675,10 +697,11 @@ MaceStatus MaceEngine::Impl::TransposeOutput(
auto shape = output_tensor->shape(); auto shape = output_tensor->shape();
int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1, int64_t output_size = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int64_t>()); std::multiplies<int64_t>());
MACE_CHECK(shape == output->second.shape()) MACE_CHECK(output_size <= output->second.impl_->buffer_size)
<< "Output shape mismatch: " << "Output size exceeds buffer size: shape"
<< MakeString<int64_t>(shape) << " != " << MakeString<int64_t>(shape) << " vs buffer size "
<< MakeString<int64_t>(output->second.shape()); << output->second.impl_->buffer_size;
output->second.impl_->shape = shape;
std::memcpy(output->second.data().get(), output_tensor->data<float>(), std::memcpy(output->second.data().get(), output_tensor->data<float>(),
output_size * sizeof(float)); output_size * sizeof(float));
return MaceStatus::MACE_SUCCESS; return MaceStatus::MACE_SUCCESS;
......
...@@ -282,8 +282,12 @@ class MACE_API MaceEngineConfig { ...@@ -282,8 +282,12 @@ class MACE_API MaceEngineConfig {
// MACE input/output tensor // MACE input/output tensor
class MACE_API MaceTensor { class MACE_API MaceTensor {
friend class MaceEngine;
public: public:
// shape - the shape of the tensor, with size n // shape - the shape of the tensor, with size n, if shape is unknown
// in advance, it should be specified large enough to hold tensor of all
// possible size.
// data - the buffer of the tensor, must not be null with size equals // data - the buffer of the tensor, must not be null with size equals
// shape[0] * shape[1] * ... * shape[n-1]. // shape[0] * shape[1] * ... * shape[n-1].
// If you want to pass a buffer which is unsuitable to use the default // If you want to pass a buffer which is unsuitable to use the default
...@@ -301,6 +305,7 @@ class MACE_API MaceTensor { ...@@ -301,6 +305,7 @@ class MACE_API MaceTensor {
MaceTensor &operator=(const MaceTensor &&other); MaceTensor &operator=(const MaceTensor &&other);
~MaceTensor(); ~MaceTensor();
// shape will be updated to the actual output shape after running.
const std::vector<int64_t> &shape() const; const std::vector<int64_t> &shape() const;
const std::shared_ptr<float> data() const; const std::shared_ptr<float> data() const;
std::shared_ptr<float> data(); std::shared_ptr<float> data();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册