diff --git a/tools/device.py b/tools/device.py index fd8c138fee10351d8647cc8c988c0d9173b3ab52..312cbb0855ca21c8bc8363fb00c1ddbc13be1d34 100644 --- a/tools/device.py +++ b/tools/device.py @@ -220,8 +220,8 @@ class DeviceWrapper: "MACE_LOG_TENSOR_RANGE=%d" % (1 if quantize_stat else 0), "%s/%s" % (target_dir, target_name), "--model_name=%s" % model_tag, - "--input_node='%s'" % ",".join(input_nodes), - "--output_node='%s'" % ",".join(output_nodes), + "--input_node=%s" % ",".join(input_nodes), + "--output_node=%s" % ",".join(output_nodes), "--input_shape=%s" % ":".join(input_shapes), "--output_shape=%s" % ":".join(output_shapes), "--input_data_format=%s" % ",".join(input_data_formats), diff --git a/tools/python/transform/base_converter.py b/tools/python/transform/base_converter.py index 5ccfaba13a431c7d61b4d64335f46982a3178357..4c23e65fbf9efa190c6f8e87e86a507f92b3af8e 100644 --- a/tools/python/transform/base_converter.py +++ b/tools/python/transform/base_converter.py @@ -236,6 +236,7 @@ class MaceKeyword(object): mace_end_axis_str = 'end_axis' mace_num_axes_str = 'num_axes' mace_num_split_str = 'num_split' + mace_size_splits_str = 'size_splits' mace_keepdims_str = 'keepdims' mace_shape_str = 'shape' mace_winograd_filter_transformed = 'is_filter_transformed' @@ -548,6 +549,7 @@ class ConverterOption(object): # Model structure related transformation TransformerRule.REMOVE_USELESS_OP, TransformerRule.TRANSFORM_FAKE_QUANTIZE, + TransformerRule.REMOVE_USELESS_OP, TransformerRule.TRANSFORM_GLOBAL_POOLING, TransformerRule.TRANSFORM_LSTMCELL_ZEROSTATE, TransformerRule.TRANSFORM_BASIC_LSTMCELL, diff --git a/tools/python/transform/megengine_converter.py b/tools/python/transform/megengine_converter.py index 7952509a2332cfe26a3537df231bfa18047f37ae..77d3fe4f57db19369d439db0502c40fe04462816 100644 --- a/tools/python/transform/megengine_converter.py +++ b/tools/python/transform/megengine_converter.py @@ -166,7 +166,7 @@ class MegengineConverter(base_converter.ConverterInterface): if "," in op.input[i]: op_name = op.input[i] op_name = op_name.replace(",", "#") - if (op_name in self._option.input_nodes or \ + if (op_name in self._option.input_nodes or op_name in self._option.output_nodes): op.input[i] = op_name for i in six.moves.range(len(op.output)): @@ -195,7 +195,8 @@ class MegengineConverter(base_converter.ConverterInterface): kernels_arg.name = MaceKeyword.mace_kernel_str kernels_arg.ints.extend(kernel) if op_def.type in (MaceOp.Conv2D.name, MaceOp.DepthwiseConv2d.name, - MaceOp.Deconv2D.name, MaceOp.DepthwiseDeconv2d.name): + MaceOp.Deconv2D.name, + MaceOp.DepthwiseDeconv2d.name): dilation = [params[mge_dilate_h_str], params[mge_dilate_w_str]] dilation_arg = op_def.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str @@ -426,13 +427,14 @@ class MegengineConverter(base_converter.ConverterInterface): # check the case of counting include padding mode = mge_op.params["mode"] if mode == "AVERAGE_COUNT_EXCLUDE_PADDING" or \ - (mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and \ + (mode == "AVERAGE" and mge_op.params["pad_w"] == 0 and mge_op.params["pad_h"] == 0): pool_type_arg.i = PoolingType.AVG.value elif mode == "MAX": pool_type_arg.i = PoolingType.MAX.value else: - mace_check(False, "AVERAGE pooling should not count padding values") + mace_check(False, + "AVERAGE pooling should not count padding values") self.add_stride_pad_kernel_arg(mge_op.params, op) diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index 8629596e7fc6aa37d23c74679c251102dfd0de62..98b67f63c4669693334bbb0f166e55a6b2c28682 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -116,6 +116,7 @@ TFSupportedOps = [ 'SpaceToBatchND', 'SpaceToDepth', 'Split', + 'SplitV', 'Sqrt', 'Square', 'SquaredDifference', @@ -279,6 +280,7 @@ class TensorflowConverter(base_converter.ConverterInterface): TFOpType.SpaceToBatchND.name: self.convert_space_batch, TFOpType.SpaceToDepth.name: self.convert_space_depth, TFOpType.Split.name: self.convert_split, + TFOpType.SplitV.name: self.convert_splitv, TFOpType.Sqrt.name: self.convert_elementwise, TFOpType.Squeeze.name: self.convert_squeeze, TFOpType.Stack.name: self.convert_stack, @@ -1057,14 +1059,15 @@ class TensorflowConverter(base_converter.ConverterInterface): keep_dims_arg.name = MaceKeyword.mace_keepdims_str keep_dims_arg.i = 0 - def convert_split(self, tf_op): + def convert_split(self, tf_op, axis_idx=0): op = self.convert_general_op(tf_op) num_or_size_splits = tf_op.get_attr('num_split') - if num_or_size_splits == 1: + is_split = (num_or_size_splits > 1) + if not is_split: op.type = MaceOp.Identity.name else: op.type = MaceOp.Split.name - axis = tf_op.inputs[0].eval().astype(np.int32) + axis = tf_op.inputs[axis_idx].eval().astype(np.int32) axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis axis_arg = op.arg.add() @@ -1074,8 +1077,24 @@ class TensorflowConverter(base_converter.ConverterInterface): num_split_arg = op.arg.add() num_split_arg.name = MaceKeyword.mace_num_split_str num_split_arg.i = num_or_size_splits - del op.input[0] - self._skip_tensor.add(tf_op.inputs[0].name) + del op.input[axis_idx] + self._skip_tensor.add(tf_op.inputs[axis_idx].name) + return (op, is_split) + + def convert_splitv(self, tf_op): + (op, is_split) = self.convert_split(tf_op, 2) + if not is_split: + return + size_splits_arg = op.arg.add() + size_splits_arg.name = MaceKeyword.mace_size_splits_str + size_splits = tf_op.inputs[1].eval().astype(np.int32) + del op.input[1] + self._skip_tensor.add(tf_op.inputs[1].name) + # todo(luxuhui): support size_splits + for size in size_splits: + mace_check(size == size_splits[0], + "SplitV Only support even distribution") + size_splits_arg.ints.extend(size_splits) def convert_tile(self, tf_op): op = self.convert_general_op(tf_op) diff --git a/tools/python/validate.py b/tools/python/validate.py index 52fb8513a1d6277eff20a012052e68c02aaaeb70..2f70f4a69c7275e4265a3f4f9241889c4e148d77 100644 --- a/tools/python/validate.py +++ b/tools/python/validate.py @@ -318,6 +318,7 @@ def validate_onnx_model(model_file, mace_out_value, value, validation_threshold, log_file) + def validate_megengine_model(model_file, input_file, mace_out_file, input_names, input_shapes, input_data_formats, output_names, output_shapes, @@ -337,7 +338,7 @@ def validate_megengine_model(model_file, input_file, util.formatted_file_name(input_file, input_names[i]), input_data_types[i]) input_value = input_value.reshape(input_shapes[i]) - if (input_data_formats[i] == DataFormat.NHWC and \ + if (input_data_formats[i] == DataFormat.NHWC and len(input_shapes[i]) == 4): input_value = input_value.transpose((0, 3, 1, 2)) feed_inputs.append(input_value) @@ -356,10 +357,10 @@ def validate_megengine_model(model_file, input_file, output_file_name = \ util.formatted_file_name(mace_out_file, output_names[i]) mace_out_value = load_data(output_file_name) - if (output_data_formats[i] == DataFormat.NHWC and \ + if (output_data_formats[i] == DataFormat.NHWC and len(output_shapes[i]) == 4): - mace_out_value = \ - mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2)) + mace_out_value = mace_out_value.reshape( + output_shapes[i]).transpose((0, 3, 1, 2)) compare_output(output_names[i], mace_out_value, mge_output_value, validation_threshold, log_file) diff --git a/tools/sh_commands.py b/tools/sh_commands.py index b41ec8f47248bb9e35e831d79fb2de5f9780f0ad..7f4432bf91de4c9d8a15d75ca40acb104fc880c2 100644 --- a/tools/sh_commands.py +++ b/tools/sh_commands.py @@ -748,8 +748,8 @@ def validate_model(abi, "--input_file=/mace/%s" % input_file_name, "--mace_out_file=/mace/%s" % output_file_name, "--device_type=%s" % device_type, - "--input_node='%s'" % ",".join(input_nodes), - "--output_node='%s'" % ",".join(output_nodes), + "--input_node=%s" % ",".join(input_nodes), + "--output_node=%s" % ",".join(output_nodes), "--input_shape=%s" % ":".join(input_shapes), "--output_shape=%s" % ":".join(output_shapes), "--input_data_format=%s" % ",".join(input_data_formats), diff --git a/tools/validate.py b/tools/validate.py index 1e76f5ba09fd9d75e85b3b5593736902cb61c27d..031ff04482c2929259f893baf3a0afbb2e2eb01e 100644 --- a/tools/validate.py +++ b/tools/validate.py @@ -350,7 +350,7 @@ def validate_megengine_model(platform, device_type, model_file, input_file, common.formatted_file_name(input_file, input_names[i]), input_data_types[i]) input_value = input_value.reshape(input_shapes[i]) - if (input_data_formats[i] == common.DataFormat.NHWC and \ + if (input_data_formats[i] == common.DataFormat.NHWC and len(input_shapes[i]) == 4): input_value = input_value.transpose((0, 3, 1, 2)) feed_inputs.append(input_value) @@ -369,10 +369,10 @@ def validate_megengine_model(platform, device_type, model_file, input_file, output_file_name = \ common.formatted_file_name(mace_out_file, output_names[i]) mace_out_value = load_data(output_file_name) - if (output_data_formats[i] == common.DataFormat.NHWC and \ + if (output_data_formats[i] == common.DataFormat.NHWC and len(output_shapes[i]) == 4): - mace_out_value = \ - mace_out_value.reshape(output_shapes[i]).transpose((0, 3, 1, 2)) + mace_out_value = mace_out_value.reshape( + output_shapes[i]).transpose((0, 3, 1, 2)) compare_output(platform, device_type, output_names[i], mace_out_value, mge_output_value, validation_threshold, log_file)