提交 92404b91 编写于 作者: L liuqi

Add transpose weight of Matmul when transform model.

上级 70a41900
......@@ -195,6 +195,7 @@ class TransformerRule(Enum):
QUANTIZE_NODES = 23
ADD_QUANTIZE_TENSOR_RANGE = 24
QUANTIZE_WEIGHTS = 25
TRANSPOSE_MATMUL_WEIGHT = 26
class ConverterInterface(object):
......@@ -345,6 +346,8 @@ class ConverterOption(object):
TransformerRule.FOLD_ACTIVATION,
TransformerRule.TRANSFORM_GLOBAL_CONV_TO_FC,
TransformerRule.RESHAPE_FC_WEIGHT,
# Transpose the weight of matmul if necessary
TransformerRule.TRANSPOSE_MATMUL_WEIGHT,
# Model data format related transformation
TransformerRule.TRANSPOSE_FILTERS,
TransformerRule.TRANSPOSE_DATA_FORMAT,
......
......@@ -64,6 +64,9 @@ class Transformer(base_converter.ConverterInterface):
TransformerRule.FOLD_BIASADD: self.fold_biasadd,
TransformerRule.FLATTEN_ATROUS_CONV: self.flatten_atrous_conv,
TransformerRule.FOLD_ACTIVATION: self.fold_activation,
# TODO(liuqi): should move to transpose_filter
TransformerRule.TRANSPOSE_MATMUL_WEIGHT:
self.transpose_matmul_weight,
TransformerRule.TRANSPOSE_FILTERS: self.transpose_filters,
TransformerRule.TRANSPOSE_DATA_FORMAT: self.transpose_data_format,
TransformerRule.ADD_IN_OUT_TENSOR_INFO:
......@@ -954,6 +957,29 @@ class Transformer(base_converter.ConverterInterface):
return False
def transpose_matmul_weight(self):
if self._option.device != DeviceType.CPU.value:
return False
net = self._model
transpose_arg_names = [MaceKeyword.mace_transpose_a_str,
MaceKeyword.mace_transpose_b_str]
for op in net.op:
if op.type == MaceOp.MatMul.name: # noqa
for i in range(len(op.input)):
input = op.input[i]
if input in self._consts \
and len(self._consts[input].dims) == 2:
arg = ConverterUtil.get_arg(op, transpose_arg_names[i])
if arg is not None and arg.i == 1:
print 'convert matmul'
filter = self._consts[input]
filter_data = np.array(filter.float_data).reshape(
filter.dims)
filter_data = filter_data.transpose(1, 0)
filter.float_data[:] = filter_data.flat
filter.dims[:] = filter_data.shape
arg.i = 0
def transpose_filters(self):
net = self._model
filter_format = self.filter_format()
......
......@@ -457,6 +457,7 @@ def format_model_config(flags):
"'%s' is necessary in subgraph" % key)
if not isinstance(value, list):
subgraph[key] = [value]
subgraph[key] = [str(v) for v in subgraph[key]]
input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
if input_data_types:
......@@ -507,6 +508,8 @@ def format_model_config(flags):
subgraph[YAMLKeyword.input_ranges] = [input_ranges]
else:
subgraph[YAMLKeyword.input_ranges] = input_ranges
subgraph[YAMLKeyword.input_ranges] =\
[str(v) for v in subgraph[YAMLKeyword.input_ranges]]
for key in [YAMLKeyword.limit_opencl_kernel_time,
YAMLKeyword.nnlib_graph_mode,
......
......@@ -44,19 +44,22 @@ def generate_input_data(input_file, input_node, input_shape, input_ranges,
input_data_type):
input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in input_shape.split(':')]
if input_ranges:
input_ranges = [r for r in input_ranges.split(':')]
else:
input_ranges = [[-1, 1]] * len(input_names)
input_ranges = ["-1,1"] * len(input_names)
if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')]
generate_data(input_names[i], shape, input_file, input_ranges[i],
input_range = [float(x) for x in input_ranges[i].split(',')]
generate_data(input_names[i], shape, input_file, input_range,
input_data_types[i])
print "Generate input file done."
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册