diff --git a/x2paddle/convert.py b/x2paddle/convert.py index 717bd5fc43138bac1d826b458281f6e97017112c..42b4cbfe9d7b7d8d73bbbb512b68f9dce8f6af5d 100644 --- a/x2paddle/convert.py +++ b/x2paddle/convert.py @@ -67,7 +67,7 @@ def arg_parser(): "--without_data_format_optimization", "-wo", action="store_true", - default=False, + default=True, help="tf model conversion without data format optimization") parser.add_argument( "--define_input_shape", diff --git a/x2paddle/op_mapper/tf_op_mapper_nhwc.py b/x2paddle/op_mapper/tf_op_mapper_nhwc.py index a5198cc780203722042b8ae043fb169b94eeb3be..a83a303e5f0fe3f21f32b44c06c7c5d44b59bd4d 100644 --- a/x2paddle/op_mapper/tf_op_mapper_nhwc.py +++ b/x2paddle/op_mapper/tf_op_mapper_nhwc.py @@ -1068,13 +1068,25 @@ class TFOpMapperNHWC(OpMapper): axis = axis.value.tolist() assert axis == 0, "Only support axis=0 in GatherV2 OP" attr = {'overwrite': False} + embeddings_shape = embeddings.out_shapes[0][-1] + reshape_list = list() + reshape_name = index.layer_name if len(index.out_shapes[0]) != 1: + reshape_list = index.out_shapes[0] reshape_attr = {"shape": [-1]} + reshape_name = "{}_reshape".format(index.layer_name) node.fluid_code.add_layer( - "reshape", inputs=index, output=index, param_attr=reshape_attr) - inputs = {'input': embeddings, 'index': index} + "reshape", + inputs=index, + output=reshape_name, + param_attr=reshape_attr) + inputs = {'input': embeddings, 'index': reshape_name} node.fluid_code.add_layer( "gather", inputs=inputs, output=node, param_attr=attr) + if len(index.out_shapes[0]) != 1: + reshape_attr = {"shape": reshape_list + [embeddings_shape]} + node.fluid_code.add_layer( + "reshape", inputs=node, output=node, param_attr=reshape_attr) def OneShotIterator(self, node): return self.Placeholder(node)