提交 5c7a519f 编写于 作者: 叶剑武

Merge branch 'scalar-input' into 'master'

Support scalar input

See merge request !1017
......@@ -67,12 +67,19 @@ def file_checksum(fname):
return hash_func.hexdigest()
def split_shape(shape):
if shape.strip() == "":
return []
else:
return shape.split(',')
def parse_int_array_from_str(ints_str):
return [int(int_str) for int_str in ints_str.split(',')]
return [int(i) for i in split_shape(ints_str)]
def parse_float_array_from_str(ints_str):
return [float(int_str) for int_str in ints_str.split(',')]
def parse_float_array_from_str(floats_str):
return [float(i) for i in floats_str.split(',')]
def transpose_shape(shape, dst_order):
......
......@@ -265,6 +265,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
tf_graph_def.ParseFromString(f.read())
self._placeholders = {}
self._skip_tensor = set()
self._output_shape = {}
print("Run transform_graph: %s" % TFTransformGraphOptions)
try:
......@@ -291,10 +293,16 @@ class TensorflowConverter(base_converter.ConverterInterface):
with session.graph.as_default() as graph:
tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph
self.update_output_shapes(session)
self._skip_tensor = set()
self._output_shape_list = []
self._output_shape_op_list = []
# we have polluted graph with 'shape' ops, so reset it and reload it
# again
tf.reset_default_graph()
with tf.Session() as session:
with session.graph.as_default() as graph:
tf.import_graph_def(transformed_graph_def, name='')
self._tf_graph = graph
def run(self):
with tf.Session() as session:
......@@ -339,10 +347,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
return tensor_name[:idx]
def update_output_shapes(self, sess):
output_shapes = sess.run(self._output_shape_op_list,
tensors = []
shape_tensors = []
for tf_op in self._tf_graph.get_operations():
for output in tf_op.outputs:
tensors.append(output.name)
shape_tensors.append(tf.shape(output))
tensor_shapes = sess.run(shape_tensors,
feed_dict=self._placeholders)
for i in range(len(self._output_shape_list)):
self._output_shape_list[i].dims.extend(output_shapes[i])
for i in range(len(tensors)):
self._output_shape[tensors[i]] = tensor_shapes[i]
def convert_ops(self, sess):
for tf_op in self._tf_graph.get_operations():
......@@ -350,7 +365,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
"Mace does not support tensorflow op type %s yet"
% tf_op.type)
self._op_converters[tf_op.type](tf_op)
self.update_output_shapes(sess)
self.convert_tensors()
def convert_tensors(self):
......@@ -384,18 +399,17 @@ class TensorflowConverter(base_converter.ConverterInterface):
# this function tries to infer tensor shape, but some dimension shape
# may be undefined due to variance of input length
def infer_tensor_shape(self, output_shape, tensor):
inferred_tensor_shape = tensor.shape.as_list()
inferred_success = True
for _, dim in enumerate(inferred_tensor_shape):
if dim is None:
inferred_success = False
break
if inferred_success:
output_shape.dims.extend(inferred_tensor_shape)
def infer_tensor_shape(self, tensor, output_shape=None):
shape = None
if tensor.name in self._output_shape:
shape = self._output_shape[tensor.name]
else:
self._output_shape_list.append(output_shape)
self._output_shape_op_list.append(tf.shape(tensor))
shape = tensor.shape.as_list()
if output_shape:
output_shape.dims.extend(shape)
return shape
def convert_nop(self, tf_op):
pass
......@@ -408,7 +422,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output.extend([tf_output.name for tf_output in tf_op.outputs])
for tf_output in tf_op.outputs:
output_shape = op.output_shape.add()
self.infer_tensor_shape(output_shape, tf_output)
self.infer_tensor_shape(tf_output, output_shape)
data_type_arg = op.arg.add()
data_type_arg.name = 'T'
......@@ -491,10 +505,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
def check_is_scalar(tf_op):
if len(tf_op.inputs) == 1:
return len(tf_op.inputs[0].shape) == 0
return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0
elif len(tf_op.inputs) == 2:
return len(tf_op.inputs[0].shape) == 0 and \
len(tf_op.inputs[1].shape) == 0
return len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
len(self.infer_tensor_shape(tf_op.inputs[1])) == 0
if check_is_scalar(tf_op):
op.type = MaceOp.ScalarMath.name
......@@ -521,9 +535,9 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType.SUM, EltwiseType.PROD,
EltwiseType.MAX, EltwiseType.MIN]
if len(tf_op.inputs) > 1 and \
len(tf_op.inputs[1].shape) == 0 and \
tf_op.inputs[1].op.type == TFOpType.Const.name:
if (len(tf_op.inputs) > 1 and
len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and
tf_op.inputs[1].op.type == TFOpType.Const.name):
scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
......@@ -535,7 +549,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_index_arg.i = 1
self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1]
elif len(tf_op.inputs[0].shape) == 0 and \
elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
tf_op.inputs[0].op.type == TFOpType.Const.name and \
is_commutative(type_arg.i):
scalar = tf_op.inputs[0].eval().astype(np.float32)
......
......@@ -275,7 +275,6 @@ bool RunModel(const std::string &model_name,
MemoryMap(FLAGS_model_data_file,
&model_weights_data,
&model_weights_data_size);
MACE_CHECK(model_weights_data != nullptr && model_weights_data_size != 0);
}
std::shared_ptr<mace::MaceEngine> engine;
......
......@@ -122,6 +122,9 @@ void MemoryMap(const std::string &file,
struct stat st;
fstat(fd, &st);
*size = static_cast<size_t>(st.st_size);
if (*size == 0) {
return;
}
*data = static_cast<const unsigned char *>(
mmap(nullptr, *size, PROT_READ, MAP_PRIVATE, fd, 0));
......@@ -135,7 +138,10 @@ void MemoryMap(const std::string &file,
void MemoryUnMap(const unsigned char *data,
const size_t &size) {
MACE_CHECK(data != nullptr && size > 0, "data is null or size is 0");
if (size == 0) {
return;
}
MACE_CHECK(data != nullptr, "data is null");
int ret = munmap(const_cast<unsigned char *>(data), size);
......
......@@ -531,3 +531,10 @@ class ToolchainType:
class TargetSOCTag:
all = 'all'
random = 'random'
def split_shape(shape):
if shape.strip() == "":
return []
else:
return shape.split(',')
......@@ -59,7 +59,7 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges,
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')]
shape = [int(x) for x in common.split_shape(input_shapes[i])]
input_range = [float(x) for x in input_ranges[i].split(',')]
generate_data(input_names[i], shape, input_file, input_range,
input_data_types[i])
......
......@@ -68,6 +68,8 @@ def calculate_similarity(u, v, data_type=np.float64):
def calculate_pixel_accuracy(out_value, mace_out_value):
if len(out_value.shape) < 2:
return 1.0
out_value = out_value.reshape((-1, out_value.shape[-1]))
batches = out_value.shape[0]
classes = out_value.shape[1]
......@@ -323,10 +325,10 @@ def validate(platform, model_file, weight_file, input_file, mace_out_file,
validation_outputs_data, log_file):
input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')]
input_shapes = [[int(x) for x in common.split_shape(shape)]
for shape in input_shape_strs]
output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')]
output_shapes = [[int(x) for x in common.split_shape(shape)]
for shape in output_shape_strs]
input_data_formats = [df for df in input_data_format_str.split(',')]
output_data_formats = [df for df in output_data_format_str.split(',')]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册