提交 f13461ef 编写于 作者: Y yejianwu

support multi fc

上级 8b9021f7
......@@ -256,6 +256,7 @@ class TransformerRule(Enum):
TRANSPOSE_MATMUL_WEIGHT = 34
FOLD_EMBEDDING_LOOKUP = 35
TRANSPOSE_CAFFE_RESHAPE_AND_FLATTEN = 36
FOLD_FC_RESHAPE = 37
class ConverterInterface(object):
......@@ -461,6 +462,7 @@ class ConverterOption(object):
TransformerRule.FOLD_SQRDIFF_MEAN,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
TransformerRule.FOLD_FC_RESHAPE,
# Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
......
......@@ -75,6 +75,8 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.TRANSPOSE_MATMUL_WEIGHT:
self.transpose_matmul_weight,
TransformerRule.FOLD_FC_RESHAPE:
self.fold_fc_reshape,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.ADD_WINOGRAD_ARG: self.add_winograd_arg,
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
......@@ -1227,11 +1229,24 @@ class Transformer(base_converter.ConverterInterface):
return True
return False
def is_after_fc(self, op):
while op.input[0] in self._producer:
producer = self._producer[op.input[0]]
if producer.type in [MaceOp.Activation.name, MaceOp.BiasAdd.name]:
op = producer
continue
elif producer.type == MaceOp.FullyConnected.name:
return True
else:
return False
return False
def transform_matmul_to_fc(self):
net = self._model
filter_format = self.filter_format()
for op in net.op:
# transform input(4D) -> reshape(2D) -> matmul to fc
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow
if op.type == MaceOp.Reshape.name and \
len(op.input) == 2 and \
......@@ -1268,6 +1283,21 @@ class Transformer(base_converter.ConverterInterface):
[weight_data.shape[1]]
return True
# transform `fc1(2D) -> matmul` to `fc1(2D) -> fc1(2D)`
if op.type == MaceOp.MatMul.name and \
filter_format == FilterFormat.HWIO:
producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]]
if len(weight.dims) == 2 and self.is_after_fc(op) and \
len(producer.output_shape[0].dims) == 2 and \
weight.dims[0] == producer.output_shape[0].dims[1]:
six.print_('convert matmul to fc')
op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = [1, 1] + list(weight_data.shape)
return True
return False
def update_float_op_data_type(self):
......@@ -1750,3 +1780,22 @@ class Transformer(base_converter.ConverterInterface):
shape_tensor.data_type = mace_pb2.DT_INT32
shape_tensor.int32_data.extend(dims)
op.input.append(shape_tensor.name)
def fold_fc_reshape(self):
net = self._model
for op in net.op:
# whether to reshape fc output(default 4D)
if op.type == MaceOp.FullyConnected.name:
consumers = self._consumers[op.output[0]]
op_output_shape = op.output_shape[0].dims[:]
for consumer in consumers:
if consumer.type == MaceOp.Reshape.name and \
consumer.input[1] in self._consts and \
self._consts[consumer.input[1]].int32_data[:] == \
[op_output_shape[0], 1, 1, op_output_shape[1]]:
# work for tensorflow
net.tensors.remove(self._consts[consumer.input[1]])
del consumer.input[1]
self.safe_remove_node(consumer, None)
return True
return False
......@@ -411,7 +411,7 @@ def format_model_config(flags):
ModuleName.YAML_CONFIG,
"'input_data_formats' must be in "
+ str(DataFormatStrs) + ", but got "
+ input_data_formats)
+ input_data_format)
else:
subgraph[YAMLKeyword.input_data_formats] = [DataFormat.NHWC]
......@@ -431,7 +431,7 @@ def format_model_config(flags):
subgraph[YAMLKeyword.output_data_formats]:
mace_check(output_data_format in DataFormatStrs,
ModuleName.YAML_CONFIG,
"'input_data_formats' must be in "
"'output_data_formats' must be in "
+ str(DataFormatStrs))
else:
subgraph[YAMLKeyword.output_data_formats] = [DataFormat.NHWC]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册