未验证 提交 cd88156a 编写于 作者: 骑马小猫 提交者: GitHub

[Bug fixes] enable two ops to support bf16 in llama model (#53026)

上级 c59debe2
......@@ -965,7 +965,7 @@ def silu(x, name=None):
Where :math:`x` is the input Tensor.
Parameters:
x (Tensor): The input Tensor with data type float32, float64.
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
......
......@@ -1899,7 +1899,7 @@ def split(x, num_or_sections, axis=0, name=None):
Split the input tensor into multiple sub-Tensors.
Args:
x (Tensor): A N-D Tensor. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64.
x (Tensor): A N-D Tensor. The data type is bool, bfloat16, float16, float32, float64, uint8, int8, int32 or int64.
num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections``
indicates the number of equal sized sub-Tensors that the ``x`` will be divided into.
If ``num_or_sections`` is a list or tuple, the length of it indicates the number of
......@@ -1970,6 +1970,7 @@ def split(x, num_or_sections, axis=0, name=None):
'input',
[
'bool',
'bfloat16',
'float16',
'uint16',
'float32',
......@@ -2546,7 +2547,7 @@ def unsqueeze(x, axis, name=None):
please use `Tensor.clone` like ``unsqueeze_clone_x = x.unsqueeze(-1).clone()``.
Args:
x (Tensor): The input Tensor to be unsqueezed. Supported data type: float32, float64, bool, int8, int32, int64.
x (Tensor): The input Tensor to be unsqueezed. Supported data type: bfloat16, float16, float32, float64, bool, int8, int32, int64.
axis (int|list|tuple|Tensor): Indicates the dimensions to be inserted. The data type is ``int32`` .
If ``axis`` is a list or tuple, the elements of it should be integers or Tensors with shape [1].
If ``axis`` is a Tensor, it should be an 1-D Tensor .
......@@ -2600,6 +2601,7 @@ def unsqueeze(x, axis, name=None):
input,
'input',
[
'uint16',
'float16',
'uint16',
'float32',
......
......@@ -499,6 +499,7 @@ def _elementwise_op(helper):
"elementwise_sub",
"elementwise_mul",
"elementwise_div",
"elementwise_max",
]
if original_op_type in bf16_and_complex_supported_ops:
data_type = [
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册