From 6bb057ea278e9745ebb9eef61c7b32a86a657edd Mon Sep 17 00:00:00 2001 From: cjh9368 Date: Wed, 5 Aug 2020 12:51:20 +0800 Subject: [PATCH] benchmark mark accuracy use only 1 tensor --- mindspore/lite/tools/benchmark/benchmark.cc | 26 ++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/mindspore/lite/tools/benchmark/benchmark.cc b/mindspore/lite/tools/benchmark/benchmark.cc index 473913841..1c3e92928 100644 --- a/mindspore/lite/tools/benchmark/benchmark.cc +++ b/mindspore/lite/tools/benchmark/benchmark.cc @@ -234,17 +234,21 @@ int Benchmark::CompareOutput() { MS_LOG(ERROR) << "Cannot find output node: " << nodeName.c_str() << " , compare output data fail."; return RET_ERROR; } - for (auto tensor : tensors) { - MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); - MS_ASSERT(tensor->GetData() != nullptr); - float bias = CompareData(nodeName, tensor->shape(), static_cast(tensor->MutableData())); - if (bias >= 0) { - totalBias += bias; - totalSize++; - } else { - hasError = true; - break; - } + // make sure tensor size is 1 + if (tensors.size() != 1) { + MS_LOG(ERROR) << "Only support 1 tensor with a name now."; + return RET_ERROR; + } + auto &tensor = tensors.front(); + MS_ASSERT(tensor->GetDataType() == DataType_DT_FLOAT); + MS_ASSERT(tensor->GetData() != nullptr); + float bias = CompareData(nodeName, tensor->shape(), static_cast(tensor->MutableData())); + if (bias >= 0) { + totalBias += bias; + totalSize++; + } else { + hasError = true; + break; } } -- GitLab