提交 a229bff2 编写于 作者: W wjj19950828

Support T5

上级 5f0b141b
......@@ -122,6 +122,7 @@ class OpSet9():
'Mul': 'paddle.multiply',
'Pow': 'paddle.pow',
'Less': 'paddle.less_than',
'LessOrEqual': 'paddle.less_equal',
}
directly_map_ops = {
......@@ -741,6 +742,7 @@ class OpSet9():
axes = node.get_attr('axes')
if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes)
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name:
self.paddle_graph.add_layer(
......@@ -749,7 +751,8 @@ class OpSet9():
outputs=[node.name],
shape=[1])
else:
if isinstance(axes, list) or isinstance(axes, tuple):
if isinstance(axes, list) or isinstance(axes, tuple) or isinstance(
axes, np.ndarray):
self.paddle_graph.add_layer(
'paddle.unsqueeze',
inputs={"x": val_x.name},
......@@ -893,72 +896,24 @@ class OpSet9():
def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_values = _const_weight_or_none(indices)
if isinstance(indices_values, np.ndarray):
indices_values = indices_values.tolist()
indices_shape = indices.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 0)
#assert len(
# indices_shape) <= 2, "Gather op don't support dim of indice >2 "
if axis == 0 and len(indices_shape) <= 1:
if len(val_x.out_shapes[0]) <= 1:
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices.name],
shape=[-1, ])
gather_ = node.name + '_1'
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[gather_])
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': gather_},
outputs=[node.name],
axis=[0])
else:
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
name_trans = val_x.name + '_trans'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
if len(indices_shape) == 1 or \
(indices_values is not None and isinstance(indices_values, int)) or \
(indices_values is not None and len(indices_values) == 1):
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': name_trans,
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name],
perm=new_perm)
if len(indices_shape) < 1:
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': node.name},
outputs=[node.name],
axis=[axis])
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
axis=axis)
else:
# if val_x is DataNode, convert gather to embedding
if axis == 0 and isinstance(val_x, ONNXGraphDataNode):
indices_cast = indices.name + '_cast'
self.paddle_graph.add_layer(
'paddle.cast',
......@@ -973,79 +928,68 @@ class OpSet9():
'paddle.nn.Embedding',
inputs={"x": indices_cast},
outputs=layer_outputs,
num_embeddings=val_x.out_shapes[0][0],
embedding_dim=val_x.out_shapes[0][1])
num_embeddings=val_x_shape[0],
embedding_dim=val_x_shape[1])
else:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices_reshape],
shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0])))
outputs=[indices.name + "_reshape"],
shape=[-1])
gather_1d = node.name + '_1D'
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices_reshape},
outputs=[node.name])
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": node.name},
outputs=[node.name],
shape=reshaped_shape)
elif axis > 0 and len(indices_shape) > 1:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices_reshape],
shape=[reshape_shape, ])
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
name_trans = val_x.name + '_transpose'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': name_trans,
'index': indices_reshape},
outputs=[node.name])
input_transpose = node.name + '_transpose'
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[input_transpose],
perm=new_perm)
perm = new_perm
val_x_shape = val_x.out_shapes[0]
reshaped_shape = []
for i in perm:
reshaped_shape.append(indices_shape[i])
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": input_transpose},
outputs=[node.name],
shape=reshaped_shape)
inputs={
'x': val_x.name,
'index': indices.name + "_reshape"
},
outputs=[gather_1d],
axis=axis)
# if shape is known
if len(indices_shape) != 0 and len(val_x_shape) != 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': gather_1d},
outputs=[node.name],
shape=val_x_shape[:axis] + indices_shape +
val_x_shape[axis + 1:])
else:
all_shape_name = list()
self.paddle_graph.add_layer(
kernel="paddle.shape",
inputs={"input": val_x.name},
outputs=[val_x.name + "_shape"])
self.paddle_graph.add_layer(
kernel="paddle.shape",
inputs={"input": indices.name},
outputs=[indices.name + "_shape"])
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_slice_start"],
axes=[0],
starts=[0],
ends=[axis])
all_shape_name.append(val_x.name + "_shape_slice_start")
all_shape_name.append(indices.name + "_shape")
self.paddle_graph.add_layer(
"paddle.slice",
inputs={"input": val_x.name + "_shape"},
outputs=[val_x.name + "_shape_slice_end"],
axes=[0],
starts=[axis + 1],
ends=[2147483647])
all_shape_name.append(val_x.name + "_shape_slice_end")
self.paddle_graph.add_layer(
'paddle.concat',
inputs={"x": all_shape_name},
outputs=[node.name + "_all_shape"],
axis=0)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={'x': gather_1d},
outputs=[node.name],
shape=node.name + "_all_shape")
@print_mapping_info
def ScatterND(self, node):
......@@ -1255,16 +1199,6 @@ class OpSet9():
outputs=[node.name],
**layer_attrs)
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info
def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -1420,16 +1354,6 @@ class OpSet9():
"y": val_y.name},
outputs=[node.name])
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info
def And(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
......@@ -1710,7 +1634,8 @@ class OpSet9():
x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0]
inputs_dict = {"x": val_x.name, "y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1:
if len(y_shape) != 0 and y_shape[0] == 1 and len(
x_shape) != 0 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer(
"paddle.squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册