diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index d3cbcd4ce6cccc0703c95ac6bb17b8a84f1f2cf8..3130fd697bf85fa6cb4ce7bea9571635a2bc1d5d 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -3,6 +3,7 @@ stages: - pycodestyle - platform_compitable_tests - ops_test + - api_test - ops_benchmark - extra_tests @@ -21,7 +22,13 @@ ops_test: stage: ops_test script: - if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi - - python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS + - python tools/bazel_adb_run.py --target="//mace/ops:ops_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS + +api_test: + stage: api_test + script: + - if [ -z "$TARGET_SOCS" ]; then TARGET_SOCS=random; fi + - python tools/bazel_adb_run.py --target="//mace/test:mace_api_test" --run_target=True --stdout_processor=unittest_stdout_processor --target_abis=armeabi-v7a,arm64-v8a --target_socs=$TARGET_SOCS ops_benchmark: stage: ops_benchmark diff --git a/mace/core/mace.cc b/mace/core/mace.cc index 76e243b138616b5dffc3cac8c7072a6bf3e18000..ca926aa5acf80325e92075cc1af5d00e3e83aa1e 100644 --- a/mace/core/mace.cc +++ b/mace/core/mace.cc @@ -178,6 +178,9 @@ MaceStatus MaceEngine::Impl::Run( std::vector input_tensors; std::vector output_tensors; for (auto &input : inputs) { + MACE_CHECK(input.second.shape().size() == 4, + "The Inputs' shape must be 4-dimension with NHWC format," + " please use 1 to fill missing dimensions"); Tensor *input_tensor = ws_->GetTensor(MakeString("mace_input_node_", input.first, ":0")); input_tensor->Resize(input.second.shape()); @@ -190,6 +193,9 @@ MaceStatus MaceEngine::Impl::Run( input_tensors.push_back(input_tensor); } for (auto &output : *outputs) { + MACE_CHECK(output.second.shape().size() == 4, + "The outputs' shape must be 4-dimension with NHWC format," + " please use 1 to fill missing dimensions"); Tensor *output_tensor = ws_->GetTensor(MakeString("mace_output_node_", output.first + ":0")); output_tensors.push_back(output_tensor); diff --git a/mace/core/workspace.cc b/mace/core/workspace.cc index 0c681b14b70d2df9c81773652413b0a140513358..7a3bd994fa8baaae98a5878f92c73c0ef6ca74ae 100644 --- a/mace/core/workspace.cc +++ b/mace/core/workspace.cc @@ -81,15 +81,19 @@ void Workspace::LoadModelTensor(const NetDef &net_def, DeviceType type) { } VLOG(3) << "Model data size: " << model_data_size; - if (type == DeviceType::CPU || type == DeviceType::NEON) { - tensor_buffer_ = std::unique_ptr( - new Buffer(GetDeviceAllocator(type), model_data_ptr, model_data_size)); - } else { - tensor_buffer_ = std::unique_ptr( - new Buffer(GetDeviceAllocator(type), model_data_size)); - tensor_buffer_->Map(nullptr); - tensor_buffer_->Copy(model_data_ptr, 0, model_data_size); - tensor_buffer_->UnMap(); + if (model_data_size > 0) { + if (type == DeviceType::CPU || type == DeviceType::NEON) { + tensor_buffer_ = std::unique_ptr( + new Buffer(GetDeviceAllocator(type), + model_data_ptr, + model_data_size)); + } else { + tensor_buffer_ = std::unique_ptr( + new Buffer(GetDeviceAllocator(type), model_data_size)); + tensor_buffer_->Map(nullptr); + tensor_buffer_->Copy(model_data_ptr, 0, model_data_size); + tensor_buffer_->UnMap(); + } } for (auto &const_tensor : net_def.tensors()) { diff --git a/mace/examples/example.cc b/mace/examples/example.cc index aa852fdab4ece6bb053e9efbe11030ba7164fec3..52809a4fc217de50ceca7df2635836625fc7cacf 100644 --- a/mace/examples/example.cc +++ b/mace/examples/example.cc @@ -163,6 +163,8 @@ bool RunModel(const std::vector &input_names, static_cast(FLAGS_gpu_priority_hint)); } + // DO NOT USE tmp directory. + // please use APP's own directory const std::string kernel_file_path = "/data/local/tmp/mace_run/cl"; diff --git a/mace/ops/ops_test_util.h b/mace/ops/ops_test_util.h index 3235c9027f16ffa0b250beaf9f64073f620ace78..a8f72f58489b7edc376948ab699034f8f31851c3 100644 --- a/mace/ops/ops_test_util.h +++ b/mace/ops/ops_test_util.h @@ -283,6 +283,16 @@ class OpsTestNet { return RunOp(DeviceType::CPU); } + bool RunNet(const NetDef &net_def, const DeviceType device) { + device_ = device; + net_ = CreateNet(op_registry_, net_def, &ws_, device, NetMode::INIT); + if (!net_->Run()) { + return false; + } + net_ = CreateNet(op_registry_, net_def, &ws_, device); + return net_->Run(); + } + Tensor *GetOutput(const char *output_name) { return ws_.GetTensor(output_name); } diff --git a/mace/python/tools/caffe_converter_lib.py b/mace/python/tools/caffe_converter_lib.py index cce64c4dce96a956fc89a65e69222656d7159318..14e39039df14edb4b62723574e97a6e93b2d8257 100644 --- a/mace/python/tools/caffe_converter_lib.py +++ b/mace/python/tools/caffe_converter_lib.py @@ -306,6 +306,13 @@ class CaffeConverter(object): arg.name = 'T' arg.i = self.dt + input_op = self.ops_map[name] + if input_op.layer is not None: + output_shape = input_op.output_shape_map[input_op.layer.top[0]] + else: + output_shape = input_op.output_shape_map[input_op.name] + self.add_output_shape(op_def, output_shape) + def add_output_transform(self, names): for name in names: output_name = MACE_OUTPUT_NODE_NAME + '_' + name + ":0" @@ -1077,15 +1084,15 @@ class CaffeConverter(object): dims_arg.ints.extend([0, 2, 3, 1]) # NCHW -> NHWC def convert(self, input_nodes, input_shapes, output_nodes): + assert self.ops[0].type == 'Input' + self.add_input_op_shape(input_nodes, input_shapes) + if self.device == 'gpu': self.add_input_transform(input_nodes) if self.device == 'neon': self.add_neon_input_transform(input_nodes) - assert self.ops[0].type == 'Input' - self.add_input_op_shape(input_nodes, input_shapes) - for op in self.ops: if op.name in self.resolved_ops: continue diff --git a/mace/python/tools/memory_optimizer.py b/mace/python/tools/memory_optimizer.py index da92448d130bf4c056fe9482972d7f4e0ee35ebc..01f271626b34d7333a17b24d0b9bf679f13ade11 100644 --- a/mace/python/tools/memory_optimizer.py +++ b/mace/python/tools/memory_optimizer.py @@ -32,7 +32,11 @@ class MemoryOptimizer(object): self.ref_counter[tensor_name] = 0 def is_buffer_image_op(self, op): - return op.type == 'BufferToImage' or op.type == 'ImageToBuffer' + if op.type == 'BufferToImage': + for arg in op.arg: + if arg.name == 'mode' and arg.i == 0: + return True + return op.type == 'ImageToBuffer' def get_mem_size(self, op_type, output_shape): mem_size = [0, 0] diff --git a/mace/python/tools/tf_converter_lib.py b/mace/python/tools/tf_converter_lib.py index c50766cb44830bfbd098ecef2903a3b9d5edb5d9..1d4aac42e0c9ae8ef310993c55c6ed1161be0ac9 100644 --- a/mace/python/tools/tf_converter_lib.py +++ b/mace/python/tools/tf_converter_lib.py @@ -141,6 +141,8 @@ class TFConverter(object): arg.name = 'T' arg.i = self.dt + self.add_output_shape(self.ops[name].outputs, op_def) + def add_neon_input_transform(self, names): for name in names: new_input_name = MACE_INPUT_NODE_NAME + '_' + name + ":0" diff --git a/mace/test/BUILD b/mace/test/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..7c453ed869fd52720bab18521e649e1dd3417ea6 --- /dev/null +++ b/mace/test/BUILD @@ -0,0 +1,25 @@ +# Description: +# Mace operators. +# +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_test( + name = "mace_api_test", + testonly = 1, + srcs = glob( + ["mace_api_test.cc"], + ), + linkopts = ["-fopenmp"], + linkstatic = 1, + deps = [ + "//mace/core:core", + "//mace/kernels:kernels", + "//mace/ops:ops", + "//mace/ops:test", + "@gtest//:gtest_main", + ], +) diff --git a/mace/test/mace_api_test.cc b/mace/test/mace_api_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..fc22a450edfc5a5971d6717e4db01c2bf2dc96ad --- /dev/null +++ b/mace/test/mace_api_test.cc @@ -0,0 +1,336 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include + +#include "mace/core/operator.h" +#include "mace/kernels/conv_pool_2d_util.h" +#include "mace/ops/ops_test_util.h" + +namespace mace { +namespace test { + +class MaceAPITest : public ::testing::Test {}; + +namespace { + +void GenerateInputs(const std::vector &input_names, + const std::vector &input_shape, + std::map *inputs) { + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + // Allocate input and output + int64_t input_size = + std::accumulate(input_shape.begin(), input_shape.end(), 1, + std::multiplies()); + auto buffer_in = std::shared_ptr(new float[input_size], + std::default_delete()); + // load input + std::vector input_data; + ops::test::GenerateRandomRealTypeData(input_shape, &input_data); + memcpy(buffer_in.get(), input_data.data(), input_size * sizeof(float)); + (*inputs)[input_names[i]] = mace::MaceTensor(input_shape, buffer_in); + } +} + +void GenerateOutputs(const std::vector &output_names, + const std::vector &output_shape, + std::map *outputs) { + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + int64_t output_size = + std::accumulate(output_shape.begin(), output_shape.end(), 1, + std::multiplies()); + auto buffer_out = std::shared_ptr(new float[output_size], + std::default_delete()); + (*outputs)[output_names[i]] = mace::MaceTensor(output_shape, buffer_out); + } +} + +template +void BufferToImage(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + const std::vector &mem_ids, + NetDef *net_def, + const int mode = NetMode::NORMAL) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("BufferToImage", "BufferToImageOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .AddIntArg("mode", mode) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void ImageToBuffer(const std::string &input_name, + const std::string &output_name, + const int buffer_type, + NetDef *net_def) { + OperatorDef operator_def; + + ops::test::OpDefBuilder("ImageToBuffer", "ImageToBufferOp") + .Input(input_name) + .Output(output_name) + .AddIntArg("buffer_type", buffer_type) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Conv3x3(const std::string &input_name, + const std::string &filter_name, + const std::string &output_name, + const std::vector &mem_ids, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Conv2D", "Conv2dOp") + .Input(input_name) + .Input(filter_name) + .Output(output_name) + .AddIntsArg("strides", {1, 1}) + .AddIntArg("padding", Padding::SAME) + .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + operator_def.set_mem_id(mem_ids); + net_def->add_op()->CopyFrom(operator_def); +} + +template +void Relu(const std::string &input_name, + const std::string &output_name, + NetDef *net_def) { + OperatorDef operator_def; + ops::test::OpDefBuilder("Activation", "ReluTest") + .Input(input_name) + .Output(output_name) + .AddStringArg("activation", "RELU") + .AddIntArg("T", static_cast(DataTypeToEnum::value)) + .Finalize(&operator_def); + + net_def->add_op()->CopyFrom(operator_def); +} + +template +void AddTensor(const std::string &name, + const std::vector &shape, + T *data, + NetDef *net_def) { + ConstTensor tensor(name, + reinterpret_cast(data), + shape, + DataTypeToEnum::value); + + net_def->mutable_tensors().push_back(tensor); +} + +template +void CheckOutputs(const NetDef &net_def, + const std::map &inputs, + const std::map &outputs) { + ops::test::OpsTestNet net; + for (auto input : inputs) { + auto input_shape = input.second.shape(); + const int64_t data_size = std::accumulate(input_shape.begin(), + input_shape.end(), 1, + std::multiplies()); + std::vector input_data(data_size); + memcpy(input_data.data(), input.second.data().get(), + data_size * sizeof(float)); + std::string input_name = MakeString("mace_input_node_", + input.first, ":0"); + net.AddInputFromArray(input_name, input.second.shape(), + input_data); + } + auto tensors = net_def.tensors(); + for (auto tensor : tensors) { + auto shape = tensor.dims(); + const int64_t data_size = std::accumulate(shape.begin(), + shape.end(), 1, + std::multiplies()); + std::vector data(data_size); + memcpy(data.data(), reinterpret_cast(tensor.data()), + data_size * sizeof(T)); + net.AddInputFromArray(tensor.name(), shape, data); + } + net.RunNet(net_def, D); + + for (auto output : outputs) { + std::unique_ptr tmp_tensor( + new Tensor(GetDeviceAllocator(DeviceType::CPU), + DataTypeToEnum::v())); + auto output_shape = output.second.shape(); + const int64_t data_size = std::accumulate(output_shape.begin(), + output_shape.end(), 1, + std::multiplies()); + tmp_tensor->Resize(output.second.shape()); + float *data = tmp_tensor->mutable_data(); + memcpy(data, output.second.data().get(), data_size * sizeof(float)); + std::string output_name = MakeString("mace_output_node_", + output.first, ":0"); + ops::test::ExpectTensorNear(*tmp_tensor, + *net.GetOutput(output_name.data()), + 1e-5); + } +} + +std::map AddMemoryOptimization( + const std::vector &input_names, + const std::vector &output_names, + const std::vector> &input_shapes, + const std::vector> &output_shapes, + NetDef *net_def) { + std::map res; + int mem_id = 0; + size_t input_shape_size = input_shapes.size(); + uint32_t in_mem_block_x = 0; + uint32_t in_mem_block_y = 0; + for (size_t i = 0; i < input_shape_size; ++i) { + in_mem_block_x = std::max(in_mem_block_x, + input_shapes[i][2] * + RoundUpDiv4(input_shapes[i][3])); + in_mem_block_y = std::max(in_mem_block_y, + input_shapes[i][0] * + input_shapes[i][1]); + } + size_t input_size = input_names.size(); + for (size_t i = 0; i < input_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, in_mem_block_x, in_mem_block_y)); + res[input_names[i]] = mem_id; + mem_id++; + } + size_t output_shape_size = output_shapes.size(); + uint32_t out_mem_block_x = 0; + uint32_t out_mem_block_y = 0; + for (size_t i = 0; i < output_shape_size; ++i) { + out_mem_block_x = std::max(out_mem_block_x, + output_shapes[i][2] * + RoundUpDiv4(output_shapes[i][3])); + out_mem_block_y = std::max(out_mem_block_y, + output_shapes[i][0] * + output_shapes[i][1]); + } + size_t output_size = output_names.size(); + for (size_t i = 0; i < output_size; ++i) { + net_def->mutable_mem_arena().mutable_mem_block().push_back( + MemoryBlock(mem_id, out_mem_block_x, out_mem_block_y)); + res[output_names[i]] = mem_id; + mem_id++; + } + return res; +} + +// The height and width of input and output must be equal. +template +void MaceRun(const int in_out_size, + const std::vector> &input_shapes, + const std::vector> &output_shapes, + const std::vector &filter_shape) { + std::vector input_names; + std::vector output_names; + for (int i = 0; i < in_out_size; ++i) { + input_names.push_back(MakeString("input", i)); + output_names.push_back(MakeString("output", i)); + } + std::string filter_tensor_name = "filter"; + std::string filter_tensor_img_name = filter_tensor_name + "_image"; + + const DeviceType device = DeviceType::OPENCL; + + NetDef net_def; + + // Add memory optimization + auto mem_map = AddMemoryOptimization(input_names, output_names, + input_shapes, output_shapes, + &net_def); + + std::vector data; + ops::test::GenerateRandomRealTypeData(filter_shape, &data); + AddTensor(filter_tensor_name, filter_shape, data.data(), &net_def); + + for (size_t i = 0; i < input_names.size(); ++i) { + std::string input_name = MakeString("mace_input_node_", + input_names[i], ":0"); + BufferToImage(input_name, input_names[i], + mace::kernels::IN_OUT_CHANNEL, + {mem_map[input_names[i]]}, + &net_def); + } + BufferToImage(filter_tensor_name, filter_tensor_img_name, + mace::kernels::CONV2D_FILTER, {}, + &net_def, NetMode::INIT); + for (size_t i = 0; i < output_names.size(); ++i) { + Conv3x3(input_names[i], filter_tensor_img_name, + output_names[i], {mem_map[output_names[i]]}, + &net_def); + } + for (size_t i = 0; i < output_names.size(); ++i) { + std::string output_name = MakeString("mace_output_node_", + output_names[i], ":0"); + ImageToBuffer(output_names[i], output_name, + mace::kernels::IN_OUT_CHANNEL, &net_def); + } + + MaceEngine engine(&net_def, device, input_names, output_names); + + std::map inputs; + std::map outputs; + + for (int i = 0; i < 5; ++i) { + size_t input_shape_size = input_shapes.size(); + for (size_t j = 0; j < input_shape_size; ++j) { + inputs.clear(); + outputs.clear(); + GenerateInputs(input_names, input_shapes[j], &inputs); + GenerateOutputs(output_names, output_shapes[j], &outputs); + engine.Run(inputs, &outputs); + } + } + + CheckOutputs(net_def, inputs, outputs); +} + +} // namespace + +TEST_F(MaceAPITest, GPUSingleInputOutput) { + MaceRun(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16}); + MaceRun(1, {{1, 32, 32, 16}}, {{1, 32, 32, 16}}, {3, 3, 16, 16}); +} + +TEST_F(MaceAPITest, GPUMultipleInputOutput) { + MaceRun(2, + {{1, 16, 32, 16}}, + {{1, 16, 32, 16}}, + {3, 3, 16, 16}); + MaceRun(2, + {{1, 16, 32, 16}}, + {{1, 16, 32, 16}}, + {3, 3, 16, 16}); +} + +TEST_F(MaceAPITest, GPUVariableInputShape) { + MaceRun(1, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {3, 3, 16, 16}); + MaceRun(2, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {{1, 16, 32, 16}, {1, 32, 64, 16}}, + {3, 3, 16, 16}); +} + +} // namespace test +} // namespace mace diff --git a/mace/utils/tuner.h b/mace/utils/tuner.h index fa25f6daaa5fad5ea786e83fec88d53e5e92d82e..db4f25fa8288cc65c094017c08b63465345fd5be 100644 --- a/mace/utils/tuner.h +++ b/mace/utils/tuner.h @@ -94,8 +94,8 @@ class Tuner { Tuner &operator=(const Tuner &) = delete; inline void WriteRunParameters() { - VLOG(3) << "Write tuning result to " << path_; if (path_ != nullptr) { + VLOG(3) << "Write tuning result to " << path_; std::ofstream ofs(path_, std::ios::binary | std::ios::out); if (ofs.is_open()) { int64_t num_pramas = param_table_.size();