提交 770d8746 编写于 作者: L liuqi

Add input/output info correctness check in MaceEngine.

上级 df29340c
......@@ -106,6 +106,8 @@ class MaceEngine::Impl {
DeviceType device_type_;
std::unique_ptr<Workspace> ws_;
std::unique_ptr<NetBase> net_;
std::map<std::string, mace::InputInfo> input_info_map_;
std::map<std::string, mace::OutputInfo> output_info_map_;
#ifdef MACE_ENABLE_HEXAGON
std::unique_ptr<HexagonControlWrapper> hexagon_controller_;
#endif
......@@ -131,12 +133,29 @@ MaceStatus MaceEngine::Impl::Init(
const std::vector<std::string> &output_nodes,
const unsigned char *model_data) {
LOG(INFO) << "Initializing MaceEngine";
// Get input and output information.
for (auto &input_info : net_def->input_info()) {
input_info_map_[input_info.name()] = input_info;
}
for (auto &output_info : net_def->output_info()) {
output_info_map_[output_info.name()] = output_info;
}
// Set storage path for internal usage
for (auto input_name : input_nodes) {
if (input_info_map_.find(input_name) == input_info_map_.end()) {
LOG(FATAL) << "'" << input_name
<< "' is not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_));
}
ws_->CreateTensor(MakeString("mace_input_node_", input_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
for (auto output_name : output_nodes) {
if (output_info_map_.find(output_name) == output_info_map_.end()) {
LOG(FATAL) << "'" << output_name
<< "' is not belong to model's outputs "
<< MakeString(MapKeys(output_info_map_));
}
ws_->CreateTensor(MakeString("mace_output_node_", output_name),
GetDeviceAllocator(device_type_), DT_FLOAT);
}
......@@ -193,6 +212,11 @@ MaceStatus MaceEngine::Impl::Run(
std::vector<Tensor *> input_tensors;
std::vector<Tensor *> output_tensors;
for (auto &input : inputs) {
if (input_info_map_.find(input.first) == input_info_map_.end()) {
LOG(FATAL) << "'" << input.first
<< "' is not belong to model's inputs: "
<< MakeString(MapKeys(input_info_map_));
}
MACE_CHECK(input.second.shape().size() == 4,
"The Inputs' shape must be 4-dimension with NHWC format,"
" please use 1 to fill missing dimensions");
......@@ -208,6 +232,11 @@ MaceStatus MaceEngine::Impl::Run(
input_tensors.push_back(input_tensor);
}
for (auto &output : *outputs) {
if (output_info_map_.find(output.first) == output_info_map_.end()) {
LOG(FATAL) << "'" << output.first
<< "' is not belong to model's outputs: "
<< MakeString(MapKeys(output_info_map_));
}
if (device_type_ == DeviceType::GPU) {
MACE_CHECK(output.second.shape().size() == 4,
"The outputs' shape must be 4-dimension with NHWC format,"
......@@ -245,7 +274,7 @@ MaceStatus MaceEngine::Impl::Run(
std::multiplies<int64_t>());
MACE_CHECK(!shape.empty()) << "Output's shape must greater than 0";
MACE_CHECK(shape == output.second.shape())
<< "Output shape mispatch: "
<< "Output shape mismatch: "
<< MakeString<int64_t>(output.second.shape())
<< " != " << MakeString<int64_t>(shape);
std::memcpy(output.second.data().get(), output_tensor->data<float>(),
......
......@@ -281,7 +281,9 @@ bool OpenCLLibraryImpl::Load() {
}
if (handle_ == nullptr) {
LOG(ERROR) << "Failed to load OpenCL library";
LOG(ERROR) << "Failed to load OpenCL library, "
"please make sure there exist OpenCL library on your device, "
"and your APP have right to access the library.";
return false;
}
......
......@@ -164,6 +164,7 @@ class TransformerRule(Enum):
TRANSFORM_BUFFER_IMAGE = 17
ADD_DEVICE_AND_DATA_TYPE = 18
SORT_BY_EXECUTION = 19
ADD_IN_OUT_TENSOR_INFO = 20
class ConverterInterface(object):
......@@ -210,6 +211,7 @@ class ConverterOption(object):
self._device = DeviceType.CPU.value
self._winograd_enabled = False
self._transformer_option = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
......
......@@ -166,6 +166,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._option = option
self._mace_net_def = mace_pb2.NetDef()
ConverterUtil.set_filter_format(self._mace_net_def, FilterFormat.HWIO)
# import tensorflow graph
tf_graph_def = tf.GraphDef()
with tf.gfile.Open(src_model_file, 'rb') as f:
tf_graph_def.ParseFromString(f.read())
......
......@@ -55,6 +55,7 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model):
# DO NOT reorder the following transformers' order
self._registered_transformers_order = [
TransformerRule.ADD_IN_OUT_TENSOR_INFO,
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
......@@ -78,6 +79,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION,
]
self._registered_transformers = {
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
self.add_in_out_tensor_info,
TransformerRule.REMOVE_USELESS_RESHAPE_OP:
self.remove_useless_reshape_op,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
......@@ -271,6 +274,21 @@ class Transformer(base_converter.ConverterInterface):
self._model.op.remove(op)
def add_in_out_tensor_info(self):
net = self._model
for input_node in self._option.input_nodes.values():
input_info = net.input_info.add()
input_info.name = input_node.name
input_info.dims.extend(input_node.shape)
for output_node in self._option.output_nodes.values():
output_info = net.output_info.add()
output_info.name = output_node.name
output_info.dims.extend(
self._producer[output_node.name].output_shape[0].dims)
return False
def remove_useless_reshape_op(self):
net = self._model
for op in net.op:
......
......@@ -50,3 +50,24 @@ cc_test(
"@gtest//:gtest_main",
],
)
cc_test(
name = "mace_api_exception_test",
testonly = 1,
srcs = ["mace_api_exception_test.cc"],
copts = ["-Werror", "-Wextra", "-Wno-missing-field-initializers"] +
if_openmp_enabled(["-fopenmp"]) +
if_neon_enabled(["-DMACE_ENABLE_NEON"]) +
if_android_armv7(["-mfpu=neon"]) +
if_android_armv7(["-mfloat-abi=softfp"]) +
if_android(["-DMACE_ENABLE_OPENCL"]) +
if_hexagon_enabled(["-DMACE_ENABLE_HEXAGON"]),
linkopts = ["-fopenmp"],
linkstatic = 1,
deps = [
"//mace/ops:test",
"//mace/kernels:kernels",
"//mace/ops:ops",
"@gtest//:gtest_main",
],
)
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace test {
TEST(MaceAPIExceptionTest, WrongInputTest) {
std::vector<std::string> input_names;
std::vector<std::string> output_names;
input_names.push_back(MakeString("input", 0));
output_names.push_back(MakeString("output", 0));
const DeviceType device = DeviceType::GPU;
std::shared_ptr<NetDef> net_def(new NetDef());
for (size_t i = 0; i < input_names.size(); ++i) {
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
MaceEngine engine(device);
ASSERT_DEATH(engine.Init(net_def.get(), {"input"}, output_names, nullptr),
"");
}
} // namespace test
} // namespace mace
......@@ -298,6 +298,8 @@ void MaceRunFunc(const int in_out_size) {
{mem_map[input_names[i]]},
device,
net_def.get());
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, device,
......@@ -315,6 +317,8 @@ void MaceRunFunc(const int in_out_size) {
mace::kernels::IN_OUT_CHANNEL,
device,
net_def.get());
OutputInfo *info = net_def->add_output_info();
info->set_name(output_names[i]);
}
const std::string file_path ="/data/local/tmp/mace";
......
......@@ -308,6 +308,8 @@ void MaceRun(const int in_out_size,
{mem_map[input_names[i]]},
device,
net_def.get());
InputInfo *info = net_def->add_input_info();
info->set_name(input_names[i]);
}
BufferToImage<half>(filter_tensor_name, filter_tensor_img_name,
mace::kernels::CONV2D_FILTER, {}, device,
......@@ -324,6 +326,8 @@ void MaceRun(const int in_out_size,
mace::kernels::IN_OUT_CHANNEL,
device,
net_def.get());
OutputInfo *info = net_def->add_output_info();
info->set_name(output_names[i]);
}
MaceEngine engine(device);
......@@ -376,5 +380,6 @@ TEST_F(MaceAPITest, GPUVariableInputShape) {
{{1, 16, 32, 16}, {1, 32, 64, 16}},
{16, 16, 3, 3});
}
} // namespace test
} // namespace mace
......@@ -16,6 +16,7 @@
#define MACE_UTILS_UTILS_H_
#include <fstream>
#include <map>
#include <sstream>
#include <string>
#include <utility>
......@@ -152,5 +153,14 @@ inline bool ReadBinaryFile(std::vector<unsigned char> *data,
return true;
}
template <typename T>
std::vector<std::string> MapKeys(const std::map<std::string, T> &data) {
std::vector<std::string> keys;
for (auto &kv : data) {
keys.push_back(kv.first);
}
return keys;
}
} // namespace mace
#endif // MACE_UTILS_UTILS_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册