提交 87041d07 编写于 作者: L luxuhui

fix validate bug on reshape layer and duplicate name bug in encrypt.py

issue595
Signed-off-by: NLuxuhui <luxuhui@xiaomi.com>
上级 446bb57c
......@@ -161,14 +161,30 @@ void RegisterReshape(OpRegistryBase *op_registry) {
return {DeviceType::CPU, DeviceType::GPU};
}
// When transforming a model, has_data_format is set
// to true only when the data dimension conforms to
// specific rules, such as dimension == 4
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (has_data_format) {
return {DeviceType::CPU, DeviceType::GPU};
}
DataFormat op_data_format = static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0);
const auto out_dims_size =
op->output_shape(0).dims_size();
if (4 == tensor_shape_info->at(input_0).size()
&& (out_dims_size == 4 || out_dims_size == 2)) {
if (op_data_format == DataFormat::NHWC &&
4 == tensor_shape_info->at(input_0).size() &&
(out_dims_size == 4 || out_dims_size == 2)) {
return {DeviceType::CPU, DeviceType::GPU};
}
return {DeviceType::CPU};
}));
}
......
......@@ -30,49 +30,48 @@ from utils import config_parser
from utils.config_parser import CPP_KEYWORDS
from utils.config_parser import ModelKeys
GENERATED_NAME = set()
def generate_obfuscated_name(namespace, name):
def generate_obfuscated_name(namespace, name, model_names_set):
md5 = hashlib.md5()
md5.update(namespace)
md5.update(name)
md5_digest = md5.hexdigest()
name = md5_digest[:8]
while name in GENERATED_NAME:
while name in model_names_set:
name = md5_digest
assert name not in GENERATED_NAME
GENERATED_NAME.add(name)
assert name not in model_names_set
model_names_set.add(name)
return name
def generate_tensor_map(tensors):
def generate_tensor_map(tensors, model_names_set):
tensor_map = {}
for t in tensors:
if t.name not in tensor_map:
tensor_map[t.name] = generate_obfuscated_name("tensor", t.name)
tensor_map[t.name] = \
generate_obfuscated_name("tensor", t.name, model_names_set)
return tensor_map
def generate_in_out_map(ops, tensor_map):
def generate_in_out_map(ops, tensor_map, model_names_set):
in_out_map = {}
for op in ops:
op.name = generate_obfuscated_name("op", op.name)
op.name = generate_obfuscated_name("op", op.name, model_names_set)
for input_name in op.input:
if input_name not in in_out_map:
if input_name in tensor_map:
in_out_map[input_name] = tensor_map[input_name]
else:
in_out_map[input_name] = generate_obfuscated_name(
"in", input_name)
"in", input_name, model_names_set)
for output_name in op.output:
if output_name not in in_out_map:
if output_name in tensor_map:
in_out_map[output_name] = tensor_map[output_name]
else:
in_out_map[output_name] = generate_obfuscated_name(
"out", output_name)
"out", output_name, model_names_set)
return in_out_map
......@@ -87,8 +86,9 @@ def obfuscate_name(model):
output_nodes = set()
for output_node in model.output_info:
output_nodes.add(output_node.name)
tensor_map = generate_tensor_map(model.tensors)
in_out_map = generate_in_out_map(model.op, tensor_map)
model_names_set = set()
tensor_map = generate_tensor_map(model.tensors, model_names_set)
in_out_map = generate_in_out_map(model.op, tensor_map, model_names_set)
for t in model.tensors:
if t.name not in input_nodes and t.name not in output_nodes:
t.name = tensor_map[t.name]
......
......@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface):
def is_transposable_data_format_ops(self, op):
if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]]
out_dims_len = len(op.output_shape[0].dims)
input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
tranposable = True
if len(input_op.output_shape) != 1 or \
len(input_op.output_shape[0].dims) != 4 \
or (out_dims_len != 4 and out_dims_len != 2):
len(input_dims) != 4 or len(output_dims) != 4:
tranposable = False
else:
in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
input_dims, ConverterUtil.data_format(input_op))
ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
output_dims, ConverterUtil.data_format(op))
tranposable = (in_b == ou_b and in_c == ou_c)
if not tranposable:
print("In this model, reshape is not transposable op.")
return False
return tranposable
return op.type in MaceTransposableDataFormatOps
def update_data_format(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册