提交 a229bff2 编写于 作者: W wjj19950828

Support T5

上级 5f0b141b
...@@ -122,6 +122,7 @@ class OpSet9(): ...@@ -122,6 +122,7 @@ class OpSet9():
'Mul': 'paddle.multiply', 'Mul': 'paddle.multiply',
'Pow': 'paddle.pow', 'Pow': 'paddle.pow',
'Less': 'paddle.less_than', 'Less': 'paddle.less_than',
'LessOrEqual': 'paddle.less_equal',
} }
directly_map_ops = { directly_map_ops = {
...@@ -741,6 +742,7 @@ class OpSet9(): ...@@ -741,6 +742,7 @@ class OpSet9():
axes = node.get_attr('axes') axes = node.get_attr('axes')
if axes is None: if axes is None:
axes = self.graph.get_input_node(node, idx=1, copy=True) axes = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes)
if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0: if len(val_x.out_shapes[0]) == 0 and len(axes) == 1 and axes[0] == 0:
if node.name: if node.name:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
...@@ -749,7 +751,8 @@ class OpSet9(): ...@@ -749,7 +751,8 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
shape=[1]) shape=[1])
else: else:
if isinstance(axes, list) or isinstance(axes, tuple): if isinstance(axes, list) or isinstance(axes, tuple) or isinstance(
axes, np.ndarray):
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.unsqueeze', 'paddle.unsqueeze',
inputs={"x": val_x.name}, inputs={"x": val_x.name},
...@@ -893,72 +896,24 @@ class OpSet9(): ...@@ -893,72 +896,24 @@ class OpSet9():
def Gather(self, node): def Gather(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
indices = self.graph.get_input_node(node, idx=1, copy=True) indices = self.graph.get_input_node(node, idx=1, copy=True)
indices_values = _const_weight_or_none(indices)
if isinstance(indices_values, np.ndarray):
indices_values = indices_values.tolist()
indices_shape = indices.out_shapes[0] indices_shape = indices.out_shapes[0]
val_x_shape = val_x.out_shapes[0]
axis = node.get_attr('axis', 0) axis = node.get_attr('axis', 0)
#assert len( if len(indices_shape) == 1 or \
# indices_shape) <= 2, "Gather op don't support dim of indice >2 " (indices_values is not None and isinstance(indices_values, int)) or \
if axis == 0 and len(indices_shape) <= 1: (indices_values is not None and len(indices_values) == 1):
if len(val_x.out_shapes[0]) <= 1:
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif len(val_x.out_shapes[0]) > 1:
if len(indices_shape) == 0:
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": indices.name},
outputs=[indices.name],
shape=[-1, ])
gather_ = node.name + '_1'
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[gather_])
self.paddle_graph.add_layer(
'paddle.squeeze',
inputs={'x': gather_},
outputs=[node.name],
axis=[0])
else:
self.paddle_graph.add_layer(
'paddle.gather',
inputs={'x': val_x.name,
'index': indices.name},
outputs=[node.name])
elif axis > 0 and len(indices_shape) <= 1:
perm = list(range(len(val_x.out_shapes[0])))
perm = [axis] + perm[:axis] + perm[axis + 1:]
name_trans = val_x.name + '_trans'
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": val_x.name},
outputs=[name_trans],
perm=perm)
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', 'paddle.gather',
inputs={'x': name_trans, inputs={'x': val_x.name,
'index': indices.name}, 'index': indices.name},
outputs=[node.name])
new_perm = [0] * len(perm)
for i in range(len(perm)):
new_perm[perm[i]] = i
self.paddle_graph.add_layer(
'paddle.transpose',
inputs={"x": node.name},
outputs=[node.name], outputs=[node.name],
perm=new_perm) axis=axis)
if len(indices_shape) < 1: else:
self.paddle_graph.add_layer( # if val_x is DataNode, convert gather to embedding
'paddle.squeeze', if axis == 0 and isinstance(val_x, ONNXGraphDataNode):
inputs={'x': node.name},
outputs=[node.name],
axis=[axis])
elif axis == 0 and len(indices_shape) > 1:
if val_x.out_shapes[0] is not None and isinstance(
val_x, ONNXGraphDataNode):
indices_cast = indices.name + '_cast' indices_cast = indices.name + '_cast'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.cast', 'paddle.cast',
...@@ -973,79 +928,68 @@ class OpSet9(): ...@@ -973,79 +928,68 @@ class OpSet9():
'paddle.nn.Embedding', 'paddle.nn.Embedding',
inputs={"x": indices_cast}, inputs={"x": indices_cast},
outputs=layer_outputs, outputs=layer_outputs,
num_embeddings=val_x.out_shapes[0][0], num_embeddings=val_x_shape[0],
embedding_dim=val_x.out_shapes[0][1]) embedding_dim=val_x_shape[1])
else: else:
from functools import reduce
reshape_shape = reduce(lambda x, y: x * y, indices_shape)
indices_reshape = indices.name + '_shape'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.reshape', 'paddle.reshape',
inputs={"x": indices.name}, inputs={"x": indices.name},
outputs=[indices_reshape], outputs=[indices.name + "_reshape"],
shape=[reshape_shape, ]) shape=[-1])
gather_1d = node.name + '_1D'
perm = list(range(len(val_x.out_shapes[0])))
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.gather', 'paddle.gather',
inputs={'x': val_x.name, inputs={
'index': indices_reshape}, 'x': val_x.name,
outputs=[node.name]) 'index': indices.name + "_reshape"
val_x_shape = val_x.out_shapes[0] },
reshaped_shape = [] outputs=[gather_1d],
for i in perm: axis=axis)
reshaped_shape.append(indices_shape[i]) # if shape is known
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]: if len(indices_shape) != 0 and len(val_x_shape) != 0:
reshaped_shape.append(i) self.paddle_graph.add_layer(
self.paddle_graph.add_layer( 'paddle.reshape',
'paddle.reshape', inputs={'x': gather_1d},
inputs={"x": node.name}, outputs=[node.name],
outputs=[node.name], shape=val_x_shape[:axis] + indices_shape +
shape=reshaped_shape) val_x_shape[axis + 1:])
elif axis > 0 and len(indices_shape) > 1: else:
from functools import reduce all_shape_name = list()
reshape_shape = reduce(lambda x, y: x * y, indices_shape) self.paddle_graph.add_layer(
indices_reshape = indices.name + '_shape' kernel="paddle.shape",
self.paddle_graph.add_layer( inputs={"input": val_x.name},
'paddle.reshape', outputs=[val_x.name + "_shape"])
inputs={"x": indices.name}, self.paddle_graph.add_layer(
outputs=[indices_reshape], kernel="paddle.shape",
shape=[reshape_shape, ]) inputs={"input": indices.name},
outputs=[indices.name + "_shape"])
perm = list(range(len(val_x.out_shapes[0]))) self.paddle_graph.add_layer(
perm = [axis] + perm[:axis] + perm[axis + 1:] "paddle.slice",
name_trans = val_x.name + '_transpose' inputs={"input": val_x.name + "_shape"},
self.paddle_graph.add_layer( outputs=[val_x.name + "_shape_slice_start"],
'paddle.transpose', axes=[0],
inputs={"x": val_x.name}, starts=[0],
outputs=[name_trans], ends=[axis])
perm=perm) all_shape_name.append(val_x.name + "_shape_slice_start")
self.paddle_graph.add_layer( all_shape_name.append(indices.name + "_shape")
'paddle.gather', self.paddle_graph.add_layer(
inputs={'x': name_trans, "paddle.slice",
'index': indices_reshape}, inputs={"input": val_x.name + "_shape"},
outputs=[node.name]) outputs=[val_x.name + "_shape_slice_end"],
input_transpose = node.name + '_transpose' axes=[0],
new_perm = [0] * len(perm) starts=[axis + 1],
for i in range(len(perm)): ends=[2147483647])
new_perm[perm[i]] = i all_shape_name.append(val_x.name + "_shape_slice_end")
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
'paddle.transpose', 'paddle.concat',
inputs={"x": node.name}, inputs={"x": all_shape_name},
outputs=[input_transpose], outputs=[node.name + "_all_shape"],
perm=new_perm) axis=0)
perm = new_perm self.paddle_graph.add_layer(
val_x_shape = val_x.out_shapes[0] 'paddle.reshape',
reshaped_shape = [] inputs={'x': gather_1d},
for i in perm: outputs=[node.name],
reshaped_shape.append(indices_shape[i]) shape=node.name + "_all_shape")
for i in val_x_shape[:axis] + val_x_shape[axis + 1:]:
reshaped_shape.append(i)
self.paddle_graph.add_layer(
'paddle.reshape',
inputs={"x": input_transpose},
outputs=[node.name],
shape=reshaped_shape)
@print_mapping_info @print_mapping_info
def ScatterND(self, node): def ScatterND(self, node):
...@@ -1255,16 +1199,6 @@ class OpSet9(): ...@@ -1255,16 +1199,6 @@ class OpSet9():
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def Clip(self, node): def Clip(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1420,16 +1354,6 @@ class OpSet9(): ...@@ -1420,16 +1354,6 @@ class OpSet9():
"y": val_y.name}, "y": val_y.name},
outputs=[node.name]) outputs=[node.name])
@print_mapping_info
def GatherND(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True)
val_y = self.graph.get_input_node(node, idx=1, copy=True)
self.paddle_graph.add_layer(
"paddle.gather_nd",
inputs={"x": val_x.name,
"index": val_y.name},
outputs=[node.name])
@print_mapping_info @print_mapping_info
def And(self, node): def And(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
...@@ -1710,7 +1634,8 @@ class OpSet9(): ...@@ -1710,7 +1634,8 @@ class OpSet9():
x_shape = val_x.out_shapes[0] x_shape = val_x.out_shapes[0]
y_shape = val_y.out_shapes[0] y_shape = val_y.out_shapes[0]
inputs_dict = {"x": val_x.name, "y": val_y.name} inputs_dict = {"x": val_x.name, "y": val_y.name}
if y_shape[0] == 1 and x_shape[-1] != 1 and x_shape[0] != 1: if len(y_shape) != 0 and y_shape[0] == 1 and len(
x_shape) != 0 and x_shape[-1] != 1 and x_shape[0] != 1:
y_squeeze = val_y.name + '_squeeze' y_squeeze = val_y.name + '_squeeze'
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.squeeze", "paddle.squeeze",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册