diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 2361a435979a338e7ccb280f7d20e90474cad67c..770a8067cc735e02c89123d98aab879c1c619986 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -534,7 +534,9 @@ MaceEngine::MaceEngine(const NetDef *net_def, DeviceType device_type) : ws_->CreateTensor("mace_input_node:0", GetDeviceAllocator(device_type_), DT_FLOAT); - + ws_->CreateTensor("mace_output_node:0", + GetDeviceAllocator(device_type_), + DT_FLOAT); net_ = std::move(CreateNet(op_registry_, *net_def, ws_.get(), device_type)); } } @@ -551,9 +553,7 @@ bool MaceEngine::Run(const float *input, float *output) { MACE_CHECK(output != nullptr, "output ptr cannot be NULL"); Tensor *input_tensor = ws_->GetTensor("mace_input_node:0"); - Tensor *output_tensor = ws_->CreateTensor("mace_output_node:0", - GetDeviceAllocator(device_type_), - DT_FLOAT); + Tensor *output_tensor = ws_->GetTensor("mace_output_node:0"); input_tensor->Resize(input_shape); { Tensor::MappingGuard input_guard(input_tensor);