From 87041d07668a0363307bc32d9f3710cabdd1f097 Mon Sep 17 00:00:00 2001 From: luxuhui Date: Tue, 24 Mar 2020 17:48:14 +0800 Subject: [PATCH] fix validate bug on reshape layer and duplicate name bug in encrypt.py issue595 Signed-off-by: Luxuhui --- mace/ops/reshape.cc | 20 +++++++++++++++++-- tools/python/encrypt.py | 28 +++++++++++++-------------- tools/python/transform/transformer.py | 17 ++++++++++++---- 3 files changed, 45 insertions(+), 20 deletions(-) diff --git a/mace/ops/reshape.cc b/mace/ops/reshape.cc index 94561720..b5daa430 100644 --- a/mace/ops/reshape.cc +++ b/mace/ops/reshape.cc @@ -161,14 +161,30 @@ void RegisterReshape(OpRegistryBase *op_registry) { return {DeviceType::CPU, DeviceType::GPU}; } + // When transforming a model, has_data_format is set + // to true only when the data dimension conforms to + // specific rules, such as dimension == 4 + int has_data_format = + ProtoArgHelper::GetOptionalArg( + *op, "has_data_format", 0); + if (has_data_format) { + return {DeviceType::CPU, DeviceType::GPU}; + } + + DataFormat op_data_format = static_cast( + ProtoArgHelper::GetOptionalArg( + *context->operator_def(), "data_format", + static_cast(DataFormat::NONE))); auto tensor_shape_info = context->tensor_shape_info(); const std::string &input_0 = op->input(0); const auto out_dims_size = op->output_shape(0).dims_size(); - if (4 == tensor_shape_info->at(input_0).size() - && (out_dims_size == 4 || out_dims_size == 2)) { + if (op_data_format == DataFormat::NHWC && + 4 == tensor_shape_info->at(input_0).size() && + (out_dims_size == 4 || out_dims_size == 2)) { return {DeviceType::CPU, DeviceType::GPU}; } + return {DeviceType::CPU}; })); } diff --git a/tools/python/encrypt.py b/tools/python/encrypt.py index 6a28fe2f..263e4941 100644 --- a/tools/python/encrypt.py +++ b/tools/python/encrypt.py @@ -30,49 +30,48 @@ from utils import config_parser from utils.config_parser import CPP_KEYWORDS from utils.config_parser import ModelKeys -GENERATED_NAME = set() - -def generate_obfuscated_name(namespace, name): +def generate_obfuscated_name(namespace, name, model_names_set): md5 = hashlib.md5() md5.update(namespace) md5.update(name) md5_digest = md5.hexdigest() name = md5_digest[:8] - while name in GENERATED_NAME: + while name in model_names_set: name = md5_digest - assert name not in GENERATED_NAME - GENERATED_NAME.add(name) + assert name not in model_names_set + model_names_set.add(name) return name -def generate_tensor_map(tensors): +def generate_tensor_map(tensors, model_names_set): tensor_map = {} for t in tensors: if t.name not in tensor_map: - tensor_map[t.name] = generate_obfuscated_name("tensor", t.name) + tensor_map[t.name] = \ + generate_obfuscated_name("tensor", t.name, model_names_set) return tensor_map -def generate_in_out_map(ops, tensor_map): +def generate_in_out_map(ops, tensor_map, model_names_set): in_out_map = {} for op in ops: - op.name = generate_obfuscated_name("op", op.name) + op.name = generate_obfuscated_name("op", op.name, model_names_set) for input_name in op.input: if input_name not in in_out_map: if input_name in tensor_map: in_out_map[input_name] = tensor_map[input_name] else: in_out_map[input_name] = generate_obfuscated_name( - "in", input_name) + "in", input_name, model_names_set) for output_name in op.output: if output_name not in in_out_map: if output_name in tensor_map: in_out_map[output_name] = tensor_map[output_name] else: in_out_map[output_name] = generate_obfuscated_name( - "out", output_name) + "out", output_name, model_names_set) return in_out_map @@ -87,8 +86,9 @@ def obfuscate_name(model): output_nodes = set() for output_node in model.output_info: output_nodes.add(output_node.name) - tensor_map = generate_tensor_map(model.tensors) - in_out_map = generate_in_out_map(model.op, tensor_map) + model_names_set = set() + tensor_map = generate_tensor_map(model.tensors, model_names_set) + in_out_map = generate_in_out_map(model.op, tensor_map, model_names_set) for t in model.tensors: if t.name not in input_nodes and t.name not in output_nodes: t.name = tensor_map[t.name] diff --git a/tools/python/transform/transformer.py b/tools/python/transform/transformer.py index 2c8901a9..3ac310f0 100644 --- a/tools/python/transform/transformer.py +++ b/tools/python/transform/transformer.py @@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface): def is_transposable_data_format_ops(self, op): if op.type == MaceOp.Reshape: input_op = self._producer[op.input[0]] - out_dims_len = len(op.output_shape[0].dims) + input_dims = input_op.output_shape[0].dims + output_dims = op.output_shape[0].dims + tranposable = True if len(input_op.output_shape) != 1 or \ - len(input_op.output_shape[0].dims) != 4 \ - or (out_dims_len != 4 and out_dims_len != 2): + len(input_dims) != 4 or len(output_dims) != 4: + tranposable = False + else: + in_b, in_h, in_w, in_c = self.sort_feature_map_shape( + input_dims, ConverterUtil.data_format(input_op)) + ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape( + output_dims, ConverterUtil.data_format(op)) + tranposable = (in_b == ou_b and in_c == ou_c) + if not tranposable: print("In this model, reshape is not transposable op.") - return False + return tranposable return op.type in MaceTransposableDataFormatOps def update_data_format(self): -- GitLab