提交 7a9ee4ca 编写于 作者: 李寅

Merge branch 'fix-tf-tf' into 'master'

Fix bug: transform fc of tensorflow.

See merge request !683
......@@ -578,7 +578,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis = tf_op.inputs[-1].eval().astype(np.int32)
axis = 4 + axis if axis < 0 else axis
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
axis_arg.i = axis
self._skip_tensor.add(tf_op.inputs[-1].name)
......
......@@ -751,6 +751,15 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at "
"channel dimension")
arg.i = 3
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1 \
and self._target_data_format == DataFormat.NHWC: # noqa
axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
......@@ -938,7 +947,10 @@ class Transformer(base_converter.ConverterInterface):
input_shape = list(input_op.output_shape[0].dims)
input_data_format = ConverterUtil.data_format(input_op)
weight.dims[:] = [weight.dims[0]] + input_shape[1:]
if input_data_format == DataFormat.NHWC:
if len(input_shape) == 2:
weight.dims[:] = weight.dims[:] + [1, 1]
if input_data_format == DataFormat.NHWC and \
len(input_shape) == 4:
self.transpose_shape(weight.dims, [0, 3, 1, 2])
return False
......@@ -1113,31 +1125,48 @@ class Transformer(base_converter.ConverterInterface):
net = self._model
filter_format = self.filter_format()
for op in net.op:
# transform reshape + matmul -> fc
# transform input(4D) -> reshape(2D) -> matmul to fc
# work for TensorFlow
if op.type == MaceOp.MatMul.name and \
if op.type == MaceOp.Reshape.name and \
op.input[1] in self._consts and \
len(op.output_shape[0].dims) == 2 and \
filter_format == FilterFormat.HWIO:
producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]]
if len(weight.dims) == 2 \
and producer.type == MaceOp.Reshape.name \
and len(producer.output) == 1 \
and producer.input[1] in self._consts \
and len(producer.output_shape[0].dims) == 2:
input_op = self._producer[producer.input[0]]
input_op = self._producer[op.input[0]]
input_shape = input_op.output_shape[0].dims
feature_size = np.prod(input_shape[1:])
self.safe_remove_node(producer, input_op,
remove_input_tensor=True)
if feature_size == producer.output_shape[0].dims[1]:
# check input op
if len(input_shape) == 4 and \
np.prod(input_shape[1:]) == op.output_shape[0].dims[1]:
is_fc = True
consumers = self._consumers[op.output[0]]
# check matmul op
for matmul_op in consumers:
if matmul_op.type != MaceOp.MatMul.name:
is_fc = False
else:
weight = self._consts[matmul_op.input[1]]
if len(weight.dims) != 2 or \
weight.dims[0] != op.output_shape[0].dims[1]:
is_fc = False
if is_fc:
print 'convert reshape and matmul to fc'
op.type = MaceOp.FullyConnected.name
self.safe_remove_node(op, input_op,
remove_input_tensor=True)
for matmul_op in consumers:
weight = self._consts[matmul_op.input[1]]
matmul_op.type = MaceOp.FullyConnected.name
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight.dims[:] = input_shape[1:] + \
[weight_data.shape[1]]
return True
elif len(weight.dims) == 2 and \
# transform input(2D) -> matmul to fc
if op.type == MaceOp.MatMul.name and \
filter_format == FilterFormat.HWIO:
producer = self._producer[op.input[0]]
weight = self._consts[op.input[1]]
if len(weight.dims) == 2 and \
producer.type != MaceOp.Reshape.name and \
len(producer.output_shape[0].dims) == 2 and \
weight.dims[0] == producer.output_shape[0].dims[1]:
print 'convert matmul to fc'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册