提交 a5ff0aa4 编写于 作者: 叶剑武

Merge branch 'reshape' into 'master'

fix validate bug on reshape layer and duplicate name bug in encrypt.py

See merge request deep-computing/mace!1255
...@@ -161,14 +161,30 @@ void RegisterReshape(OpRegistryBase *op_registry) { ...@@ -161,14 +161,30 @@ void RegisterReshape(OpRegistryBase *op_registry) {
return {DeviceType::CPU, DeviceType::GPU}; return {DeviceType::CPU, DeviceType::GPU};
} }
// When transforming a model, has_data_format is set
// to true only when the data dimension conforms to
// specific rules, such as dimension == 4
int has_data_format =
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*op, "has_data_format", 0);
if (has_data_format) {
return {DeviceType::CPU, DeviceType::GPU};
}
DataFormat op_data_format = static_cast<DataFormat>(
ProtoArgHelper::GetOptionalArg<OperatorDef, int>(
*context->operator_def(), "data_format",
static_cast<int>(DataFormat::NONE)));
auto tensor_shape_info = context->tensor_shape_info(); auto tensor_shape_info = context->tensor_shape_info();
const std::string &input_0 = op->input(0); const std::string &input_0 = op->input(0);
const auto out_dims_size = const auto out_dims_size =
op->output_shape(0).dims_size(); op->output_shape(0).dims_size();
if (4 == tensor_shape_info->at(input_0).size() if (op_data_format == DataFormat::NHWC &&
&& (out_dims_size == 4 || out_dims_size == 2)) { 4 == tensor_shape_info->at(input_0).size() &&
(out_dims_size == 4 || out_dims_size == 2)) {
return {DeviceType::CPU, DeviceType::GPU}; return {DeviceType::CPU, DeviceType::GPU};
} }
return {DeviceType::CPU}; return {DeviceType::CPU};
})); }));
} }
......
...@@ -30,49 +30,48 @@ from utils import config_parser ...@@ -30,49 +30,48 @@ from utils import config_parser
from utils.config_parser import CPP_KEYWORDS from utils.config_parser import CPP_KEYWORDS
from utils.config_parser import ModelKeys from utils.config_parser import ModelKeys
GENERATED_NAME = set()
def generate_obfuscated_name(namespace, name, model_names_set):
def generate_obfuscated_name(namespace, name):
md5 = hashlib.md5() md5 = hashlib.md5()
md5.update(namespace) md5.update(namespace)
md5.update(name) md5.update(name)
md5_digest = md5.hexdigest() md5_digest = md5.hexdigest()
name = md5_digest[:8] name = md5_digest[:8]
while name in GENERATED_NAME: while name in model_names_set:
name = md5_digest name = md5_digest
assert name not in GENERATED_NAME assert name not in model_names_set
GENERATED_NAME.add(name) model_names_set.add(name)
return name return name
def generate_tensor_map(tensors): def generate_tensor_map(tensors, model_names_set):
tensor_map = {} tensor_map = {}
for t in tensors: for t in tensors:
if t.name not in tensor_map: if t.name not in tensor_map:
tensor_map[t.name] = generate_obfuscated_name("tensor", t.name) tensor_map[t.name] = \
generate_obfuscated_name("tensor", t.name, model_names_set)
return tensor_map return tensor_map
def generate_in_out_map(ops, tensor_map): def generate_in_out_map(ops, tensor_map, model_names_set):
in_out_map = {} in_out_map = {}
for op in ops: for op in ops:
op.name = generate_obfuscated_name("op", op.name) op.name = generate_obfuscated_name("op", op.name, model_names_set)
for input_name in op.input: for input_name in op.input:
if input_name not in in_out_map: if input_name not in in_out_map:
if input_name in tensor_map: if input_name in tensor_map:
in_out_map[input_name] = tensor_map[input_name] in_out_map[input_name] = tensor_map[input_name]
else: else:
in_out_map[input_name] = generate_obfuscated_name( in_out_map[input_name] = generate_obfuscated_name(
"in", input_name) "in", input_name, model_names_set)
for output_name in op.output: for output_name in op.output:
if output_name not in in_out_map: if output_name not in in_out_map:
if output_name in tensor_map: if output_name in tensor_map:
in_out_map[output_name] = tensor_map[output_name] in_out_map[output_name] = tensor_map[output_name]
else: else:
in_out_map[output_name] = generate_obfuscated_name( in_out_map[output_name] = generate_obfuscated_name(
"out", output_name) "out", output_name, model_names_set)
return in_out_map return in_out_map
...@@ -87,8 +86,9 @@ def obfuscate_name(model): ...@@ -87,8 +86,9 @@ def obfuscate_name(model):
output_nodes = set() output_nodes = set()
for output_node in model.output_info: for output_node in model.output_info:
output_nodes.add(output_node.name) output_nodes.add(output_node.name)
tensor_map = generate_tensor_map(model.tensors) model_names_set = set()
in_out_map = generate_in_out_map(model.op, tensor_map) tensor_map = generate_tensor_map(model.tensors, model_names_set)
in_out_map = generate_in_out_map(model.op, tensor_map, model_names_set)
for t in model.tensors: for t in model.tensors:
if t.name not in input_nodes and t.name not in output_nodes: if t.name not in input_nodes and t.name not in output_nodes:
t.name = tensor_map[t.name] t.name = tensor_map[t.name]
......
...@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1402,12 +1402,21 @@ class Transformer(base_converter.ConverterInterface):
def is_transposable_data_format_ops(self, op): def is_transposable_data_format_ops(self, op):
if op.type == MaceOp.Reshape: if op.type == MaceOp.Reshape:
input_op = self._producer[op.input[0]] input_op = self._producer[op.input[0]]
out_dims_len = len(op.output_shape[0].dims) input_dims = input_op.output_shape[0].dims
output_dims = op.output_shape[0].dims
tranposable = True
if len(input_op.output_shape) != 1 or \ if len(input_op.output_shape) != 1 or \
len(input_op.output_shape[0].dims) != 4 \ len(input_dims) != 4 or len(output_dims) != 4:
or (out_dims_len != 4 and out_dims_len != 2): tranposable = False
else:
in_b, in_h, in_w, in_c = self.sort_feature_map_shape(
input_dims, ConverterUtil.data_format(input_op))
ou_b, ou_h, ou_w, ou_c = self.sort_feature_map_shape(
output_dims, ConverterUtil.data_format(op))
tranposable = (in_b == ou_b and in_c == ou_c)
if not tranposable:
print("In this model, reshape is not transposable op.") print("In this model, reshape is not transposable op.")
return False return tranposable
return op.type in MaceTransposableDataFormatOps return op.type in MaceTransposableDataFormatOps
def update_data_format(self): def update_data_format(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册