From ba4d2185728338eff88c39f40784bb03cb24cde4 Mon Sep 17 00:00:00 2001 From: mamingjie-China Date: Tue, 11 Aug 2020 16:18:47 +0800 Subject: [PATCH] fix bug in gatherv2 and change param without__data_format_optimization --- x2paddle/convert.py | 2 +- x2paddle/op_mapper/tf_op_mapper_nhwc.py | 16 ++++++++++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 717bd5f..42b4cbf 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -67,7 +67,7 @@ def arg_parser(): "--without_data_format_optimization", "-wo", action="store_true", - default=False, + default=True, help="tf model conversion without data format optimization") parser.add_argument( "--define_input_shape", diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index a5198cc..a83a303 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -1068,13 +1068,25 @@ class TFOpMapperNHWC(OpMapper): axis = axis.value.tolist() assert axis == 0, "Only support axis=0 in GatherV2 OP" attr = {'overwrite': False} + embeddings_shape = embeddings.out_shapes[0][-1] + reshape_list = list() + reshape_name = index.layer_name if len(index.out_shapes[0]) != 1: + reshape_list = index.out_shapes[0] reshape_attr = {"shape": [-1]} + reshape_name = "{}_reshape".format(index.layer_name) node.fluid_code.add_layer( - "reshape", inputs=index, output=index, param_attr=reshape_attr) - inputs = {'input': embeddings, 'index': index} + "reshape", + inputs=index, + output=reshape_name, + param_attr=reshape_attr) + inputs = {'input': embeddings, 'index': reshape_name} node.fluid_code.add_layer( "gather", inputs=inputs, output=node, param_attr=attr) + if len(index.out_shapes[0]) != 1: + reshape_attr = {"shape": reshape_list + [embeddings_shape]} + node.fluid_code.add_layer( + "reshape", inputs=node, output=node, param_attr=reshape_attr) def OneShotIterator(self, node): return self.Placeholder(node) -- GitLab