未验证 提交 c1fba5c1 编写于 作者: W WJJ1995 提交者: GitHub

Support gpt2 (#796)

* Support bigbird model

* Support GPT2

* rm useless code

* deal with comments
上级 e85f6924
...@@ -1093,6 +1093,12 @@ class OpSet9(): ...@@ -1093,6 +1093,12 @@ class OpSet9():
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
starts, ends, axes, steps = None, None, None, None starts, ends, axes, steps = None, None, None, None
layer_attrs = {} layer_attrs = {}
if val_x.dtype == 'uint8':
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": val_x.name},
outputs=[val_x.name],
dtype=string('int32'))
if len(node.inputs) > 1: if len(node.inputs) > 1:
starts = self.graph.get_input_node(node, idx=1, copy=True) starts = self.graph.get_input_node(node, idx=1, copy=True)
ends = self.graph.get_input_node(node, idx=2, copy=True) ends = self.graph.get_input_node(node, idx=2, copy=True)
...@@ -1121,8 +1127,9 @@ class OpSet9(): ...@@ -1121,8 +1127,9 @@ class OpSet9():
starts_value = starts_value.copy() starts_value = starts_value.copy()
ends_value = ends_value.copy() ends_value = ends_value.copy()
for idx in range(len(ends_value)): for idx in range(len(ends_value)):
if starts_value[idx] >= val_x.out_shapes[0][axes[ if len(val_x.out_shapes[0]) != 0 and starts_value[
idx]] and val_x.out_shapes[0][axes[idx]] > 0: idx] >= val_x.out_shapes[0][axes[
idx]] and val_x.out_shapes[0][axes[idx]] > 0:
starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1 starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1
ends_value[idx] = val_x.out_shapes[0][axes[idx]] ends_value[idx] = val_x.out_shapes[0][axes[idx]]
elif ends_value[idx] > 2**31 - 1: elif ends_value[idx] > 2**31 - 1:
...@@ -1178,6 +1185,12 @@ class OpSet9(): ...@@ -1178,6 +1185,12 @@ class OpSet9():
inputs={"input": val_x.name}, inputs={"input": val_x.name},
outputs=[node.name], outputs=[node.name],
**layer_attrs) **layer_attrs)
if val_x.dtype == 'uint8':
self.paddle_graph.add_layer(
'paddle.cast',
inputs={"x": node.name},
outputs=[node.name],
dtype=string('uint8'))
@print_mapping_info @print_mapping_info
def ConstantOfShape(self, node): def ConstantOfShape(self, node):
...@@ -1790,7 +1803,11 @@ class OpSet9(): ...@@ -1790,7 +1803,11 @@ class OpSet9():
def Squeeze(self, node): def Squeeze(self, node):
val_x = self.graph.get_input_node(node, idx=0, copy=True) val_x = self.graph.get_input_node(node, idx=0, copy=True)
axes = node.get_attr('axes') axes = node.get_attr('axes')
if len(val_x.out_shapes[0]) == 1: if axes is None:
axes_node = self.graph.get_input_node(node, idx=1, copy=True)
axes = _const_weight_or_none(axes_node, necessary=True)
# deal with scalar(0D) tensor
if len(val_x.out_shapes[0]) <= 1 and len(axes) == 1 and axes[0] == 0:
self.paddle_graph.add_layer( self.paddle_graph.add_layer(
"paddle.cast", "paddle.cast",
inputs={"x": val_x.name}, inputs={"x": val_x.name},
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册