提交 5c7a519f 编写于 作者: 叶剑武

Merge branch 'scalar-input' into 'master'

Support scalar input

See merge request !1017
...@@ -67,12 +67,19 @@ def file_checksum(fname): ...@@ -67,12 +67,19 @@ def file_checksum(fname):
return hash_func.hexdigest() return hash_func.hexdigest()
def split_shape(shape):
if shape.strip() == "":
return []
else:
return shape.split(',')
def parse_int_array_from_str(ints_str): def parse_int_array_from_str(ints_str):
return [int(int_str) for int_str in ints_str.split(',')] return [int(i) for i in split_shape(ints_str)]
def parse_float_array_from_str(ints_str): def parse_float_array_from_str(floats_str):
return [float(int_str) for int_str in ints_str.split(',')] return [float(i) for i in floats_str.split(',')]
def transpose_shape(shape, dst_order): def transpose_shape(shape, dst_order):
......
...@@ -265,6 +265,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -265,6 +265,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def.ParseFromString(f.read()) tf_graph_def.ParseFromString(f.read())
self._placeholders = {} self._placeholders = {}
self._skip_tensor = set()
self._output_shape = {}
print("Run transform_graph: %s" % TFTransformGraphOptions) print("Run transform_graph: %s" % TFTransformGraphOptions)
try: try:
...@@ -291,10 +293,16 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -291,10 +293,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
with session.graph.as_default() as graph: with session.graph.as_default() as graph:
tf.import_graph_def(transformed_graph_def, name='') tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph self._tf_graph = graph
self.update_output_shapes(session)
self._skip_tensor = set() # we have polluted graph with 'shape' ops, so reset it and reload it
self._output_shape_list = [] # again
self._output_shape_op_list = [] tf.reset_default_graph()
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph
def run(self): def run(self):
with tf.Session() as session: with tf.Session() as session:
...@@ -339,10 +347,17 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -339,10 +347,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
return tensor_name[:idx] return tensor_name[:idx]
def update_output_shapes(self, sess): def update_output_shapes(self, sess):
output_shapes = sess.run(self._output_shape_op_list, tensors = []
shape_tensors = []
for tf_op in self._tf_graph.get_operations():
for output in tf_op.outputs:
tensors.append(output.name)
shape_tensors.append(tf.shape(output))
tensor_shapes = sess.run(shape_tensors,
feed_dict=self._placeholders) feed_dict=self._placeholders)
for i in range(len(self._output_shape_list)): for i in range(len(tensors)):
self._output_shape_list[i].dims.extend(output_shapes[i]) self._output_shape[tensors[i]] = tensor_shapes[i]
def convert_ops(self, sess): def convert_ops(self, sess):
for tf_op in self._tf_graph.get_operations(): for tf_op in self._tf_graph.get_operations():
...@@ -350,7 +365,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -350,7 +365,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
"Mace does not support tensorflow op type %s yet" "Mace does not support tensorflow op type %s yet"
% tf_op.type) % tf_op.type)
self._op_converters[tf_op.type](tf_op) self._op_converters[tf_op.type](tf_op)
self.update_output_shapes(sess)
self.convert_tensors() self.convert_tensors()
def convert_tensors(self): def convert_tensors(self):
...@@ -384,18 +399,17 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -384,18 +399,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape # this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length # may be undefined due to variance of input length
def infer_tensor_shape(self, output_shape, tensor): def infer_tensor_shape(self, tensor, output_shape=None):
inferred_tensor_shape = tensor.shape.as_list() shape = None
inferred_success = True if tensor.name in self._output_shape:
for _, dim in enumerate(inferred_tensor_shape): shape = self._output_shape[tensor.name]
if dim is None:
inferred_success = False
break
if inferred_success:
output_shape.dims.extend(inferred_tensor_shape)
else: else:
self._output_shape_list.append(output_shape) shape = tensor.shape.as_list()
self._output_shape_op_list.append(tf.shape(tensor))
if output_shape:
output_shape.dims.extend(shape)
return shape
def convert_nop(self, tf_op): def convert_nop(self, tf_op):
pass pass
...@@ -408,7 +422,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -408,7 +422,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output.extend([tf_output.name for tf_output in tf_op.outputs]) op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs: for tf_output in tf_op.outputs:
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
self.infer_tensor_shape(output_shape, tf_output) self.infer_tensor_shape(tf_output, output_shape)
data_type_arg = op.arg.add() data_type_arg = op.arg.add()
data_type_arg.name = 'T' data_type_arg.name = 'T'
...@@ -491,10 +505,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -491,10 +505,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
def check_is_scalar(tf_op): def check_is_scalar(tf_op):
if len(tf_op.inputs) == 1: if len(tf_op.inputs) == 1:
return len(tf_op.inputs[0].shape) == 0 return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0
elif len(tf_op.inputs) == 2: elif len(tf_op.inputs) == 2:
return len(tf_op.inputs[0].shape) == 0 and \ return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
len(tf_op.inputs[1].shape) == 0 len(self.infer_tensor_shape(tf_op.inputs[1])) == 0
if check_is_scalar(tf_op): if check_is_scalar(tf_op):
op.type = MaceOp.ScalarMath.name op.type = MaceOp.ScalarMath.name
...@@ -521,9 +535,9 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -521,9 +535,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType.SUM, EltwiseType.PROD, EltwiseType.SUM, EltwiseType.PROD,
EltwiseType.MAX, EltwiseType.MIN] EltwiseType.MAX, EltwiseType.MIN]
if len(tf_op.inputs) > 1 and \ if (len(tf_op.inputs) > 1 and
len(tf_op.inputs[1].shape) == 0 and \ len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and
tf_op.inputs[1].op.type == TFOpType.Const.name: tf_op.inputs[1].op.type == TFOpType.Const.name):
scalar = tf_op.inputs[1].eval().astype(np.float32) scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
...@@ -535,7 +549,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -535,7 +549,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_index_arg.i = 1 value_index_arg.i = 1
self._skip_tensor.add(tf_op.inputs[1].name) self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1] del op.input[1]
elif len(tf_op.inputs[0].shape) == 0 and \ elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
tf_op.inputs[0].op.type == TFOpType.Const.name and \ tf_op.inputs[0].op.type == TFOpType.Const.name and \
is_commutative(type_arg.i): is_commutative(type_arg.i):
scalar = tf_op.inputs[0].eval().astype(np.float32) scalar = tf_op.inputs[0].eval().astype(np.float32)
......
...@@ -275,7 +275,6 @@ bool RunModel(const std::string &model_name, ...@@ -275,7 +275,6 @@ bool RunModel(const std::string &model_name,
MemoryMap(FLAGS_model_data_file, MemoryMap(FLAGS_model_data_file,
&model_weights_data, &model_weights_data,
&model_weights_data_size); &model_weights_data_size);
MACE_CHECK(model_weights_data != nullptr && model_weights_data_size != 0);
} }
std::shared_ptr<mace::MaceEngine> engine; std::shared_ptr<mace::MaceEngine> engine;
......
...@@ -122,6 +122,9 @@ void MemoryMap(const std::string &file, ...@@ -122,6 +122,9 @@ void MemoryMap(const std::string &file,
struct stat st; struct stat st;
fstat(fd, &st); fstat(fd, &st);
*size = static_cast<size_t>(st.st_size); *size = static_cast<size_t>(st.st_size);
if (*size == 0) {
return;
}
*data = static_cast<const unsigned char *>( *data = static_cast<const unsigned char *>(
mmap(nullptr, *size, PROT_READ, MAP_PRIVATE, fd, 0)); mmap(nullptr, *size, PROT_READ, MAP_PRIVATE, fd, 0));
...@@ -135,7 +138,10 @@ void MemoryMap(const std::string &file, ...@@ -135,7 +138,10 @@ void MemoryMap(const std::string &file,
void MemoryUnMap(const unsigned char *data, void MemoryUnMap(const unsigned char *data,
const size_t &size) { const size_t &size) {
MACE_CHECK(data != nullptr && size > 0, "data is null or size is 0"); if (size == 0) {
return;
}
MACE_CHECK(data != nullptr, "data is null");
int ret = munmap(const_cast<unsigned char *>(data), size); int ret = munmap(const_cast<unsigned char *>(data), size);
......
...@@ -531,3 +531,10 @@ class ToolchainType: ...@@ -531,3 +531,10 @@ class ToolchainType:
class TargetSOCTag: class TargetSOCTag:
all = 'all' all = 'all'
random = 'random' random = 'random'
def split_shape(shape):
if shape.strip() == "":
return []
else:
return shape.split(',')
...@@ -59,7 +59,7 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges, ...@@ -59,7 +59,7 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges,
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)): for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')] shape = [int(x) for x in common.split_shape(input_shapes[i])]
input_range = [float(x) for x in input_ranges[i].split(',')] input_range = [float(x) for x in input_ranges[i].split(',')]
generate_data(input_names[i], shape, input_file, input_range, generate_data(input_names[i], shape, input_file, input_range,
input_data_types[i]) input_data_types[i])
......
...@@ -68,6 +68,8 @@ def calculate_similarity(u, v, data_type=np.float64): ...@@ -68,6 +68,8 @@ def calculate_similarity(u, v, data_type=np.float64):
def calculate_pixel_accuracy(out_value, mace_out_value): def calculate_pixel_accuracy(out_value, mace_out_value):
if len(out_value.shape) < 2:
return 1.0
out_value = out_value.reshape((-1, out_value.shape[-1])) out_value = out_value.reshape((-1, out_value.shape[-1]))
batches = out_value.shape[0] batches = out_value.shape[0]
classes = out_value.shape[1] classes = out_value.shape[1]
...@@ -323,10 +325,10 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file, ...@@ -323,10 +325,10 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
validation_outputs_data, log_file): validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in common.split_shape(shape)]
for shape in input_shape_strs] for shape in input_shape_strs]
output_shape_strs = [shape for shape in output_shape.split(':')] output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in common.split_shape(shape)]
for shape in output_shape_strs] for shape in output_shape_strs]
input_data_formats = [df for df in input_data_format_str.split(',')] input_data_formats = [df for df in input_data_format_str.split(',')]
output_data_formats = [df for df in output_data_format_str.split(',')] output_data_formats = [df for df in output_data_format_str.split(',')]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册