提交 ba4d2185 编写于 作者: M mamingjie-China

fix bug in gatherv2 and change param without__data_format_optimization

上级 f745dab4
......@@ -67,7 +67,7 @@ def arg_parser():
"--without_data_format_optimization",
"-wo",
action="store_true",
default=False,
default=True,
help="tf model conversion without data format optimization")
parser.add_argument(
"--define_input_shape",
......
......@@ -1068,13 +1068,25 @@ class TFOpMapperNHWC(OpMapper):
axis = axis.value.tolist()
assert axis == 0, "Only support axis=0 in GatherV2 OP"
attr = {'overwrite': False}
embeddings_shape = embeddings.out_shapes[0][-1]
reshape_list = list()
reshape_name = index.layer_name
if len(index.out_shapes[0]) != 1:
reshape_list = index.out_shapes[0]
reshape_attr = {"shape": [-1]}
reshape_name = "{}_reshape".format(index.layer_name)
node.fluid_code.add_layer(
"reshape", inputs=index, output=index, param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': index}
"reshape",
inputs=index,
output=reshape_name,
param_attr=reshape_attr)
inputs = {'input': embeddings, 'index': reshape_name}
node.fluid_code.add_layer(
"gather", inputs=inputs, output=node, param_attr=attr)
if len(index.out_shapes[0]) != 1:
reshape_attr = {"shape": reshape_list + [embeddings_shape]}
node.fluid_code.add_layer(
"reshape", inputs=node, output=node, param_attr=reshape_attr)
def OneShotIterator(self, node):
return self.Placeholder(node)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册