提交 2c831a8d 编写于 作者: M mamingjie-China

fix bug in decode

上级 c3773664
...@@ -91,12 +91,8 @@ class TFGraphNode(GraphNode): ...@@ -91,12 +91,8 @@ class TFGraphNode(GraphNode):
@property @property
def name(self): def name(self):
multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] if hasattr(self, 'index'):
if self.layer_type in multi_out_ops: return self.layer_name + "_p{}".format(self.index)
if self.layer_name.count(':') > 0:
return self.layer_name.replace(':', '_p')
else:
return "{}_p0".format(self.layer_name)
return self.layer_name return self.layer_name
def get_attr(self, name): def get_attr(self, name):
......
...@@ -322,14 +322,14 @@ class TFOpMapperNHWC(OpMapper): ...@@ -322,14 +322,14 @@ class TFOpMapperNHWC(OpMapper):
if kernel.layer_type == 'Const': if kernel.layer_type == 'Const':
kernel_value = kernel.value kernel_value = kernel.value
kernel_weight_name = kernel.layer_name.replace('/', '_') kernel_weight_name = kernel.name.replace('/', '_')
else: else:
kernel_value = self.decoder.infer_tensor(kernel) kernel_value = self.decoder.infer_tensor(kernel)
if kernel.layer_type == 'Split': if kernel.layer_type == 'Split':
kernel_weight_name = "{}_{}_kernel".format(node.layer_name, kernel_weight_name = "{}_{}_kernel".format(node.name,
kernel.layer_name) kernel.name)
else: else:
kernel_weight_name = kernel.layer_name.replace('/', '_') kernel_weight_name = kernel.name.replace('/', '_')
program.parameters[kernel_weight_name] = numpy.transpose(kernel_value, program.parameters[kernel_weight_name] = numpy.transpose(kernel_value,
(3, 2, 0, 1)) (3, 2, 0, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册