diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index a4400d6272f9ef517a9a57595abcc3e5fd6ec1de..5747de7ddd2d492cc509aeec8fa97f8635fef3ab 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -1160,19 +1160,22 @@ def grad_var_name(var_name): def convert_np_dtype_to_dtype_(np_dtype): """ - Convert the data type in numpy to the data type in Paddle + Convert the data type in numpy to the data type in Paddle. Args: - np_dtype(np.dtype): the data type in numpy. + np_dtype (np.dtype|str): The data type in numpy or valid data type + string. Returns: - core.VarDesc.VarType: the data type in Paddle. + core.VarDesc.VarType: The data type in Paddle. """ - if np_dtype == "bfloat16": + # Convert the data type string to numpy data type. + if isinstance(np_dtype, str) and np_dtype == "bfloat16": dtype = np.uint16 else: dtype = np.dtype(np_dtype) + if dtype == np.float32: return core.VarDesc.VarType.FP32 elif dtype == np.float64: diff --git a/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py index f62c5b47d5ab0cc89c4ec6e17e535b8d5e5750ce..2bd374fe6d0e7acdb6a327c11ffc51ea39018660 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_tril_triu_op_npu.py @@ -135,7 +135,7 @@ for _op_type in ['tril', 'triu']: class TestTrilTriuOpAPI(unittest.TestCase): - """ test case by using API and has -1 dimension + """ test case by using API and has -1 dimension """ def test_api(self):