From 2c831a8df41ef4aa97a53b545b923ada3b3168b6 Mon Sep 17 00:00:00 2001 From: mamingjie-China Date: Tue, 4 Aug 2020 17:32:37 +0800 Subject: [PATCH] fix bug in decode --- x2paddle/decoder/tf_decoder.py | 8 ++------ x2paddle/op_mapper/tf_op_mapper_nhwc.py | 8 ++++---- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/x2paddle/decoder/tf_decoder.py b/x2paddle/decoder/tf_decoder.py index e7d63aa..dc04172 100644 --- a/x2paddle/decoder/tf_decoder.py +++ b/x2paddle/decoder/tf_decoder.py @@ -91,12 +91,8 @@ class TFGraphNode(GraphNode): @property def name(self): - multi_out_ops = ['Split', 'SplitV', 'IteratorV2'] - if self.layer_type in multi_out_ops: - if self.layer_name.count(':') > 0: - return self.layer_name.replace(':', '_p') - else: - return "{}_p0".format(self.layer_name) + if hasattr(self, 'index'): + return self.layer_name + "_p{}".format(self.index) return self.layer_name def get_attr(self, name): diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index 81079df..c747611 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -322,14 +322,14 @@ class TFOpMapperNHWC(OpMapper): if kernel.layer_type == 'Const': kernel_value = kernel.value - kernel_weight_name = kernel.layer_name.replace('/', '_') + kernel_weight_name = kernel.name.replace('/', '_') else: kernel_value = self.decoder.infer_tensor(kernel) if kernel.layer_type == 'Split': - kernel_weight_name = "{}_{}_kernel".format(node.layer_name, - kernel.layer_name) + kernel_weight_name = "{}_{}_kernel".format(node.name, + kernel.name) else: - kernel_weight_name = kernel.layer_name.replace('/', '_') + kernel_weight_name = kernel.name.replace('/', '_') program.parameters[kernel_weight_name] = numpy.transpose(kernel_value, (3, 2, 0, 1)) -- GitLab