diff --git a/mace/examples/BUILD b/mace/examples/BUILD index 516557b427cd7a07c3f7d350ab4807cae38d8911..d55eb4f20360226b4b53ee479ad56e5c6dc68879 100644 --- a/mace/examples/BUILD +++ b/mace/examples/BUILD @@ -12,7 +12,7 @@ cc_binary( "@org_tensorflow//tensorflow/core:android_tensorflow_lib", ], copts = ["-std=c++11"], - linkopts = if_android(["-ldl"]), + linkopts = ["-fopenmp",] + if_android(["-ldl"]), ) cc_test( @@ -23,7 +23,22 @@ cc_test( "//mace/core:test_benchmark_main", ], copts = ["-std=c++11"], - linkopts = if_android(["-ldl"]), + linkopts = ["-fopenmp",] + if_android(["-ldl"]), linkstatic = 1, testonly = 1, ) + +cc_binary( + name = "mace_run", + srcs = [ + "mace_run.cc", + ], + deps = [ + "//mace/core", + "//mace/utils", + "//mace/ops", + ], + copts = ["-std=c++11",], + linkopts = ["-fopenmp",] + if_android(["-ldl"]), + linkstatic = 1, +) diff --git a/mace/examples/mace_run.cc b/mace/examples/mace_run.cc new file mode 100644 index 0000000000000000000000000000000000000000..e969010cc0d893eb469655bf399973eb260a253d --- /dev/null +++ b/mace/examples/mace_run.cc @@ -0,0 +1,121 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +/** + * Usage: + * mace_run --model=mobi_mace.pb \ + * --input=input_node \ + * --output=MobilenetV1/Logits/conv2d/convolution \ + * --input_shape=1,3,224,224 \ + * --input_file=input_data \ + * --output_file=mace.out \ + * --device=NEON + */ +#include +#include +#include "mace/core/net.h" +#include "mace/utils/command_line_flags.h" + +using namespace std; +using namespace mace; + +void ParseShape(const string &str, vector *shape) { + string tmp = str; + while (!tmp.empty()) { + int dim = atoi(tmp.data()); + shape->push_back(dim); + size_t next_offset = tmp.find(","); + if (next_offset == string::npos) { + break; + } else { + tmp = tmp.substr(next_offset + 1); + } + } +} + +int main(int argc, char **argv) { + string model_file; + string input_node; + string output_node; + string input_shape; + string input_file; + string output_file; + string device; + int round = 1; + + std::vector flag_list = { + Flag("model", &model_file, "model file name"), + Flag("input", &input_node, "input node"), + Flag("output", &output_node, "output node"), + Flag("input_shape", &input_shape, "input shape, separated by comma"), + Flag("input_file", &input_file, "input file name"), + Flag("output_file", &output_file, "output file name"), + Flag("device", &device, "CPU/NEON"), + Flag("round", &round, "round"), + }; + + string usage = Flags::Usage(argv[0], flag_list); + const bool parse_result = Flags::Parse(&argc, argv, flag_list); + + if (!parse_result) { + LOG(ERROR) << usage; + return -1; + } + + VLOG(0) << "model: " << model_file << std::endl + << "input: " << input_node << std::endl + << "output: " << output_node << std::endl + << "input_shape: " << input_shape << std::endl + << "input_file: " << input_file << std::endl + << "output_file: " << output_file << std::endl + << "device: " << device << std::endl + << "round: " << round << std::endl; + + vector shape; + ParseShape(input_shape, &shape); + + // load model + ifstream file_stream(model_file, ios::in | ios::binary); + NetDef net_def; + net_def.ParseFromIstream(&file_stream); + file_stream.close(); + + Workspace ws; + ws.LoadModelTensor(net_def, DeviceType::CPU); + Tensor *input_tensor = ws.CreateTensor(input_node + ":0", + cpu_allocator(), DT_FLOAT); + input_tensor->Resize(shape); + float *input_data = input_tensor->mutable_data(); + + + // load input + ifstream in_file(input_file, ios::in | ios::binary); + in_file.read(reinterpret_cast(input_data), + input_tensor->size() * sizeof(float)); + in_file.close(); + + // run model + DeviceType device_type; + DeviceType_Parse(device, &device_type); + VLOG(0) << device_type; + auto net = CreateNet(net_def, &ws, device_type); + + timeval tv1, tv2; + gettimeofday(&tv1, NULL); + for (int i = 0; i < round; ++i) { + net->Run(); + } + gettimeofday(&tv2, NULL); + cout << "avg duration: " << ((tv2.tv_sec - tv1.tv_sec) * 1000 + + (tv2.tv_usec - tv1.tv_usec) / 1000) / round << endl; + + // save output + const Tensor *output = ws.GetTensor(output_node + ":0"); + + ofstream out_file(output_file, ios::binary); + out_file.write((const char *) (output->data()), + output->size() * sizeof(float)); + out_file.flush(); + out_file.close(); +} \ No newline at end of file diff --git a/mace/python/tools/mace_pb2.py b/mace/python/tools/mace_pb2.py new file mode 100755 index 0000000000000000000000000000000000000000..190ecb886f3aef331e331526fd6ca2a3a55b577e --- /dev/null +++ b/mace/python/tools/mace_pb2.py @@ -0,0 +1,456 @@ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: mace/proto/mace.proto + +import sys +_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1')) +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf import reflection as _reflection +from google.protobuf import symbol_database as _symbol_database +from google.protobuf import descriptor_pb2 +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor.FileDescriptor( + name='mace/proto/mace.proto', + package='mace', + syntax='proto2', + serialized_pb=_b('\n\x15mace/proto/mace.proto\x12\x04mace\"\xdf\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12+\n\tdata_type\x18\x02 \x01(\x0e\x32\x0e.mace.DataType:\x08\x44T_FLOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x17\n\x0b\x64ouble_data\x18\t \x03(\x01\x42\x02\x10\x01\x12\x16\n\nint64_data\x18\n \x03(\x03\x42\x02\x10\x01\x12\x0c\n\x04name\x18\x07 \x01(\t\"h\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\x0c\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x03\x12\x0f\n\x07strings\x18\x07 \x03(\x0c\"e\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x1b\n\x03\x61rg\x18\x05 \x03(\x0b\x32\x0e.mace.Argument\"\x87\x01\n\x06NetDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1d\n\x02op\x18\x02 \x03(\x0b\x32\x11.mace.OperatorDef\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1b\n\x03\x61rg\x18\x04 \x03(\x0b\x32\x0e.mace.Argument\x12\"\n\x07tensors\x18\x05 \x03(\x0b\x32\x11.mace.TensorProto*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04NEON\x10\x01\x12\n\n\x06OPENCL\x10\x02*\xa7\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x0c\n\x08\x44T_INT64\x10\x08\x12\r\n\tDT_UINT16\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n') +) + +_DEVICETYPE = _descriptor.EnumDescriptor( + name='DeviceType', + full_name='mace.DeviceType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='CPU', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='NEON', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='OPENCL', index=2, number=2, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=604, + serialized_end=647, +) +_sym_db.RegisterEnumDescriptor(_DEVICETYPE) + +DeviceType = enum_type_wrapper.EnumTypeWrapper(_DEVICETYPE) +_DATATYPE = _descriptor.EnumDescriptor( + name='DataType', + full_name='mace.DataType', + filename=None, + file=DESCRIPTOR, + values=[ + _descriptor.EnumValueDescriptor( + name='DT_INVALID', index=0, number=0, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_FLOAT', index=1, number=1, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_DOUBLE', index=2, number=2, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT32', index=3, number=3, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT8', index=4, number=4, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT16', index=5, number=5, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT8', index=6, number=6, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_STRING', index=7, number=7, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_INT64', index=8, number=8, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_UINT16', index=9, number=9, + options=None, + type=None), + _descriptor.EnumValueDescriptor( + name='DT_BOOL', index=10, number=10, + options=None, + type=None), + ], + containing_type=None, + options=None, + serialized_start=650, + serialized_end=817, +) +_sym_db.RegisterEnumDescriptor(_DATATYPE) + +DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE) +CPU = 0 +NEON = 1 +OPENCL = 2 +DT_INVALID = 0 +DT_FLOAT = 1 +DT_DOUBLE = 2 +DT_INT32 = 3 +DT_UINT8 = 4 +DT_INT16 = 5 +DT_INT8 = 6 +DT_STRING = 7 +DT_INT64 = 8 +DT_UINT16 = 9 +DT_BOOL = 10 + + + +_TENSORPROTO = _descriptor.Descriptor( + name='TensorProto', + full_name='mace.TensorProto', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='dims', full_name='mace.TensorProto.dims', index=0, + number=1, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='data_type', full_name='mace.TensorProto.data_type', index=1, + number=2, type=14, cpp_type=8, label=1, + has_default_value=True, default_value=1, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='float_data', full_name='mace.TensorProto.float_data', index=2, + number=3, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), + _descriptor.FieldDescriptor( + name='int32_data', full_name='mace.TensorProto.int32_data', index=3, + number=4, type=5, cpp_type=1, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), + _descriptor.FieldDescriptor( + name='byte_data', full_name='mace.TensorProto.byte_data', index=4, + number=5, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='string_data', full_name='mace.TensorProto.string_data', index=5, + number=6, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='double_data', full_name='mace.TensorProto.double_data', index=6, + number=9, type=1, cpp_type=5, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), + _descriptor.FieldDescriptor( + name='int64_data', full_name='mace.TensorProto.int64_data', index=7, + number=10, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))), + _descriptor.FieldDescriptor( + name='name', full_name='mace.TensorProto.name', index=8, + number=7, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=32, + serialized_end=255, +) + + +_ARGUMENT = _descriptor.Descriptor( + name='Argument', + full_name='mace.Argument', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='mace.Argument.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='f', full_name='mace.Argument.f', index=1, + number=2, type=2, cpp_type=6, label=1, + has_default_value=False, default_value=float(0), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='i', full_name='mace.Argument.i', index=2, + number=3, type=3, cpp_type=2, label=1, + has_default_value=False, default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='s', full_name='mace.Argument.s', index=3, + number=4, type=12, cpp_type=9, label=1, + has_default_value=False, default_value=_b(""), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='floats', full_name='mace.Argument.floats', index=4, + number=5, type=2, cpp_type=6, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='ints', full_name='mace.Argument.ints', index=5, + number=6, type=3, cpp_type=2, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='strings', full_name='mace.Argument.strings', index=6, + number=7, type=12, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=257, + serialized_end=361, +) + + +_OPERATORDEF = _descriptor.Descriptor( + name='OperatorDef', + full_name='mace.OperatorDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='input', full_name='mace.OperatorDef.input', index=0, + number=1, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='output', full_name='mace.OperatorDef.output', index=1, + number=2, type=9, cpp_type=9, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='name', full_name='mace.OperatorDef.name', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='type', full_name='mace.OperatorDef.type', index=3, + number=4, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='arg', full_name='mace.OperatorDef.arg', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=363, + serialized_end=464, +) + + +_NETDEF = _descriptor.Descriptor( + name='NetDef', + full_name='mace.NetDef', + filename=None, + file=DESCRIPTOR, + containing_type=None, + fields=[ + _descriptor.FieldDescriptor( + name='name', full_name='mace.NetDef.name', index=0, + number=1, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='op', full_name='mace.NetDef.op', index=1, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='version', full_name='mace.NetDef.version', index=2, + number=3, type=9, cpp_type=9, label=1, + has_default_value=False, default_value=_b("").decode('utf-8'), + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='arg', full_name='mace.NetDef.arg', index=3, + number=4, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + _descriptor.FieldDescriptor( + name='tensors', full_name='mace.NetDef.tensors', index=4, + number=5, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + options=None), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + options=None, + is_extendable=False, + syntax='proto2', + extension_ranges=[], + oneofs=[ + ], + serialized_start=467, + serialized_end=602, +) + +_TENSORPROTO.fields_by_name['data_type'].enum_type = _DATATYPE +_OPERATORDEF.fields_by_name['arg'].message_type = _ARGUMENT +_NETDEF.fields_by_name['op'].message_type = _OPERATORDEF +_NETDEF.fields_by_name['arg'].message_type = _ARGUMENT +_NETDEF.fields_by_name['tensors'].message_type = _TENSORPROTO +DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO +DESCRIPTOR.message_types_by_name['Argument'] = _ARGUMENT +DESCRIPTOR.message_types_by_name['OperatorDef'] = _OPERATORDEF +DESCRIPTOR.message_types_by_name['NetDef'] = _NETDEF +DESCRIPTOR.enum_types_by_name['DeviceType'] = _DEVICETYPE +DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE +_sym_db.RegisterFileDescriptor(DESCRIPTOR) + +TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict( + DESCRIPTOR = _TENSORPROTO, + __module__ = 'mace.proto.mace_pb2' + # @@protoc_insertion_point(class_scope:mace.TensorProto) + )) +_sym_db.RegisterMessage(TensorProto) + +Argument = _reflection.GeneratedProtocolMessageType('Argument', (_message.Message,), dict( + DESCRIPTOR = _ARGUMENT, + __module__ = 'mace.proto.mace_pb2' + # @@protoc_insertion_point(class_scope:mace.Argument) + )) +_sym_db.RegisterMessage(Argument) + +OperatorDef = _reflection.GeneratedProtocolMessageType('OperatorDef', (_message.Message,), dict( + DESCRIPTOR = _OPERATORDEF, + __module__ = 'mace.proto.mace_pb2' + # @@protoc_insertion_point(class_scope:mace.OperatorDef) + )) +_sym_db.RegisterMessage(OperatorDef) + +NetDef = _reflection.GeneratedProtocolMessageType('NetDef', (_message.Message,), dict( + DESCRIPTOR = _NETDEF, + __module__ = 'mace.proto.mace_pb2' + # @@protoc_insertion_point(class_scope:mace.NetDef) + )) +_sym_db.RegisterMessage(NetDef) + + +_TENSORPROTO.fields_by_name['float_data'].has_options = True +_TENSORPROTO.fields_by_name['float_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) +_TENSORPROTO.fields_by_name['int32_data'].has_options = True +_TENSORPROTO.fields_by_name['int32_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) +_TENSORPROTO.fields_by_name['double_data'].has_options = True +_TENSORPROTO.fields_by_name['double_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) +_TENSORPROTO.fields_by_name['int64_data'].has_options = True +_TENSORPROTO.fields_by_name['int64_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001')) +# @@protoc_insertion_point(module_scope) diff --git a/mace/utils/BUILD b/mace/utils/BUILD new file mode 100644 index 0000000000000000000000000000000000000000..aa5f65d312ec6db4d856d66dd12da6adf682f6d1 --- /dev/null +++ b/mace/utils/BUILD @@ -0,0 +1,22 @@ +# Description: +# Mace utils. +# +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +cc_library( + name = "utils", + srcs = glob([ + "*.cc", + ]), + hdrs = glob([ + "*.h", + ]), + copts = ["-std=c++11"], + deps = [ + "//mace/core:core", + ], +) \ No newline at end of file diff --git a/mace/utils/command_line_flags.cc b/mace/utils/command_line_flags.cc new file mode 100644 index 0000000000000000000000000000000000000000..d9a249b8b56ab6b294de2447a21b14f3bef980eb --- /dev/null +++ b/mace/utils/command_line_flags.cc @@ -0,0 +1,208 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#include "mace/utils/command_line_flags.h" +#include +#include + +namespace mace { +namespace { + +bool StringConsume(string &arg, const string &x) { + if ((arg.size() >= x.size()) + && (memcmp(arg.data(), x.data(), x.size()) == 0)) { + arg = arg.substr(x.size()); + return true; + } + return false; +} + +bool ParseStringFlag(string arg, string flag, + string *dst, bool *value_parsing_ok) { + *value_parsing_ok = true; + if (StringConsume(arg, "--") && StringConsume(arg, flag) + && StringConsume(arg, "=")) { + *dst = arg; + return true; + } + + return false; +} + +bool ParseInt32Flag(string arg, string flag, + int32_t *dst, bool *value_parsing_ok) { + *value_parsing_ok = true; + if (StringConsume(arg, "--") && StringConsume(arg, flag) + && StringConsume(arg, "=")) { + char extra; + if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + +bool ParseInt64Flag(string arg, string flag, + long long *dst, bool *value_parsing_ok) { + *value_parsing_ok = true; + if (StringConsume(arg, "--") && StringConsume(arg, flag) + && StringConsume(arg, "=")) { + char extra; + if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + +bool ParseBoolFlag(string arg, string flag, + bool *dst, bool *value_parsing_ok) { + *value_parsing_ok = true; + if (StringConsume(arg, "--") && StringConsume(arg, flag)) { + if (arg.empty()) { + *dst = true; + return true; + } + + if (arg == "=true") { + *dst = true; + return true; + } else if (arg == "=false") { + *dst = false; + return true; + } else { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + return true; + } + } + + return false; +} + +bool ParseFloatFlag(string arg, string flag, + float *dst, bool *value_parsing_ok) { + *value_parsing_ok = true; + if (StringConsume(arg, "--") && StringConsume(arg, flag) + && StringConsume(arg, "=")) { + char extra; + if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) { + LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag + << "."; + *value_parsing_ok = false; + } + return true; + } + + return false; +} + +} // namespace + +Flag::Flag(const char *name, int *dst, const string &usage_text) + : name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {} + +Flag::Flag(const char *name, long long *dst, const string &usage_text) + : name_(name), + type_(TYPE_INT64), + int64_value_(dst), + usage_text_(usage_text) {} + +Flag::Flag(const char *name, bool *dst, const string &usage_text) + : name_(name), + type_(TYPE_BOOL), + bool_value_(dst), + usage_text_(usage_text) {} + +Flag::Flag(const char *name, string *dst, const string &usage_text) + : name_(name), + type_(TYPE_STRING), + string_value_(dst), + usage_text_(usage_text) {} + +Flag::Flag(const char *name, float *dst, const string &usage_text) + : name_(name), + type_(TYPE_FLOAT), + float_value_(dst), + usage_text_(usage_text) {} + +bool Flag::Parse(string arg, bool *value_parsing_ok) const { + bool result = false; + if (type_ == TYPE_INT) { + result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok); + } else if (type_ == TYPE_INT64) { + result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok); + } else if (type_ == TYPE_BOOL) { + result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok); + } else if (type_ == TYPE_STRING) { + result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok); + } else if (type_ == TYPE_FLOAT) { + result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok); + } + return result; +} + +/*static*/ bool Flags::Parse(int *argc, char **argv, + const std::vector &flag_list) { + bool result = true; + std::vector unknown_flags; + for (int i = 1; i < *argc; ++i) { + if (string(argv[i]) == "--") { + while (i < *argc) { + unknown_flags.push_back(argv[i]); + ++i; + } + break; + } + + bool was_found = false; + for (const Flag &flag : flag_list) { + bool value_parsing_ok; + was_found = flag.Parse(argv[i], &value_parsing_ok); + if (!value_parsing_ok) { + result = false; + } + if (was_found) { + break; + } + } + if (!was_found) { + unknown_flags.push_back(argv[i]); + } + } + // Passthrough any extra flags. + int dst = 1; // Skip argv[0] + for (char *f : unknown_flags) { + argv[dst++] = f; + } + argv[dst++] = nullptr; + *argc = unknown_flags.size() + 1; + return result && (*argc < 2 || strcmp(argv[1], "--help") != 0); +} + +/*static*/ string Flags::Usage(const string &cmdline, + const std::vector &flag_list) { + std::stringstream usage_text; + usage_text << "usage: " << cmdline << std::endl; + + if (!flag_list.empty()) { + usage_text << "Flags: " << std::endl; + } + for (const Flag &flag : flag_list) { + usage_text << "\t" << std::left << std::setw(30) << flag.name_; + usage_text << flag.usage_text_ << std::endl; + } + return usage_text.str(); +} + +} // namespace mace diff --git a/mace/utils/command_line_flags.h b/mace/utils/command_line_flags.h new file mode 100644 index 0000000000000000000000000000000000000000..0d3daf286537debc341baf41af962aead3e8e348 --- /dev/null +++ b/mace/utils/command_line_flags.h @@ -0,0 +1,54 @@ +// +// Copyright (c) 2017 XiaoMi All rights reserved. +// + +#ifndef MACE_CORE_COMMAND_LINE_FLAGS_H +#define MACE_CORE_COMMAND_LINE_FLAGS_H + +#include "mace/core/common.h" + +namespace mace { + +class Flag { + public: + Flag(const char *name, int *dst1, const string &usage_text); + Flag(const char *name, long long *dst1, const string &usage_text); + Flag(const char *name, bool *dst, const string &usage_text); + Flag(const char *name, string *dst, const string &usage_text); + Flag(const char *name, float *dst, const string &usage_text); + + private: + friend class Flags; + + bool Parse(string arg, bool *value_parsing_ok) const; + + string name_; + enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_; + int *int_value_; + long long *int64_value_; + bool *bool_value_; + string *string_value_; + float *float_value_; + string usage_text_; +}; + +class Flags { + public: + // Parse the command line represented by argv[0, ..., (*argc)-1] to find flag + // instances matching flags in flaglist[]. Update the variables associated + // with matching flags, and remove the matching arguments from (*argc, argv). + // Return true iff all recognized flag values were parsed correctly, and the + // first remaining argument is not "--help". + static bool Parse(int *argc, + char **argv, + const std::vector &flag_list); + + // Return a usage message with command line cmdline, and the + // usage_text strings in flag_list[]. + static string Usage(const string &cmdline, + const std::vector &flag_list); +}; + +} // namespace mace + +#endif // MACE_CORE_COMMAND_LINE_FLAGS_H