提交 041a5d2e 编写于 作者: L luxuhui

feat: support splitV op for tensorflow

N/A
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 a28e8128
...@@ -220,8 +220,8 @@ class DeviceWrapper: ...@@ -220,8 +220,8 @@ class DeviceWrapper:
"MACE_LOG_TENSOR_RANGE=%d" % (1 if quantize_stat else 0), "MACE_LOG_TENSOR_RANGE=%d" % (1 if quantize_stat else 0),
"%s/%s" % (target_dir, target_name), "%s/%s" % (target_dir, target_name),
"--model_name=%s" % model_tag, "--model_name=%s" % model_tag,
"--input_node='%s'" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats), "--input_data_format=%s" % ",".join(input_data_formats),
......
...@@ -236,6 +236,7 @@ class MaceKeyword(object): ...@@ -236,6 +236,7 @@ class MaceKeyword(object):
mace_end_axis_str = 'end_axis' mace_end_axis_str = 'end_axis'
mace_num_axes_str = 'num_axes' mace_num_axes_str = 'num_axes'
mace_num_split_str = 'num_split' mace_num_split_str = 'num_split'
mace_size_splits_str = 'size_splits'
mace_keepdims_str = 'keepdims' mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape' mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed' mace_winograd_filter_transformed = 'is_filter_transformed'
...@@ -548,6 +549,7 @@ class ConverterOption(object): ...@@ -548,6 +549,7 @@ class ConverterOption(object):
# Model structure related transformation # Model structure related transformation
TransformerRule.REMOVE_USELESS_OP, TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_FAKE_QUANTIZE, TransformerRule.TRANSFORM_FAKE_QUANTIZE,
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL, TransformerRule.TRANSFORM_BASIC_LSTMCELL,
......
...@@ -166,7 +166,7 @@ class MegengineConverter(base_converter.ConverterInterface): ...@@ -166,7 +166,7 @@ class MegengineConverter(base_converter.ConverterInterface):
if "," in op.input[i]: if "," in op.input[i]:
op_name = op.input[i] op_name = op.input[i]
op_name = op_name.replace(",", "#") op_name = op_name.replace(",", "#")
if (op_name in self._option.input_nodes or \ if (op_name in self._option.input_nodes or
op_name in self._option.output_nodes): op_name in self._option.output_nodes):
op.input[i] = op_name op.input[i] = op_name
for i in six.moves.range(len(op.output)): for i in six.moves.range(len(op.output)):
...@@ -195,7 +195,8 @@ class MegengineConverter(base_converter.ConverterInterface): ...@@ -195,7 +195,8 @@ class MegengineConverter(base_converter.ConverterInterface):
kernels_arg.name = MaceKeyword.mace_kernel_str kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel) kernels_arg.ints.extend(kernel)
if op_def.type in (MaceOp.Conv2D.name, MaceOp.DepthwiseConv2d.name, if op_def.type in (MaceOp.Conv2D.name, MaceOp.DepthwiseConv2d.name,
MaceOp.Deconv2D.name, MaceOp.DepthwiseDeconv2d.name): MaceOp.Deconv2D.name,
MaceOp.DepthwiseDeconv2d.name):
dilation = [params[mge_dilate_h_str], params[mge_dilate_w_str]] dilation = [params[mge_dilate_h_str], params[mge_dilate_w_str]]
dilation_arg = op_def.arg.add() dilation_arg = op_def.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str dilation_arg.name = MaceKeyword.mace_dilations_str
...@@ -426,13 +427,14 @@ class MegengineConverter(base_converter.ConverterInterface): ...@@ -426,13 +427,14 @@ class MegengineConverter(base_converter.ConverterInterface):
# check the case of counting include padding # check the case of counting include padding
mode = mge_op.params["mode"] mode = mge_op.params["mode"]
if mode == "AVERAGE_COUNT_EXCLUDE_PADDING" or \ if mode == "AVERAGE_COUNT_EXCLUDE_PADDING" or \
(mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and \ (mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and
mge_op.params["pad_h"] == 0): mge_op.params["pad_h"] == 0):
pool_type_arg.i = PoolingType.AVG.value pool_type_arg.i = PoolingType.AVG.value
elif mode == "MAX": elif mode == "MAX":
pool_type_arg.i = PoolingType.MAX.value pool_type_arg.i = PoolingType.MAX.value
else: else:
mace_check(False, "AVERAGE pooling should not count padding values") mace_check(False,
"AVERAGE pooling should not count padding values")
self.add_stride_pad_kernel_arg(mge_op.params, op) self.add_stride_pad_kernel_arg(mge_op.params, op)
......
...@@ -116,6 +116,7 @@ TFSupportedOps = [ ...@@ -116,6 +116,7 @@ TFSupportedOps = [
'SpaceToBatchND', 'SpaceToBatchND',
'SpaceToDepth', 'SpaceToDepth',
'Split', 'Split',
'SplitV',
'Sqrt', 'Sqrt',
'Square', 'Square',
'SquaredDifference', 'SquaredDifference',
...@@ -279,6 +280,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -279,6 +280,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.SpaceToBatchND.name: self.convert_space_batch, TFOpType.SpaceToBatchND.name: self.convert_space_batch,
TFOpType.SpaceToDepth.name: self.convert_space_depth, TFOpType.SpaceToDepth.name: self.convert_space_depth,
TFOpType.Split.name: self.convert_split, TFOpType.Split.name: self.convert_split,
TFOpType.SplitV.name: self.convert_splitv,
TFOpType.Sqrt.name: self.convert_elementwise, TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Stack.name: self.convert_stack, TFOpType.Stack.name: self.convert_stack,
...@@ -1057,14 +1059,15 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -1057,14 +1059,15 @@ class TensorflowConverter(base_converter.ConverterInterface):
keep_dims_arg.name = MaceKeyword.mace_keepdims_str keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = 0 keep_dims_arg.i = 0
def convert_split(self, tf_op): def convert_split(self, tf_op, axis_idx=0):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
num_or_size_splits = tf_op.get_attr('num_split') num_or_size_splits = tf_op.get_attr('num_split')
if num_or_size_splits == 1: is_split = (num_or_size_splits > 1)
if not is_split:
op.type = MaceOp.Identity.name op.type = MaceOp.Identity.name
else: else:
op.type = MaceOp.Split.name op.type = MaceOp.Split.name
axis = tf_op.inputs[0].eval().astype(np.int32) axis = tf_op.inputs[axis_idx].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
axis_arg = op.arg.add() axis_arg = op.arg.add()
...@@ -1074,8 +1077,24 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -1074,8 +1077,24 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_split_arg = op.arg.add() num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = num_or_size_splits num_split_arg.i = num_or_size_splits
del op.input[0] del op.input[axis_idx]
self._skip_tensor.add(tf_op.inputs[0].name) self._skip_tensor.add(tf_op.inputs[axis_idx].name)
return (op, is_split)
def convert_splitv(self, tf_op):
(op, is_split) = self.convert_split(tf_op, 2)
if not is_split:
return
size_splits_arg = op.arg.add()
size_splits_arg.name = MaceKeyword.mace_size_splits_str
size_splits = tf_op.inputs[1].eval().astype(np.int32)
del op.input[1]
self._skip_tensor.add(tf_op.inputs[1].name)
# todo(luxuhui): support size_splits
for size in size_splits:
mace_check(size == size_splits[0],
"SplitV Only support even distribution")
size_splits_arg.ints.extend(size_splits)
def convert_tile(self, tf_op): def convert_tile(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
......
...@@ -318,6 +318,7 @@ def validate_onnx_model(model_file, ...@@ -318,6 +318,7 @@ def validate_onnx_model(model_file,
mace_out_value, value, mace_out_value, value,
validation_threshold, log_file) validation_threshold, log_file)
def validate_megengine_model(model_file, input_file, def validate_megengine_model(model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
input_data_formats, output_names, output_shapes, input_data_formats, output_names, output_shapes,
...@@ -337,7 +338,7 @@ def validate_megengine_model(model_file, input_file, ...@@ -337,7 +338,7 @@ def validate_megengine_model(model_file, input_file,
util.formatted_file_name(input_file, input_names[i]), util.formatted_file_name(input_file, input_names[i]),
input_data_types[i]) input_data_types[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == DataFormat.NHWC and \ if (input_data_formats[i] == DataFormat.NHWC and
len(input_shapes[i]) == 4): len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2)) input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value) feed_inputs.append(input_value)
...@@ -356,10 +357,10 @@ def validate_megengine_model(model_file, input_file, ...@@ -356,10 +357,10 @@ def validate_megengine_model(model_file, input_file,
output_file_name = \ output_file_name = \
util.formatted_file_name(mace_out_file, output_names[i]) util.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == DataFormat.NHWC and \ if (output_data_formats[i] == DataFormat.NHWC and
len(output_shapes[i]) == 4): len(output_shapes[i]) == 4):
mace_out_value = \ mace_out_value = mace_out_value.reshape(
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2)) output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(output_names[i], mace_out_value, compare_output(output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file) mge_output_value, validation_threshold, log_file)
......
...@@ -748,8 +748,8 @@ def validate_model(abi, ...@@ -748,8 +748,8 @@ def validate_model(abi,
"--input_file=/mace/%s" % input_file_name, "--input_file=/mace/%s" % input_file_name,
"--mace_out_file=/mace/%s" % output_file_name, "--mace_out_file=/mace/%s" % output_file_name,
"--device_type=%s" % device_type, "--device_type=%s" % device_type,
"--input_node='%s'" % ",".join(input_nodes), "--input_node=%s" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes), "--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes), "--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes), "--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats), "--input_data_format=%s" % ",".join(input_data_formats),
......
...@@ -350,7 +350,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file, ...@@ -350,7 +350,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
common.formatted_file_name(input_file, input_names[i]), common.formatted_file_name(input_file, input_names[i]),
input_data_types[i]) input_data_types[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == common.DataFormat.NHWC and \ if (input_data_formats[i] == common.DataFormat.NHWC and
len(input_shapes[i]) == 4): len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2)) input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value) feed_inputs.append(input_value)
...@@ -369,10 +369,10 @@ def validate_megengine_model(platform, device_type, model_file, input_file, ...@@ -369,10 +369,10 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
output_file_name = \ output_file_name = \
common.formatted_file_name(mace_out_file, output_names[i]) common.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name) mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == common.DataFormat.NHWC and \ if (output_data_formats[i] == common.DataFormat.NHWC and
len(output_shapes[i]) == 4): len(output_shapes[i]) == 4):
mace_out_value = \ mace_out_value = mace_out_value.reshape(
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2)) output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], mace_out_value, compare_output(platform, device_type, output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file) mge_output_value, validation_threshold, log_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册