提交 0c8ef13b 编写于 作者: 刘琦

Merge branch 'transform' into 'master'

Add tensorflow fc support

See merge request !496
...@@ -23,7 +23,8 @@ namespace mace { ...@@ -23,7 +23,8 @@ namespace mace {
ArgumentHelper::ArgumentHelper(const OperatorDef &def) { ArgumentHelper::ArgumentHelper(const OperatorDef &def) {
for (auto &arg : def.arg()) { for (auto &arg : def.arg()) {
if (arg_map_.find(arg.name()) != arg_map_.end()) { if (arg_map_.find(arg.name()) != arg_map_.end()) {
LOG(WARNING) << "Duplicated argument name found in operator def."; LOG(WARNING) << "Duplicated argument name found in operator def: "
<< def.name() << " " << arg.name();
} }
arg_map_[arg.name()] = arg; arg_map_[arg.name()] = arg;
......
...@@ -128,7 +128,7 @@ def main(unused_args): ...@@ -128,7 +128,7 @@ def main(unused_args):
FLAGS.weight_file) FLAGS.weight_file)
output_graph_def = converter.run() output_graph_def = converter.run()
print("Transform model to one that can better run on device.") print("Transform model to one that can better run on device")
if not FLAGS.runtime: if not FLAGS.runtime:
cpu_graph_def = copy.deepcopy(output_graph_def) cpu_graph_def = copy.deepcopy(output_graph_def)
option.device = mace_pb2.CPU option.device = mace_pb2.CPU
......
...@@ -136,23 +136,25 @@ class MaceKeyword(object): ...@@ -136,23 +136,25 @@ class MaceKeyword(object):
class TransformerRule(Enum): class TransformerRule(Enum):
REMOVE_IDENTITY_OP = 0 REMOVE_USELESS_RESHAPE_OP = 0
TRANSFORM_GLOBAL_POOLING = 1 REMOVE_IDENTITY_OP = 1
FOLD_SOFTMAX = 2 TRANSFORM_GLOBAL_POOLING = 2
FOLD_BATCHNORM = 3, FOLD_RESHAPE = 3
FOLD_CONV_AND_BN = 4, TRANSFORM_MATMUL_TO_FC = 4
FOLD_DEPTHWISE_CONV_AND_BN = 5, FOLD_BATCHNORM = 5
TRANSFORM_GPU_WINOGRAD = 6, FOLD_CONV_AND_BN = 6
TRANSFORM_ADD_TO_BIASADD = 7, FOLD_DEPTHWISE_CONV_AND_BN = 7
FOLD_BIASADD = 8, TRANSFORM_GPU_WINOGRAD = 8
FOLD_ACTIVATION = 9, TRANSFORM_ADD_TO_BIASADD = 9
TRANSPOSE_FILTERS = 10, FOLD_BIASADD = 10
RESHAPE_FC_WEIGHT = 11, FOLD_ACTIVATION = 11
TRANSPOSE_DATA_FORMAT = 12, TRANSPOSE_FILTERS = 12
TRANSFORM_GLOBAL_CONV_TO_FC = 13, RESHAPE_FC_WEIGHT = 13
TRANSFORM_BUFFER_IMAGE = 14, TRANSPOSE_DATA_FORMAT = 14
ADD_DEVICE_AND_DATA_TYPE = 15, TRANSFORM_GLOBAL_CONV_TO_FC = 15
SORT_BY_EXECUTION = 16 TRANSFORM_BUFFER_IMAGE = 16
ADD_DEVICE_AND_DATA_TYPE = 17
SORT_BY_EXECUTION = 18
class ConverterInterface(object): class ConverterInterface(object):
...@@ -199,9 +201,11 @@ class ConverterOption(object): ...@@ -199,9 +201,11 @@ class ConverterOption(object):
self._device = mace_pb2.CPU self._device = mace_pb2.CPU
self._winograd_enabled = False self._winograd_enabled = False
self._transformer_option = [ self._transformer_option = [
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX, TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM, TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN, TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
......
...@@ -101,9 +101,11 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -101,9 +101,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
'AvgPool': self.convert_pooling, 'AvgPool': self.convert_pooling,
'MaxPool': self.convert_pooling, 'MaxPool': self.convert_pooling,
'Squeeze': self.convert_identity, 'Squeeze': self.convert_identity,
'MatMul': self.convert_matmul,
'Identity': self.convert_identity, 'Identity': self.convert_identity,
'Reshape': self.convert_reshape, 'Reshape': self.convert_reshape,
'Shape': self.convert_nop, 'Shape': self.convert_nop,
'Transpose': self.convert_transpose,
'Softmax': self.convert_softmax, 'Softmax': self.convert_softmax,
'ResizeBilinear': self.convert_resize_bilinear, 'ResizeBilinear': self.convert_resize_bilinear,
'Placeholder': self.convert_nop, 'Placeholder': self.convert_nop,
...@@ -144,7 +146,8 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -144,7 +146,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
for i in xrange(len(op.input)): for i in xrange(len(op.input)):
if op.input[i][-2:] == ':0': if op.input[i][-2:] == ':0':
op_name = op.input[i][:-2] op_name = op.input[i][:-2]
if op_name in self._option.input_nodes: if op_name in self._option.input_nodes \
or op_name in self._option.output_nodes:
op.input[i] = op_name op.input[i] = op_name
for i in xrange(len(op.output)): for i in xrange(len(op.output)):
if op.output[i][-2:] == ':0': if op.output[i][-2:] == ':0':
...@@ -411,6 +414,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -411,6 +414,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
self._skip_tensor.update(tf_op.inputs[-1].name) self._skip_tensor.update(tf_op.inputs[-1].name)
def convert_matmul(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.MatMul.name
def convert_reshape(self, tf_op): def convert_reshape(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Reshape.name op.type = MaceOp.Reshape.name
...@@ -430,6 +437,20 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -430,6 +437,20 @@ class TensorflowConverter(base_converter.ConverterInterface):
shape_arg.ints.extend(shape_value) shape_arg.ints.extend(shape_value)
def convert_transpose(self, tf_op):
perm = tf_op.inputs[1].eval().astype(np.int32)
ordered_perm = np.sort(perm)
mace_check(np.array_equal(perm, ordered_perm),
"Transpose not supported yet, only internal transpose"
" in composed ops might be supported")
op = self.convert_general_op(tf_op)
op.type = 'Identity'
del op.input[1:]
self._skip_tensor.add(tf_op.inputs[1].name)
def convert_mean(self, tf_op): def convert_mean(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
del op.input[1:] del op.input[1:]
......
...@@ -53,9 +53,11 @@ class Transformer(base_converter.ConverterInterface): ...@@ -53,9 +53,11 @@ class Transformer(base_converter.ConverterInterface):
def __init__(self, option, model): def __init__(self, option, model):
# DO NOT reorder the following transformers # DO NOT reorder the following transformers
self._registered_transformers_order = [ self._registered_transformers_order = [
TransformerRule.REMOVE_USELESS_RESHAPE_OP,
TransformerRule.REMOVE_IDENTITY_OP, TransformerRule.REMOVE_IDENTITY_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.FOLD_SOFTMAX, TransformerRule.FOLD_RESHAPE,
TransformerRule.TRANSFORM_MATMUL_TO_FC,
TransformerRule.FOLD_BATCHNORM, TransformerRule.FOLD_BATCHNORM,
TransformerRule.FOLD_CONV_AND_BN, TransformerRule.FOLD_CONV_AND_BN,
TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN, TransformerRule.FOLD_DEPTHWISE_CONV_AND_BN,
...@@ -72,10 +74,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -72,10 +74,14 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.SORT_BY_EXECUTION, TransformerRule.SORT_BY_EXECUTION,
] ]
self._registered_transformers = { self._registered_transformers = {
TransformerRule.REMOVE_USELESS_RESHAPE_OP:
self.remove_useless_reshape_op,
TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op, TransformerRule.REMOVE_IDENTITY_OP: self.remove_identity_op,
TransformerRule.TRANSFORM_GLOBAL_POOLING: TransformerRule.TRANSFORM_GLOBAL_POOLING:
self.transform_global_pooling, self.transform_global_pooling,
TransformerRule.FOLD_SOFTMAX: self.fold_softmax, TransformerRule.FOLD_RESHAPE: self.fold_reshape,
TransformerRule.TRANSFORM_MATMUL_TO_FC:
self.transform_matmul_to_fc,
TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm, TransformerRule.FOLD_BATCHNORM: self.fold_batchnorm,
TransformerRule.FOLD_CONV_AND_BN: TransformerRule.FOLD_CONV_AND_BN:
self.fold_conv_and_bn, # data_format related self.fold_conv_and_bn, # data_format related
...@@ -161,18 +167,26 @@ class Transformer(base_converter.ConverterInterface): ...@@ -161,18 +167,26 @@ class Transformer(base_converter.ConverterInterface):
for output_tensor in op.output: for output_tensor in op.output:
self._producer[output_tensor] = op self._producer[output_tensor] = op
for input_node in self._option.input_nodes.values(): for input_node in self._option.input_nodes.values():
op = mace_pb2.OperatorDef() input_node_existed = False
op.name = self.normalize_op_name(input_node.name) for op in self._model.op:
op.type = 'Input' if input_node.name in op.output:
op.output.extend(input_node.name) input_node_existed = True
output_shape = op.output_shape.add() break
output_shape.dims.extend(input_node.shape) if not input_node_existed:
if self._option.device == mace_pb2.CPU: op = mace_pb2.OperatorDef()
self.transpose_shape(output_shape.dims, [0, 3, 1, 2]) op.name = self.normalize_op_name(input_node.name)
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW) op.type = 'Input'
else: op.output.extend([input_node.name])
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC) output_shape = op.output_shape.add()
self._producer[op.output[0]] = op output_shape.dims.extend(input_node.shape)
if ConverterUtil.data_format(
self._consumers[input_node.name][0]) \
== DataFormat.NCHW:
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
else:
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
self._producer[op.output[0]] = op
@staticmethod @staticmethod
def replace(obj_list, source, target): def replace(obj_list, source, target):
...@@ -191,6 +205,12 @@ class Transformer(base_converter.ConverterInterface): ...@@ -191,6 +205,12 @@ class Transformer(base_converter.ConverterInterface):
def normalize_op_name(name): def normalize_op_name(name):
return name.replace(':', '_') return name.replace(':', '_')
def get_tensor_shape(self, tensor):
producer = self._producer[tensor]
for i in xrange(len(producer.output)):
if producer.output[i] == tensor:
return list(producer.output_shape[i].dims)
def consumer_count(self, tensor_name): def consumer_count(self, tensor_name):
return len(self._consumers.get(tensor_name, [])) return len(self._consumers.get(tensor_name, []))
...@@ -203,23 +223,68 @@ class Transformer(base_converter.ConverterInterface): ...@@ -203,23 +223,68 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def replace_output_node(self, op): def safe_remove_node(self, op, replace_op):
"""if it is an output node, change output node to the op before it""" """remove op.
if self.is_op_output_node(op): 1. change the inputs of its consumers to the outputs of replace_op
real_output_node = self._producer[op.input[0]] 2. if the op is output node, change output node to replace op"""
self.replace(real_output_node.output, op.input[0], op.output[0])
print("change %s to %s" % (real_output_node.name, op.name)) if replace_op is None:
# When no replace op specified, we change the inputs of
# its consumers to the input of the op. This handles the case
# that the op is identity op and its input is a tensor.
mace_check(len(op.output) == 1 and len(op.input) == 1,
"cannot remove op that w/o replace op specified"
" and input/output length > 1" + str(op))
for consumer_op in self._consumers.get(op.output[0], []):
self.replace(consumer_op.input, op.output[0], op.input[0])
mace_check(op.output[0] not in self._option.output_nodes,
"cannot remove op that is output node")
else:
mace_check(len(op.output) == len(replace_op.output),
"cannot remove op since len(op.output) "
"!= len(replace_op.output)")
for i in xrange(len(op.output)):
for consumer_op in self._consumers.get(op.output[i], []):
self.replace(consumer_op.input,
op.output[i],
replace_op.output[i])
# if the op is output node, change replace_op output name to the op
# output name
for i in xrange(len(op.output)):
if op.output[i] in self._option.output_nodes:
for consumer in self._consumers.get(
replace_op.output[i], []):
self.replace(consumer.input,
replace_op.output[i],
op.output[i])
replace_op.output[i] = op.output[i]
self._model.op.remove(op)
def remove_useless_reshape_op(self):
net = self._model
for op in net.op:
if op.type == MaceOp.Reshape.name:
shape = list(ConverterUtil.get_arg(
op, MaceKeyword.mace_shape_str).ints)
if shape == self.get_tensor_shape(op.input[0]):
print("Remove useless reshape: %s(%s)"
% (op.name, op.type))
op.type = 'Identity'
return False
def remove_identity_op(self): def remove_identity_op(self):
net = self._model net = self._model
for op in net.op: for op in net.op:
if op.type == 'Identity': if op.type == 'Identity':
print("Remove identity: %s(%s)" % (op.name, op.type)) print("Remove identity: %s(%s)" % (op.name, op.type))
for consumer_op in self._consumers.get(op.output[0], []): self.safe_remove_node(op,
Transformer.replace(consumer_op.input, op.output[0], self._producer.get(op.input[0], None))
op.input[0])
self.replace_output_node(op)
net.op.remove(op)
return True return True
return False return False
...@@ -264,10 +329,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -264,10 +329,10 @@ class Transformer(base_converter.ConverterInterface):
and len(self._consts[consumer_op.input[1]].dims) == 1: and len(self._consts[consumer_op.input[1]].dims) == 1:
print("Fold batchnorm: %s(%s)" % (op.name, op.type)) print("Fold batchnorm: %s(%s)" % (op.name, op.type))
consumer_op.type = MaceOp.FoldedBatchNorm.name consumer_op.type = MaceOp.FoldedBatchNorm.name
inputs = [op.input[0], op.input[1], consumer_op.input[1]] consumer_op.input[:] = [op.input[0], op.input[1],
consumer_op.input[:] = inputs[:] consumer_op.input[1]]
net.op.remove(op) self.safe_remove_node(op, None)
return True return True
return False return False
...@@ -514,7 +579,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -514,7 +579,7 @@ class Transformer(base_converter.ConverterInterface):
filter.float_data[:] = weight_tensor_value.flat[:] filter.float_data[:] = weight_tensor_value.flat[:]
filter.dims[:] = weight_tensor_value.shape[:] filter.dims[:] = weight_tensor_value.shape[:]
net.op.remove(op) self.safe_remove_node(op, iwt_op)
return False return False
...@@ -544,10 +609,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -544,10 +609,8 @@ class Transformer(base_converter.ConverterInterface):
consumer_op = self._consumers[op.output[0]][0] consumer_op = self._consumers[op.output[0]][0]
if consumer_op.type == MaceOp.BiasAdd.name: if consumer_op.type == MaceOp.BiasAdd.name:
print("Fold biasadd: %s(%s)" % (op.name, op.type)) print("Fold biasadd: %s(%s)" % (op.name, op.type))
op.name = consumer_op.name
op.input.append(consumer_op.input[1]) op.input.append(consumer_op.input[1])
op.output[0] = consumer_op.output[0] self.safe_remove_node(consumer_op, op)
net.op.remove(consumer_op)
return True return True
return False return False
...@@ -575,7 +638,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -575,7 +638,7 @@ class Transformer(base_converter.ConverterInterface):
or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa or arg.name == MaceKeyword.mace_activation_max_limit_str: # noqa
op.arg.extend([arg]) op.arg.extend([arg])
net.op.remove(consumer_op) self.safe_remove_node(consumer_op, op)
return True return True
return False return False
...@@ -651,11 +714,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -651,11 +714,14 @@ class Transformer(base_converter.ConverterInterface):
op.output.extend([input_node.name]) op.output.extend([input_node.name])
output_shape = op.output_shape.add() output_shape = op.output_shape.add()
output_shape.dims.extend(input_node.shape) output_shape.dims.extend(input_node.shape)
self.transpose_shape(output_shape.dims, [0, 3, 1, 2])
dims_arg = op.arg.add() dims_arg = op.arg.add()
dims_arg.name = MaceKeyword.mace_dims_str dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 3, 1, 2]) dims_arg.ints.extend([0, 3, 1, 2])
ConverterUtil.add_data_format_arg(op, DataFormat.NCHW)
for output_node in self._option.output_nodes.values(): for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \ output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name + '_' + output_node.name
...@@ -673,6 +739,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -673,6 +739,8 @@ class Transformer(base_converter.ConverterInterface):
dims_arg.name = MaceKeyword.mace_dims_str dims_arg.name = MaceKeyword.mace_dims_str
dims_arg.ints.extend([0, 2, 3, 1]) dims_arg.ints.extend([0, 2, 3, 1])
ConverterUtil.add_data_format_arg(op, DataFormat.NHWC)
return False return False
def transpose_filters(self): def transpose_filters(self):
...@@ -695,12 +763,19 @@ class Transformer(base_converter.ConverterInterface): ...@@ -695,12 +763,19 @@ class Transformer(base_converter.ConverterInterface):
filter_data = filter_data.transpose(3, 2, 0, 1) filter_data = filter_data.transpose(3, 2, 0, 1)
filter.float_data[:] = filter_data.flat filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape filter.dims[:] = filter_data.shape
if op.type == MaceOp.FullyConnected.name:
weight = self._consts[op.input[1]]
weight_data = np.array(weight.float_data).reshape(
weight.dims)
weight_data = weight_data.transpose(1, 0)
weight.float_data[:] = weight_data.flat
weight.dims[:] = weight_data.shape
self.set_filter_format(FilterFormat.OIHW) self.set_filter_format(FilterFormat.OIHW)
return False return False
def reshape_fc_weight(self): def reshape_fc_weight(self):
print("Reshape fully connecrted weight shape") print("Reshape fully connected weight shape")
net = self._model net = self._model
for op in net.op: for op in net.op:
if op.type == MaceOp.FullyConnected.name: if op.type == MaceOp.FullyConnected.name:
...@@ -789,6 +864,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -789,6 +864,8 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
for output_node in self._option.output_nodes.values(): for output_node in self._option.output_nodes.values():
output_name = MaceKeyword.mace_output_node_name \ output_name = MaceKeyword.mace_output_node_name \
+ '_' + output_node.name + '_' + output_node.name
...@@ -804,14 +881,16 @@ class Transformer(base_converter.ConverterInterface): ...@@ -804,14 +881,16 @@ class Transformer(base_converter.ConverterInterface):
arg.name = MaceKeyword.mace_buffer_type arg.name = MaceKeyword.mace_buffer_type
arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value arg.i = OpenCLBufferType.IN_OUT_CHANNEL.value
ConverterUtil.add_data_format_arg(op_def, DataFormat.NHWC)
return False return False
def fold_softmax(self): def fold_reshape(self):
changed = False changed = False
net = self._model net = self._model
for op in net.op: for op in net.op:
if op.type == MaceOp.Softmax.name: if op.type == MaceOp.Softmax.name or op.type == MaceOp.MatMul.name:
print("Fold softmax: %s(%s)" % (op.name, op.type)) print("Fold reshape: %s(%s)" % (op.name, op.type))
if self.consumer_count(op.output[0]) == 1: if self.consumer_count(op.output[0]) == 1:
consumer = self._consumers[op.output[0]][0] consumer = self._consumers[op.output[0]][0]
if consumer.type == MaceOp.Reshape.name: if consumer.type == MaceOp.Reshape.name:
...@@ -819,15 +898,14 @@ class Transformer(base_converter.ConverterInterface): ...@@ -819,15 +898,14 @@ class Transformer(base_converter.ConverterInterface):
MaceKeyword.mace_shape_str).ints # noqa MaceKeyword.mace_shape_str).ints # noqa
del op.output_shape[0].dims[:] del op.output_shape[0].dims[:]
op.output_shape[0].dims.extend(shape) op.output_shape[0].dims.extend(shape)
self.replace_output_node(consumer) self.safe_remove_node(consumer, op)
net.op.remove(consumer)
changed = True changed = True
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
if producer.type == MaceOp.Reshape.name: if producer.type == MaceOp.Reshape.name:
op.input[0] = producer.input[0] self.safe_remove_node(producer,
self.replace_output_node(producer) self._producer[
net.op.remove(producer) producer.input[0]])
changed = True changed = True
if len(op.output_shape[0].dims) < 4: if len(op.output_shape[0].dims) < 4:
...@@ -840,6 +918,20 @@ class Transformer(base_converter.ConverterInterface): ...@@ -840,6 +918,20 @@ class Transformer(base_converter.ConverterInterface):
return False return False
def transform_matmul_to_fc(self):
net = self._model
for op in net.op:
if op.type == MaceOp.MatMul.name:
input_shape = self.get_tensor_shape(op.input[0])
_, h, w, _ = self.sort_feature_map_shape(input_shape,
ConverterUtil.data_format(self._producer[op.input[0]])) # noqa
if h == 1 and w == 1 and op.input[1] in self._consts:
weight = self._consts[op.input[1]]
if len(weight.dims) == 2:
op.type = MaceOp.FullyConnected.name
return False
def transform_global_conv_to_fc(self): def transform_global_conv_to_fc(self):
"""Transform global conv to fc should be placed after transposing """Transform global conv to fc should be placed after transposing
input/output and filter""" input/output and filter"""
...@@ -918,4 +1010,8 @@ class Transformer(base_converter.ConverterInterface): ...@@ -918,4 +1010,8 @@ class Transformer(base_converter.ConverterInterface):
del net.op[:] del net.op[:]
net.op.extend(sorted_nodes) net.op.extend(sorted_nodes)
print("Final ops:")
for op in net.op:
print("%s (%s)" % (op.name, op.type))
return False return False
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册