提交 2c831a8d 编写于 作者: M mamingjie-China

fix bug in decode

上级 c3773664
......@@ -91,12 +91,8 @@ class TFGraphNode(GraphNode):
@property
def name(self):
multi_out_ops = ['Split', 'SplitV', 'IteratorV2']
if self.layer_type in multi_out_ops:
if self.layer_name.count(':') > 0:
return self.layer_name.replace(':', '_p')
else:
return "{}_p0".format(self.layer_name)
if hasattr(self, 'index'):
return self.layer_name + "_p{}".format(self.index)
return self.layer_name
def get_attr(self, name):
......
......@@ -322,14 +322,14 @@ class TFOpMapperNHWC(OpMapper):
if kernel.layer_type == 'Const':
kernel_value = kernel.value
kernel_weight_name = kernel.layer_name.replace('/', '_')
kernel_weight_name = kernel.name.replace('/', '_')
else:
kernel_value = self.decoder.infer_tensor(kernel)
if kernel.layer_type == 'Split':
kernel_weight_name = "{}_{}_kernel".format(node.layer_name,
kernel.layer_name)
kernel_weight_name = "{}_{}_kernel".format(node.name,
kernel.name)
else:
kernel_weight_name = kernel.layer_name.replace('/', '_')
kernel_weight_name = kernel.name.replace('/', '_')
program.parameters[kernel_weight_name] = numpy.transpose(kernel_value,
(3, 2, 0, 1))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册