提交 a422aa26 编写于 作者: 李寅

Add tensorflow fc support

上级 cc3ea692
......@@ -23,7 +23,8 @@ namespace mace {
ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) {
LOG(WARNING) << "Duplicated argument name found in operator def.";
LOG(WARNING) << "Duplicated argument name found in operator def: "
<< def.name() << " " << arg.name();
}
arg_map_[arg.name()] = arg;
......
......@@ -128,7 +128,7 @@ def main(unused_args):
FLAGS.weight_file)
output_graph_def = converter.run()
print("Transform model to one that can better run on device.")
print("Transform model to one that can better run on device")
if not FLAGS.runtime:
cpu_graph_def = copy.deepcopy(output_graph_def)
option.device = mace_pb2.CPU
......
......@@ -136,23 +136,25 @@ class MaceKeyword(object):
class TransformerRule(Enum):
REMOVE_IDENTITY_OP = 0
TRANSFORM_GLOBAL_POOLING = 1
FOLD_SOFTMAX = 2
FOLD_BATCHNORM = 3,
FOLD_CONV_AND_BN = 4,
FOLD_DEPTHWISE_CONV_AND_BN = 5,
TRANSFORM_GPU_WINOGRAD = 6,
TRANSFORM_ADD_TO_BIASADD = 7,
FOLD_BIASADD = 8,
FOLD_ACTIVATION = 9,
TRANSPOSE_FILTERS = 10,
RESHAPE_FC_WEIGHT = 11,
TRANSPOSE_DATA_FORMAT = 12,
TRANSFORM_GLOBAL_CONV_TO_FC = 13,
TRANSFORM_BUFFER_IMAGE = 14,
ADD_DEVICE_AND_DATA_TYPE = 15,
SORT_BY_EXECUTION = 16
REMOVE_USELESS_RESHAPE_OP = 0
REMOVE_IDENTITY_OP = 1
TRANSFORM_GLOBAL_POOLING = 2
FOLD_RESHAPE = 3
TRANSFORM_MATMUL_TO_FC = 4
FOLD_BATCHNORM = 5
FOLD_CONV_AND_BN = 6
FOLD_DEPTHWISE_CONV_AND_BN = 7
TRANSFORM_GPU_WINOGRAD = 8
TRANSFORM_ADD_TO_BIASADD = 9
FOLD_BIASADD = 10
FOLD_ACTIVATION = 11
TRANSPOSE_FILTERS = 12
RESHAPE_FC_WEIGHT = 13
TRANSPOSE_DATA_FORMAT = 14
TRANSFORM_GLOBAL_CONV_TO_FC = 15
TRANSFORM_BUFFER_IMAGE = 16
ADD_DEVICE_AND_DATA_TYPE = 17
SORT_BY_EXECUTION = 18
class ConverterInterface(object):
......@@ -199,9 +201,11 @@ class ConverterOption(object):
self._device = mace_pb2.CPU
self._winograd_enabled = False
self._transformer_option = [
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
......
......@@ -101,9 +101,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
'AvgPool': self.convert_pooling,
'MaxPool': self.convert_pooling,
'Squeeze': self.convert_identity,
'MatMul': self.convert_matmul,
'Identity': self.convert_identity,
'Reshape': self.convert_reshape,
'Shape': self.convert_nop,
'Transpose': self.convert_transpose,
'Softmax': self.convert_softmax,
'ResizeBilinear': self.convert_resize_bilinear,
'Placeholder': self.convert_nop,
......@@ -144,7 +146,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
for i in xrange(len(op.input)):
if op.input[i][-2:] == ':0':
op_name = op.input[i][:-2]
if op_name in self._option.input_nodes:
if op_name in self._option.input_nodes \
or op_name in self._option.output_nodes:
op.input[i] = op_name
for i in xrange(len(op.output)):
if op.output[i][-2:] == ':0':
......@@ -411,6 +414,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor.update(tf_op.inputs[-1].name)
def convert_matmul(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.MatMul.name
def convert_reshape(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name
......@@ -430,6 +437,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
shape_arg.ints.extend(shape_value)
def convert_transpose(self, tf_op):
perm = tf_op.inputs[1].eval().astype(np.int32)
ordered_perm = np.sort(perm)
mace_check(np.array_equal(perm, ordered_perm),
"Transpose not supported yet, only internal transpose"
" in composed ops might be supported")
op = self.convert_general_op(tf_op)
op.type = 'Identity'
del op.input[1:]
self._skip_tensor.add(tf_op.inputs[1].name)
def convert_mean(self, tf_op):
op = self.convert_general_op(tf_op)
del op.input[1:]
......
......@@ -53,9 +53,11 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model):
# DO NOT reorder the following transformers
self._registered_transformers_order = [
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX,
TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
......@@ -72,10 +74,14 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION,
]
self._registered_transformers = {
TransformerRule.REMOVE_USELESS_RESHAPE_OP:
self.remove_useless_reshape_op,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling,
TransformerRule.FOLD_SOFTMAX: self.fold_softmax,
TransformerRule.FOLD_RESHAPE: self.fold_reshape,
TransformerRule.TRANSFORM_MATMUL_TO_FC:
self.transform_matmul_to_fc,
TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm,
TransformerRule.FOLD_CONV_AND_BN:
self.fold_conv_and_bn, # data_format related
......@@ -161,18 +167,26 @@ class Transformer(base_converter.ConverterInterface):
for output_tensor in op.output:
self._producer[output_tensor] = op
for input_node in self._option.input_nodes.values():
op = mace_pb2.OperatorDef()
op.name = self.normalize_op_name(input_node.name)
op.type = 'Input'
op.output.extend(input_node.name)
output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape)
if self._option.device == mace_pb2.CPU:
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
else:
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
self._producer[op.output[0]] = op
input_node_existed = False
for op in self._model.op:
if input_node.name in op.output:
input_node_existed = True
break
if not input_node_existed:
op = mace_pb2.OperatorDef()
op.name = self.normalize_op_name(input_node.name)
op.type = 'Input'
op.output.extend([input_node.name])
output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape)
if ConverterUtil.data_format(
self._consumers[input_node.name][0]) \
== DataFormat.NCHW:
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
else:
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
self._producer[op.output[0]] = op
@staticmethod
def replace(obj_list, source, target):
......@@ -191,6 +205,12 @@ class Transformer(base_converter.ConverterInterface):
def normalize_op_name(name):
return name.replace(':', '_')
def get_tensor_shape(self, tensor):
producer = self._producer[tensor]
for i in xrange(len(producer.output)):
if producer.output[i] == tensor:
return list(producer.output_shape[i].dims)
def consumer_count(self, tensor_name):
return len(self._consumers.get(tensor_name, []))
......@@ -203,23 +223,68 @@ class Transformer(base_converter.ConverterInterface):
return False
def replace_output_node(self, op):
"""if it is an output node, change output node to the op before it"""
if self.is_op_output_node(op):
real_output_node = self._producer[op.input[0]]
self.replace(real_output_node.output, op.input[0], op.output[0])
print("change %s to %s" % (real_output_node.name, op.name))
def safe_remove_node(self, op, replace_op):
"""remove op.
1. change the inputs of its consumers to the outputs of replace_op
2. if the op is output node, change output node to replace op"""
if replace_op is None:
# When no replace op specified, we change the inputs of
# its consumers to the input of the op. This handles the case
# that the op is identity op and its input is a tensor.
mace_check(len(op.output) == 1 and len(op.input) == 1,
"cannot remove op that w/o replace op specified"
" and input/output length > 1" + str(op))
for consumer_op in self._consumers.get(op.output[0], []):
self.replace(consumer_op.input, op.output[0], op.input[0])
mace_check(op.output[0] not in self._option.output_nodes,
"cannot remove op that is output node")
else:
mace_check(len(op.output) == len(replace_op.output),
"cannot remove op since len(op.output) "
"!= len(replace_op.output)")
for i in xrange(len(op.output)):
for consumer_op in self._consumers.get(op.output[i], []):
self.replace(consumer_op.input,
op.output[i],
replace_op.output[i])
# if the op is output node, change replace_op output name to the op
# output name
for i in xrange(len(op.output)):
if op.output[i] in self._option.output_nodes:
for consumer in self._consumers.get(
replace_op.output[i], []):
self.replace(consumer.input,
replace_op.output[i],
op.output[i])
replace_op.output[i] = op.output[i]
self._model.op.remove(op)
def remove_useless_reshape_op(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Reshape.name:
shape = list(ConverterUtil.get_arg(
op, MaceKeyword.mace_shape_str).ints)
if shape == self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)"
% (op.name, op.type))
op.type = 'Identity'
return False
def remove_identity_op(self):
net = self._model
for op in net.op:
if op.type == 'Identity':
print("Remove identity: %s(%s)" % (op.name, op.type))
for consumer_op in self._consumers.get(op.output[0], []):
Transformer.replace(consumer_op.input, op.output[0],
op.input[0])
self.replace_output_node(op)
net.op.remove(op)
self.safe_remove_node(op,
self._producer.get(op.input[0], None))
return True
return False
......@@ -264,10 +329,10 @@ class Transformer(base_converter.ConverterInterface):
and len(self._consts[consumer_op.input[1]].dims) == 1:
print("Fold batchnorm: %s(%s)" % (op.name, op.type))
consumer_op.type = MaceOp.FoldedBatchNorm.name
inputs = [op.input[0], op.input[1], consumer_op.input[1]]
consumer_op.input[:] = inputs[:]
consumer_op.input[:] = [op.input[0], op.input[1],
consumer_op.input[1]]
net.op.remove(op)
self.safe_remove_node(op, None)
return True
return False
......@@ -514,7 +579,7 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = weight_tensor_value.flat[:]
filter.dims[:] = weight_tensor_value.shape[:]
net.op.remove(op)
self.safe_remove_node(op, iwt_op)
return False
......@@ -544,10 +609,8 @@ class Transformer(base_converter.ConverterInterface):
consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BiasAdd.name:
print("Fold biasadd: %s(%s)" % (op.name, op.type))
op.name = consumer_op.name
op.input.append(consumer_op.input[1])
op.output[0] = consumer_op.output[0]
net.op.remove(consumer_op)
self.safe_remove_node(consumer_op, op)
return True
return False
......@@ -575,7 +638,7 @@ class Transformer(base_converter.ConverterInterface):
or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa
op.arg.extend([arg])
net.op.remove(consumer_op)
self.safe_remove_node(consumer_op, op)
return True
return False
......@@ -651,11 +714,14 @@ class Transformer(base_converter.ConverterInterface):
op.output.extend([input_node.name])
output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape)
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
......@@ -673,6 +739,8 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1])
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
return False
def transpose_filters(self):
......@@ -695,12 +763,19 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(1, 0)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
self.set_filter_format(FilterFormat.OIHW)
return False
def reshape_fc_weight(self):
print("Reshape fully connecrted weight shape")
print("Reshape fully connected weight shape")
net = self._model
for op in net.op:
if op.type == MaceOp.FullyConnected.name:
......@@ -789,6 +864,8 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name
......@@ -804,14 +881,16 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
return False
def fold_softmax(self):
def fold_reshape(self):
changed = False
net = self._model
for op in net.op:
if op.type == MaceOp.Softmax.name:
print("Fold softmax: %s(%s)" % (op.name, op.type))
if op.type == MaceOp.Softmax.name or op.type == MaceOp.MatMul.name:
print("Fold reshape: %s(%s)" % (op.name, op.type))
if self.consumer_count(op.output[0]) == 1:
consumer = self._consumers[op.output[0]][0]
if consumer.type == MaceOp.Reshape.name:
......@@ -819,15 +898,14 @@ class Transformer(base_converter.ConverterInterface):
MaceKeyword.mace_shape_str).ints # noqa
del op.output_shape[0].dims[:]
op.output_shape[0].dims.extend(shape)
self.replace_output_node(consumer)
net.op.remove(consumer)
self.safe_remove_node(consumer, op)
changed = True
producer = self._producer[op.input[0]]
if producer.type == MaceOp.Reshape.name:
op.input[0] = producer.input[0]
self.replace_output_node(producer)
net.op.remove(producer)
self.safe_remove_node(producer,
self._producer[
producer.input[0]])
changed = True
if len(op.output_shape[0].dims) < 4:
......@@ -840,6 +918,20 @@ class Transformer(base_converter.ConverterInterface):
return False
def transform_matmul_to_fc(self):
net = self._model
for op in net.op:
if op.type == MaceOp.MatMul.name:
input_shape = self.get_tensor_shape(op.input[0])
_, h, w, _ = self.sort_feature_map_shape(input_shape,
ConverterUtil.data_format(self._producer[op.input[0]])) # noqa
if h == 1 and w == 1 and op.input[1] in self._consts:
weight = self._consts[op.input[1]]
if len(weight.dims) == 2:
op.type = MaceOp.FullyConnected.name
return False
def transform_global_conv_to_fc(self):
"""Transform global conv to fc should be placed after transposing
input/output and filter"""
......@@ -918,4 +1010,8 @@ class Transformer(base_converter.ConverterInterface):
del net.op[:]
net.op.extend(sorted_nodes)
print("Final ops:")
for op in net.op:
print("%s (%s)" % (op.name, op.type))
return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册