未验证 提交 cd88156a 编写于 作者: 骑马小猫 提交者: GitHub

[Bug fixes] enable two ops to support bf16 in llama model (#53026)

上级 c59debe2
...@@ -965,7 +965,7 @@ def silu(x, name=None): ...@@ -965,7 +965,7 @@ def silu(x, name=None):
Where :math:`x` is the input Tensor. Where :math:`x` is the input Tensor.
Parameters: Parameters:
x (Tensor): The input Tensor with data type float32, float64. x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None. name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns: Returns:
......
...@@ -1899,7 +1899,7 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1899,7 +1899,7 @@ def split(x, num_or_sections, axis=0, name=None):
Split the input tensor into multiple sub-Tensors. Split the input tensor into multiple sub-Tensors.
Args: Args:
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. x (Tensor): A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
...@@ -1970,6 +1970,7 @@ def split(x, num_or_sections, axis=0, name=None): ...@@ -1970,6 +1970,7 @@ def split(x, num_or_sections, axis=0, name=None):
'input', 'input',
[ [
'bool', 'bool',
'bfloat16',
'float16', 'float16',
'uint16', 'uint16',
'float32', 'float32',
...@@ -2546,7 +2547,7 @@ def unsqueeze(x, axis, name=None): ...@@ -2546,7 +2547,7 @@ def unsqueeze(x, axis, name=None):
please use `Tensor.clone` like ``unsqueeze_clone_x = x.unsqueeze(-1).clone()``. please use `Tensor.clone` like ``unsqueeze_clone_x = x.unsqueeze(-1).clone()``.
Args: Args:
x (Tensor): The input Tensor to be unsqueezed. Supported data type: float32, float64, bool, int8, int32, int64. x (Tensor): The input Tensor to be unsqueezed. Supported data type: bfloat16, float16, float32, float64, bool, int8, int32, int64.
axis (int|list|tuple|Tensor): Indicates the dimensions to be inserted. The data type is ``int32`` . axis (int|list|tuple|Tensor): Indicates the dimensions to be inserted. The data type is ``int32`` .
If ``axis`` is a list or tuple, the elements of it should be integers or Tensors with shape [1]. If ``axis`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``axis`` is a Tensor, it should be an 1-D Tensor . If ``axis`` is a Tensor, it should be an 1-D Tensor .
...@@ -2600,6 +2601,7 @@ def unsqueeze(x, axis, name=None): ...@@ -2600,6 +2601,7 @@ def unsqueeze(x, axis, name=None):
input, input,
'input', 'input',
[ [
'uint16',
'float16', 'float16',
'uint16', 'uint16',
'float32', 'float32',
......
...@@ -499,6 +499,7 @@ def _elementwise_op(helper): ...@@ -499,6 +499,7 @@ def _elementwise_op(helper):
"elementwise_sub", "elementwise_sub",
"elementwise_mul", "elementwise_mul",
"elementwise_div", "elementwise_div",
"elementwise_max",
] ]
if original_op_type in bf16_and_complex_supported_ops: if original_op_type in bf16_and_complex_supported_ops:
data_type = [ data_type = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册