提交 7fa03dd2 编写于 作者: 李寅

Merge branch 'check' into 'master'

Fix check tensor

See merge request !1116
...@@ -137,7 +137,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -137,7 +137,6 @@ class Transformer(base_converter.ConverterInterface):
changed = transformer() changed = transformer()
if not changed: if not changed:
break break
self.delete_after_check_nodes()
return self._model, self._quantize_activation_info return self._model, self._quantize_activation_info
def initialize_name_map(self): def initialize_name_map(self):
...@@ -1332,7 +1331,10 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1332,7 +1331,10 @@ class Transformer(base_converter.ConverterInterface):
visited = set() visited = set()
sorted_nodes = [] sorted_nodes = []
for output_node in self._option.output_nodes: output_nodes = self._option.check_nodes
if not self._quantize_activation_info:
output_nodes.update(self._option.output_nodes)
for output_node in output_nodes:
mace_check(output_node in self._producer, mace_check(output_node in self._producer,
"output_tensor %s not existed in model" % output_node) "output_tensor %s not existed in model" % output_node)
self.sort_dfs(self._producer[output_node], visited, sorted_nodes) self.sort_dfs(self._producer[output_node], visited, sorted_nodes)
...@@ -2010,18 +2012,6 @@ class Transformer(base_converter.ConverterInterface): ...@@ -2010,18 +2012,6 @@ class Transformer(base_converter.ConverterInterface):
arg.i = mace_pb2.GPU_IMAGE if self._option.cl_mem_type == "image"\ arg.i = mace_pb2.GPU_IMAGE if self._option.cl_mem_type == "image"\
else mace_pb2.GPU_BUFFER else mace_pb2.GPU_BUFFER
def delete_after_check_nodes(self):
if self._option.check_nodes != self._option.output_nodes:
mace_check(len(self._option.check_nodes) == 1,
"Only support one check node now.")
check_node = None
for i in six.moves.range(len(self._model.op)):
if self._model.op[i].output[0] in self._option.check_nodes:
check_node = self._model.op[i]
del self._model.op[i+1:]
break
mace_check(check_node is not None, "check node not found.")
def transform_caffe_reshape_and_flatten(self): def transform_caffe_reshape_and_flatten(self):
net = self._model net = self._model
for op in net.op: for op in net.op:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册