提交 3bda6319 编写于 作者: 李滨

Merge branch 'splitv' into 'master'

feat: support splitV op for tensorflow

See merge request deep-computing/mace!1284
......@@ -220,8 +220,8 @@ class DeviceWrapper:
"MACE_LOG_TENSOR_RANGE=%d" % (1 if quantize_stat else 0),
"%s/%s" % (target_dir, target_name),
"--model_name=%s" % model_tag,
"--input_node='%s'" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes),
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
......
......@@ -236,6 +236,7 @@ class MaceKeyword(object):
mace_end_axis_str = 'end_axis'
mace_num_axes_str = 'num_axes'
mace_num_split_str = 'num_split'
mace_size_splits_str = 'size_splits'
mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed'
......@@ -548,6 +549,7 @@ class ConverterOption(object):
# Model structure related transformation
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_FAKE_QUANTIZE,
TransformerRule.REMOVE_USELESS_OP,
TransformerRule.TRANSFORM_GLOBAL_POOLING,
TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE,
TransformerRule.TRANSFORM_BASIC_LSTMCELL,
......
......@@ -166,7 +166,7 @@ class MegengineConverter(base_converter.ConverterInterface):
if "," in op.input[i]:
op_name = op.input[i]
op_name = op_name.replace(",", "#")
if (op_name in self._option.input_nodes or \
if (op_name in self._option.input_nodes or
op_name in self._option.output_nodes):
op.input[i] = op_name
for i in six.moves.range(len(op.output)):
......@@ -195,7 +195,8 @@ class MegengineConverter(base_converter.ConverterInterface):
kernels_arg.name = MaceKeyword.mace_kernel_str
kernels_arg.ints.extend(kernel)
if op_def.type in (MaceOp.Conv2D.name, MaceOp.DepthwiseConv2d.name,
MaceOp.Deconv2D.name, MaceOp.DepthwiseDeconv2d.name):
MaceOp.Deconv2D.name,
MaceOp.DepthwiseDeconv2d.name):
dilation = [params[mge_dilate_h_str], params[mge_dilate_w_str]]
dilation_arg = op_def.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
......@@ -426,13 +427,14 @@ class MegengineConverter(base_converter.ConverterInterface):
# check the case of counting include padding
mode = mge_op.params["mode"]
if mode == "AVERAGE_COUNT_EXCLUDE_PADDING" or \
(mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and \
(mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and
mge_op.params["pad_h"] == 0):
pool_type_arg.i = PoolingType.AVG.value
elif mode == "MAX":
pool_type_arg.i = PoolingType.MAX.value
else:
mace_check(False, "AVERAGE pooling should not count padding values")
mace_check(False,
"AVERAGE pooling should not count padding values")
self.add_stride_pad_kernel_arg(mge_op.params, op)
......
......@@ -116,6 +116,7 @@ TFSupportedOps = [
'SpaceToBatchND',
'SpaceToDepth',
'Split',
'SplitV',
'Sqrt',
'Square',
'SquaredDifference',
......@@ -279,6 +280,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.SpaceToBatchND.name: self.convert_space_batch,
TFOpType.SpaceToDepth.name: self.convert_space_depth,
TFOpType.Split.name: self.convert_split,
TFOpType.SplitV.name: self.convert_splitv,
TFOpType.Sqrt.name: self.convert_elementwise,
TFOpType.Squeeze.name: self.convert_squeeze,
TFOpType.Stack.name: self.convert_stack,
......@@ -1057,14 +1059,15 @@ class TensorflowConverter(base_converter.ConverterInterface):
keep_dims_arg.name = MaceKeyword.mace_keepdims_str
keep_dims_arg.i = 0
def convert_split(self, tf_op):
def convert_split(self, tf_op, axis_idx=0):
op = self.convert_general_op(tf_op)
num_or_size_splits = tf_op.get_attr('num_split')
if num_or_size_splits == 1:
is_split = (num_or_size_splits > 1)
if not is_split:
op.type = MaceOp.Identity.name
else:
op.type = MaceOp.Split.name
axis = tf_op.inputs[0].eval().astype(np.int32)
axis = tf_op.inputs[axis_idx].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
axis_arg = op.arg.add()
......@@ -1074,8 +1077,24 @@ class TensorflowConverter(base_converter.ConverterInterface):
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = num_or_size_splits
del op.input[0]
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[axis_idx]
self._skip_tensor.add(tf_op.inputs[axis_idx].name)
return (op, is_split)
def convert_splitv(self, tf_op):
(op, is_split) = self.convert_split(tf_op, 2)
if not is_split:
return
size_splits_arg = op.arg.add()
size_splits_arg.name = MaceKeyword.mace_size_splits_str
size_splits = tf_op.inputs[1].eval().astype(np.int32)
del op.input[1]
self._skip_tensor.add(tf_op.inputs[1].name)
# todo(luxuhui): support size_splits
for size in size_splits:
mace_check(size == size_splits[0],
"SplitV Only support even distribution")
size_splits_arg.ints.extend(size_splits)
def convert_tile(self, tf_op):
op = self.convert_general_op(tf_op)
......
......@@ -318,6 +318,7 @@ def validate_onnx_model(model_file,
mace_out_value, value,
validation_threshold, log_file)
def validate_megengine_model(model_file, input_file,
mace_out_file, input_names, input_shapes,
input_data_formats, output_names, output_shapes,
......@@ -337,7 +338,7 @@ def validate_megengine_model(model_file, input_file,
util.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == DataFormat.NHWC and \
if (input_data_formats[i] == DataFormat.NHWC and
len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value)
......@@ -356,10 +357,10 @@ def validate_megengine_model(model_file, input_file,
output_file_name = \
util.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == DataFormat.NHWC and \
if (output_data_formats[i] == DataFormat.NHWC and
len(output_shapes[i]) == 4):
mace_out_value = \
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2))
mace_out_value = mace_out_value.reshape(
output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file)
......
......@@ -748,8 +748,8 @@ def validate_model(abi,
"--input_file=/mace/%s" % input_file_name,
"--mace_out_file=/mace/%s" % output_file_name,
"--device_type=%s" % device_type,
"--input_node='%s'" % ",".join(input_nodes),
"--output_node='%s'" % ",".join(output_nodes),
"--input_node=%s" % ",".join(input_nodes),
"--output_node=%s" % ",".join(output_nodes),
"--input_shape=%s" % ":".join(input_shapes),
"--output_shape=%s" % ":".join(output_shapes),
"--input_data_format=%s" % ",".join(input_data_formats),
......
......@@ -350,7 +350,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i])
if (input_data_formats[i] == common.DataFormat.NHWC and \
if (input_data_formats[i] == common.DataFormat.NHWC and
len(input_shapes[i]) == 4):
input_value = input_value.transpose((0, 3, 1, 2))
feed_inputs.append(input_value)
......@@ -369,10 +369,10 @@ def validate_megengine_model(platform, device_type, model_file, input_file,
output_file_name = \
common.formatted_file_name(mace_out_file, output_names[i])
mace_out_value = load_data(output_file_name)
if (output_data_formats[i] == common.DataFormat.NHWC and \
if (output_data_formats[i] == common.DataFormat.NHWC and
len(output_shapes[i]) == 4):
mace_out_value = \
mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2))
mace_out_value = mace_out_value.reshape(
output_shapes[i]).transpose((0, 3, 1, 2))
compare_output(platform, device_type, output_names[i], mace_out_value,
mge_output_value, validation_threshold, log_file)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册