提交 eedb3921 编写于 作者: 李寅

Add mace run tool

上级 2b2cd38f
......@@ -12,7 +12,7 @@ cc_binary(
"@org_tensorflow//tensorflow/core:android_tensorflow_lib",
],
copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]),
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
)
cc_test(
......@@ -23,7 +23,22 @@ cc_test(
"//mace/core:test_benchmark_main",
],
copts = ["-std=c++11"],
linkopts = if_android(["-ldl"]),
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1,
testonly = 1,
)
cc_binary(
name = "mace_run",
srcs = [
"mace_run.cc",
],
deps = [
"//mace/core",
"//mace/utils",
"//mace/ops",
],
copts = ["-std=c++11",],
linkopts = ["-fopenmp",] + if_android(["-ldl"]),
linkstatic = 1,
)
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
/**
* Usage:
* mace_run --model=mobi_mace.pb \
* --input=input_node \
* --output=MobilenetV1/Logits/conv2d/convolution \
* --input_shape=1,3,224,224 \
* --input_file=input_data \
* --output_file=mace.out \
* --device=NEON
*/
#include <fstream>
#include <sys/time.h>
#include "mace/core/net.h"
#include "mace/utils/command_line_flags.h"
using namespace std;
using namespace mace;
void ParseShape(const string &str, vector<index_t> *shape) {
string tmp = str;
while (!tmp.empty()) {
int dim = atoi(tmp.data());
shape->push_back(dim);
size_t next_offset = tmp.find(",");
if (next_offset == string::npos) {
break;
} else {
tmp = tmp.substr(next_offset + 1);
}
}
}
int main(int argc, char **argv) {
string model_file;
string input_node;
string output_node;
string input_shape;
string input_file;
string output_file;
string device;
int round = 1;
std::vector<Flag> flag_list = {
Flag("model", &model_file, "model file name"),
Flag("input", &input_node, "input node"),
Flag("output", &output_node, "output node"),
Flag("input_shape", &input_shape, "input shape, separated by comma"),
Flag("input_file", &input_file, "input file name"),
Flag("output_file", &output_file, "output file name"),
Flag("device", &device, "CPU/NEON"),
Flag("round", &round, "round"),
};
string usage = Flags::Usage(argv[0], flag_list);
const bool parse_result = Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << usage;
return -1;
}
VLOG(0) << "model: " << model_file << std::endl
<< "input: " << input_node << std::endl
<< "output: " << output_node << std::endl
<< "input_shape: " << input_shape << std::endl
<< "input_file: " << input_file << std::endl
<< "output_file: " << output_file << std::endl
<< "device: " << device << std::endl
<< "round: " << round << std::endl;
vector<index_t> shape;
ParseShape(input_shape, &shape);
// load model
ifstream file_stream(model_file, ios::in | ios::binary);
NetDef net_def;
net_def.ParseFromIstream(&file_stream);
file_stream.close();
Workspace ws;
ws.LoadModelTensor(net_def, DeviceType::CPU);
Tensor *input_tensor = ws.CreateTensor(input_node + ":0",
cpu_allocator(), DT_FLOAT);
input_tensor->Resize(shape);
float *input_data = input_tensor->mutable_data<float>();
// load input
ifstream in_file(input_file, ios::in | ios::binary);
in_file.read(reinterpret_cast<char *>(input_data),
input_tensor->size() * sizeof(float));
in_file.close();
// run model
DeviceType device_type;
DeviceType_Parse(device, &device_type);
VLOG(0) << device_type;
auto net = CreateNet(net_def, &ws, device_type);
timeval tv1, tv2;
gettimeofday(&tv1, NULL);
for (int i = 0; i < round; ++i) {
net->Run();
}
gettimeofday(&tv2, NULL);
cout << "avg duration: " << ((tv2.tv_sec - tv1.tv_sec) * 1000
+ (tv2.tv_usec - tv1.tv_usec) / 1000) / round << endl;
// save output
const Tensor *output = ws.GetTensor(output_node + ":0");
ofstream out_file(output_file, ios::binary);
out_file.write((const char *) (output->data<float>()),
output->size() * sizeof(float));
out_file.flush();
out_file.close();
}
\ No newline at end of file
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: mace/proto/mace.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from google.protobuf import reflection as _reflection
from google.protobuf import symbol_database as _symbol_database
from google.protobuf import descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor.FileDescriptor(
name='mace/proto/mace.proto',
package='mace',
syntax='proto2',
serialized_pb=_b('\n\x15mace/proto/mace.proto\x12\x04mace\"\xdf\x01\n\x0bTensorProto\x12\x0c\n\x04\x64ims\x18\x01 \x03(\x03\x12+\n\tdata_type\x18\x02 \x01(\x0e\x32\x0e.mace.DataType:\x08\x44T_FLOAT\x12\x16\n\nfloat_data\x18\x03 \x03(\x02\x42\x02\x10\x01\x12\x16\n\nint32_data\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x11\n\tbyte_data\x18\x05 \x01(\x0c\x12\x13\n\x0bstring_data\x18\x06 \x03(\x0c\x12\x17\n\x0b\x64ouble_data\x18\t \x03(\x01\x42\x02\x10\x01\x12\x16\n\nint64_data\x18\n \x03(\x03\x42\x02\x10\x01\x12\x0c\n\x04name\x18\x07 \x01(\t\"h\n\x08\x41rgument\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\t\n\x01\x66\x18\x02 \x01(\x02\x12\t\n\x01i\x18\x03 \x01(\x03\x12\t\n\x01s\x18\x04 \x01(\x0c\x12\x0e\n\x06\x66loats\x18\x05 \x03(\x02\x12\x0c\n\x04ints\x18\x06 \x03(\x03\x12\x0f\n\x07strings\x18\x07 \x03(\x0c\"e\n\x0bOperatorDef\x12\r\n\x05input\x18\x01 \x03(\t\x12\x0e\n\x06output\x18\x02 \x03(\t\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\x0c\n\x04type\x18\x04 \x01(\t\x12\x1b\n\x03\x61rg\x18\x05 \x03(\x0b\x32\x0e.mace.Argument\"\x87\x01\n\x06NetDef\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x1d\n\x02op\x18\x02 \x03(\x0b\x32\x11.mace.OperatorDef\x12\x0f\n\x07version\x18\x03 \x01(\t\x12\x1b\n\x03\x61rg\x18\x04 \x03(\x0b\x32\x0e.mace.Argument\x12\"\n\x07tensors\x18\x05 \x03(\x0b\x32\x11.mace.TensorProto*+\n\nDeviceType\x12\x07\n\x03\x43PU\x10\x00\x12\x08\n\x04NEON\x10\x01\x12\n\n\x06OPENCL\x10\x02*\xa7\x01\n\x08\x44\x61taType\x12\x0e\n\nDT_INVALID\x10\x00\x12\x0c\n\x08\x44T_FLOAT\x10\x01\x12\r\n\tDT_DOUBLE\x10\x02\x12\x0c\n\x08\x44T_INT32\x10\x03\x12\x0c\n\x08\x44T_UINT8\x10\x04\x12\x0c\n\x08\x44T_INT16\x10\x05\x12\x0b\n\x07\x44T_INT8\x10\x06\x12\r\n\tDT_STRING\x10\x07\x12\x0c\n\x08\x44T_INT64\x10\x08\x12\r\n\tDT_UINT16\x10\t\x12\x0b\n\x07\x44T_BOOL\x10\n')
)
_DEVICETYPE = _descriptor.EnumDescriptor(
name='DeviceType',
full_name='mace.DeviceType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='CPU', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='NEON', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='OPENCL', index=2, number=2,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=604,
serialized_end=647,
)
_sym_db.RegisterEnumDescriptor(_DEVICETYPE)
DeviceType = enum_type_wrapper.EnumTypeWrapper(_DEVICETYPE)
_DATATYPE = _descriptor.EnumDescriptor(
name='DataType',
full_name='mace.DataType',
filename=None,
file=DESCRIPTOR,
values=[
_descriptor.EnumValueDescriptor(
name='DT_INVALID', index=0, number=0,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_FLOAT', index=1, number=1,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_DOUBLE', index=2, number=2,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT32', index=3, number=3,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_UINT8', index=4, number=4,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT16', index=5, number=5,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT8', index=6, number=6,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_STRING', index=7, number=7,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_INT64', index=8, number=8,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_UINT16', index=9, number=9,
options=None,
type=None),
_descriptor.EnumValueDescriptor(
name='DT_BOOL', index=10, number=10,
options=None,
type=None),
],
containing_type=None,
options=None,
serialized_start=650,
serialized_end=817,
)
_sym_db.RegisterEnumDescriptor(_DATATYPE)
DataType = enum_type_wrapper.EnumTypeWrapper(_DATATYPE)
CPU = 0
NEON = 1
OPENCL = 2
DT_INVALID = 0
DT_FLOAT = 1
DT_DOUBLE = 2
DT_INT32 = 3
DT_UINT8 = 4
DT_INT16 = 5
DT_INT8 = 6
DT_STRING = 7
DT_INT64 = 8
DT_UINT16 = 9
DT_BOOL = 10
_TENSORPROTO = _descriptor.Descriptor(
name='TensorProto',
full_name='mace.TensorProto',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='dims', full_name='mace.TensorProto.dims', index=0,
number=1, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='data_type', full_name='mace.TensorProto.data_type', index=1,
number=2, type=14, cpp_type=8, label=1,
has_default_value=True, default_value=1,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='float_data', full_name='mace.TensorProto.float_data', index=2,
number=3, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='int32_data', full_name='mace.TensorProto.int32_data', index=3,
number=4, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='byte_data', full_name='mace.TensorProto.byte_data', index=4,
number=5, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='string_data', full_name='mace.TensorProto.string_data', index=5,
number=6, type=12, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='double_data', full_name='mace.TensorProto.double_data', index=6,
number=9, type=1, cpp_type=5, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='int64_data', full_name='mace.TensorProto.int64_data', index=7,
number=10, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=_descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))),
_descriptor.FieldDescriptor(
name='name', full_name='mace.TensorProto.name', index=8,
number=7, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=32,
serialized_end=255,
)
_ARGUMENT = _descriptor.Descriptor(
name='Argument',
full_name='mace.Argument',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='mace.Argument.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='f', full_name='mace.Argument.f', index=1,
number=2, type=2, cpp_type=6, label=1,
has_default_value=False, default_value=float(0),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='i', full_name='mace.Argument.i', index=2,
number=3, type=3, cpp_type=2, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='s', full_name='mace.Argument.s', index=3,
number=4, type=12, cpp_type=9, label=1,
has_default_value=False, default_value=_b(""),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='floats', full_name='mace.Argument.floats', index=4,
number=5, type=2, cpp_type=6, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='ints', full_name='mace.Argument.ints', index=5,
number=6, type=3, cpp_type=2, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='strings', full_name='mace.Argument.strings', index=6,
number=7, type=12, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=257,
serialized_end=361,
)
_OPERATORDEF = _descriptor.Descriptor(
name='OperatorDef',
full_name='mace.OperatorDef',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='input', full_name='mace.OperatorDef.input', index=0,
number=1, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='output', full_name='mace.OperatorDef.output', index=1,
number=2, type=9, cpp_type=9, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='name', full_name='mace.OperatorDef.name', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='type', full_name='mace.OperatorDef.type', index=3,
number=4, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arg', full_name='mace.OperatorDef.arg', index=4,
number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=363,
serialized_end=464,
)
_NETDEF = _descriptor.Descriptor(
name='NetDef',
full_name='mace.NetDef',
filename=None,
file=DESCRIPTOR,
containing_type=None,
fields=[
_descriptor.FieldDescriptor(
name='name', full_name='mace.NetDef.name', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='op', full_name='mace.NetDef.op', index=1,
number=2, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='version', full_name='mace.NetDef.version', index=2,
number=3, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='arg', full_name='mace.NetDef.arg', index=3,
number=4, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
_descriptor.FieldDescriptor(
name='tensors', full_name='mace.NetDef.tensors', index=4,
number=5, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
options=None),
],
extensions=[
],
nested_types=[],
enum_types=[
],
options=None,
is_extendable=False,
syntax='proto2',
extension_ranges=[],
oneofs=[
],
serialized_start=467,
serialized_end=602,
)
_TENSORPROTO.fields_by_name['data_type'].enum_type = _DATATYPE
_OPERATORDEF.fields_by_name['arg'].message_type = _ARGUMENT
_NETDEF.fields_by_name['op'].message_type = _OPERATORDEF
_NETDEF.fields_by_name['arg'].message_type = _ARGUMENT
_NETDEF.fields_by_name['tensors'].message_type = _TENSORPROTO
DESCRIPTOR.message_types_by_name['TensorProto'] = _TENSORPROTO
DESCRIPTOR.message_types_by_name['Argument'] = _ARGUMENT
DESCRIPTOR.message_types_by_name['OperatorDef'] = _OPERATORDEF
DESCRIPTOR.message_types_by_name['NetDef'] = _NETDEF
DESCRIPTOR.enum_types_by_name['DeviceType'] = _DEVICETYPE
DESCRIPTOR.enum_types_by_name['DataType'] = _DATATYPE
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
TensorProto = _reflection.GeneratedProtocolMessageType('TensorProto', (_message.Message,), dict(
DESCRIPTOR = _TENSORPROTO,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.TensorProto)
))
_sym_db.RegisterMessage(TensorProto)
Argument = _reflection.GeneratedProtocolMessageType('Argument', (_message.Message,), dict(
DESCRIPTOR = _ARGUMENT,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.Argument)
))
_sym_db.RegisterMessage(Argument)
OperatorDef = _reflection.GeneratedProtocolMessageType('OperatorDef', (_message.Message,), dict(
DESCRIPTOR = _OPERATORDEF,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.OperatorDef)
))
_sym_db.RegisterMessage(OperatorDef)
NetDef = _reflection.GeneratedProtocolMessageType('NetDef', (_message.Message,), dict(
DESCRIPTOR = _NETDEF,
__module__ = 'mace.proto.mace_pb2'
# @@protoc_insertion_point(class_scope:mace.NetDef)
))
_sym_db.RegisterMessage(NetDef)
_TENSORPROTO.fields_by_name['float_data'].has_options = True
_TENSORPROTO.fields_by_name['float_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['int32_data'].has_options = True
_TENSORPROTO.fields_by_name['int32_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['double_data'].has_options = True
_TENSORPROTO.fields_by_name['double_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
_TENSORPROTO.fields_by_name['int64_data'].has_options = True
_TENSORPROTO.fields_by_name['int64_data']._options = _descriptor._ParseOptions(descriptor_pb2.FieldOptions(), _b('\020\001'))
# @@protoc_insertion_point(module_scope)
# Description:
# Mace utils.
#
package(
default_visibility = ["//visibility:public"],
)
licenses(["notice"]) # Apache 2.0
cc_library(
name = "utils",
srcs = glob([
"*.cc",
]),
hdrs = glob([
"*.h",
]),
copts = ["-std=c++11"],
deps = [
"//mace/core:core",
],
)
\ No newline at end of file
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/utils/command_line_flags.h"
#include <cstring>
#include <iomanip>
namespace mace {
namespace {
bool StringConsume(string &arg, const string &x) {
if ((arg.size() >= x.size())
&& (memcmp(arg.data(), x.data(), x.size()) == 0)) {
arg = arg.substr(x.size());
return true;
}
return false;
}
bool ParseStringFlag(string arg, string flag,
string *dst, bool *value_parsing_ok) {
*value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)
&& StringConsume(arg, "=")) {
*dst = arg;
return true;
}
return false;
}
bool ParseInt32Flag(string arg, string flag,
int32_t *dst, bool *value_parsing_ok) {
*value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)
&& StringConsume(arg, "=")) {
char extra;
if (sscanf(arg.data(), "%d%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
}
return true;
}
return false;
}
bool ParseInt64Flag(string arg, string flag,
long long *dst, bool *value_parsing_ok) {
*value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)
&& StringConsume(arg, "=")) {
char extra;
if (sscanf(arg.data(), "%lld%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
}
return true;
}
return false;
}
bool ParseBoolFlag(string arg, string flag,
bool *dst, bool *value_parsing_ok) {
*value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)) {
if (arg.empty()) {
*dst = true;
return true;
}
if (arg == "=true") {
*dst = true;
return true;
} else if (arg == "=false") {
*dst = false;
return true;
} else {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
return true;
}
}
return false;
}
bool ParseFloatFlag(string arg, string flag,
float *dst, bool *value_parsing_ok) {
*value_parsing_ok = true;
if (StringConsume(arg, "--") && StringConsume(arg, flag)
&& StringConsume(arg, "=")) {
char extra;
if (sscanf(arg.data(), "%f%c", dst, &extra) != 1) {
LOG(ERROR) << "Couldn't interpret value " << arg << " for flag " << flag
<< ".";
*value_parsing_ok = false;
}
return true;
}
return false;
}
} // namespace
Flag::Flag(const char *name, int *dst, const string &usage_text)
: name_(name), type_(TYPE_INT), int_value_(dst), usage_text_(usage_text) {}
Flag::Flag(const char *name, long long *dst, const string &usage_text)
: name_(name),
type_(TYPE_INT64),
int64_value_(dst),
usage_text_(usage_text) {}
Flag::Flag(const char *name, bool *dst, const string &usage_text)
: name_(name),
type_(TYPE_BOOL),
bool_value_(dst),
usage_text_(usage_text) {}
Flag::Flag(const char *name, string *dst, const string &usage_text)
: name_(name),
type_(TYPE_STRING),
string_value_(dst),
usage_text_(usage_text) {}
Flag::Flag(const char *name, float *dst, const string &usage_text)
: name_(name),
type_(TYPE_FLOAT),
float_value_(dst),
usage_text_(usage_text) {}
bool Flag::Parse(string arg, bool *value_parsing_ok) const {
bool result = false;
if (type_ == TYPE_INT) {
result = ParseInt32Flag(arg, name_, int_value_, value_parsing_ok);
} else if (type_ == TYPE_INT64) {
result = ParseInt64Flag(arg, name_, int64_value_, value_parsing_ok);
} else if (type_ == TYPE_BOOL) {
result = ParseBoolFlag(arg, name_, bool_value_, value_parsing_ok);
} else if (type_ == TYPE_STRING) {
result = ParseStringFlag(arg, name_, string_value_, value_parsing_ok);
} else if (type_ == TYPE_FLOAT) {
result = ParseFloatFlag(arg, name_, float_value_, value_parsing_ok);
}
return result;
}
/*static*/ bool Flags::Parse(int *argc, char **argv,
const std::vector<Flag> &flag_list) {
bool result = true;
std::vector<char *> unknown_flags;
for (int i = 1; i < *argc; ++i) {
if (string(argv[i]) == "--") {
while (i < *argc) {
unknown_flags.push_back(argv[i]);
++i;
}
break;
}
bool was_found = false;
for (const Flag &flag : flag_list) {
bool value_parsing_ok;
was_found = flag.Parse(argv[i], &value_parsing_ok);
if (!value_parsing_ok) {
result = false;
}
if (was_found) {
break;
}
}
if (!was_found) {
unknown_flags.push_back(argv[i]);
}
}
// Passthrough any extra flags.
int dst = 1; // Skip argv[0]
for (char *f : unknown_flags) {
argv[dst++] = f;
}
argv[dst++] = nullptr;
*argc = unknown_flags.size() + 1;
return result && (*argc < 2 || strcmp(argv[1], "--help") != 0);
}
/*static*/ string Flags::Usage(const string &cmdline,
const std::vector<Flag> &flag_list) {
std::stringstream usage_text;
usage_text << "usage: " << cmdline << std::endl;
if (!flag_list.empty()) {
usage_text << "Flags: " << std::endl;
}
for (const Flag &flag : flag_list) {
usage_text << "\t" << std::left << std::setw(30) << flag.name_;
usage_text << flag.usage_text_ << std::endl;
}
return usage_text.str();
}
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_CORE_COMMAND_LINE_FLAGS_H
#define MACE_CORE_COMMAND_LINE_FLAGS_H
#include "mace/core/common.h"
namespace mace {
class Flag {
public:
Flag(const char *name, int *dst1, const string &usage_text);
Flag(const char *name, long long *dst1, const string &usage_text);
Flag(const char *name, bool *dst, const string &usage_text);
Flag(const char *name, string *dst, const string &usage_text);
Flag(const char *name, float *dst, const string &usage_text);
private:
friend class Flags;
bool Parse(string arg, bool *value_parsing_ok) const;
string name_;
enum { TYPE_INT, TYPE_INT64, TYPE_BOOL, TYPE_STRING, TYPE_FLOAT } type_;
int *int_value_;
long long *int64_value_;
bool *bool_value_;
string *string_value_;
float *float_value_;
string usage_text_;
};
class Flags {
public:
// Parse the command line represented by argv[0, ..., (*argc)-1] to find flag
// instances matching flags in flaglist[]. Update the variables associated
// with matching flags, and remove the matching arguments from (*argc, argv).
// Return true iff all recognized flag values were parsed correctly, and the
// first remaining argument is not "--help".
static bool Parse(int *argc,
char **argv,
const std::vector<Flag> &flag_list);
// Return a usage message with command line cmdline, and the
// usage_text strings in flag_list[].
static string Usage(const string &cmdline,
const std::vector <Flag> &flag_list);
};
} // namespace mace
#endif // MACE_CORE_COMMAND_LINE_FLAGS_H
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册