From c1fba5c1e56c1e5d17ef5c60747dec713b7ae3c5 Mon Sep 17 00:00:00 2001 From: WJJ1995 Date: Tue, 31 May 2022 21:39:34 +0800 Subject: [PATCH] Support gpt2 (#796) * Support bigbird model * Support GPT2 * rm useless code * deal with comments --- .../op_mapper/onnx2paddle/opset9/opset.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py index 3e7911c..58bc0a5 100755 --- a/x2paddle/op_mapper/onnx2paddle/opset9/opset.py +++ b/x2paddle/op_mapper/onnx2paddle/opset9/opset.py @@ -1093,6 +1093,12 @@ class OpSet9(): val_x = self.graph.get_input_node(node, idx=0, copy=True) starts, ends, axes, steps = None, None, None, None layer_attrs = {} + if val_x.dtype == 'uint8': + self.paddle_graph.add_layer( + 'paddle.cast', + inputs={"x": val_x.name}, + outputs=[val_x.name], + dtype=string('int32')) if len(node.inputs) > 1: starts = self.graph.get_input_node(node, idx=1, copy=True) ends = self.graph.get_input_node(node, idx=2, copy=True) @@ -1121,8 +1127,9 @@ class OpSet9(): starts_value = starts_value.copy() ends_value = ends_value.copy() for idx in range(len(ends_value)): - if starts_value[idx] >= val_x.out_shapes[0][axes[ - idx]] and val_x.out_shapes[0][axes[idx]] > 0: + if len(val_x.out_shapes[0]) != 0 and starts_value[ + idx] >= val_x.out_shapes[0][axes[ + idx]] and val_x.out_shapes[0][axes[idx]] > 0: starts_value[idx] = val_x.out_shapes[0][axes[idx]] - 1 ends_value[idx] = val_x.out_shapes[0][axes[idx]] elif ends_value[idx] > 2**31 - 1: @@ -1178,6 +1185,12 @@ class OpSet9(): inputs={"input": val_x.name}, outputs=[node.name], **layer_attrs) + if val_x.dtype == 'uint8': + self.paddle_graph.add_layer( + 'paddle.cast', + inputs={"x": node.name}, + outputs=[node.name], + dtype=string('uint8')) @print_mapping_info def ConstantOfShape(self, node): @@ -1790,7 +1803,11 @@ class OpSet9(): def Squeeze(self, node): val_x = self.graph.get_input_node(node, idx=0, copy=True) axes = node.get_attr('axes') - if len(val_x.out_shapes[0]) == 1: + if axes is None: + axes_node = self.graph.get_input_node(node, idx=1, copy=True) + axes = _const_weight_or_none(axes_node, necessary=True) + # deal with scalar(0D) tensor + if len(val_x.out_shapes[0]) <= 1 and len(axes) == 1 and axes[0] == 0: self.paddle_graph.add_layer( "paddle.cast", inputs={"x": val_x.name}, -- GitLab