未验证 提交 5022dd9b 编写于 作者: N Nyakku Shigure 提交者: GitHub

[CodeStyle][W291] trim trailing whitespace in NPU unittest file (#46042)

上级 710efdae
...@@ -1160,19 +1160,22 @@ def grad_var_name(var_name): ...@@ -1160,19 +1160,22 @@ def grad_var_name(var_name):
def convert_np_dtype_to_dtype_(np_dtype): def convert_np_dtype_to_dtype_(np_dtype):
""" """
Convert the data type in numpy to the data type in Paddle Convert the data type in numpy to the data type in Paddle.
Args: Args:
np_dtype(np.dtype): the data type in numpy. np_dtype (np.dtype|str): The data type in numpy or valid data type
string.
Returns: Returns:
core.VarDesc.VarType: the data type in Paddle. core.VarDesc.VarType: The data type in Paddle.
""" """
if np_dtype == "bfloat16": # Convert the data type string to numpy data type.
if isinstance(np_dtype, str) and np_dtype == "bfloat16":
dtype = np.uint16 dtype = np.uint16
else: else:
dtype = np.dtype(np_dtype) dtype = np.dtype(np_dtype)
if dtype == np.float32: if dtype == np.float32:
return core.VarDesc.VarType.FP32 return core.VarDesc.VarType.FP32
elif dtype == np.float64: elif dtype == np.float64:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册