未验证 提交 b10e4577 编写于 作者: W WJJ1995 提交者: GitHub

[AMP OP&Test]Add fp16/bf16 support logical op (#52112)

* fixed glog

* add

* add bfloat16 test for logical op

* rm useless code

* add uint16

* deal with comments

* fixed code style

* fixed code style

* fixed for ci

* deal with comments

* fixed for ci
上级 d95eaa17
文件模式从 100644 更改为 100755
...@@ -77,14 +77,15 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) { ...@@ -77,14 +77,15 @@ PD_REGISTER_KERNEL(logical_xor, KPS, ALL_LAYOUT, phi::LogicalXorKernel, int) {
kernel->OutputAt(0).SetDataType(phi::DataType::BOOL); kernel->OutputAt(0).SetDataType(phi::DataType::BOOL);
} }
#else #else
using float16 = phi::dtype::float16;
#define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \ #define REGISTER_LOGICAL_CUDA_KERNEL(logical_and, func_type) \
PD_REGISTER_KERNEL(logical_and, \ PD_REGISTER_KERNEL(logical_and, \
KPS, \ KPS, \
ALL_LAYOUT, \ ALL_LAYOUT, \
phi::Logical##func_type##Kernel, \ phi::Logical##func_type##Kernel, \
float, \ float, \
float16, \ phi::dtype::float16, \
phi::dtype::bfloat16, \
double, \ double, \
bool, \ bool, \
int64_t, \ int64_t, \
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import convert_float_to_uint16
import paddle import paddle
from paddle.framework import _non_static_mode from paddle.framework import _non_static_mode
...@@ -24,6 +25,7 @@ SUPPORTED_DTYPES = [ ...@@ -24,6 +25,7 @@ SUPPORTED_DTYPES = [
bool, bool,
np.int8, np.int8,
np.int16, np.int16,
np.uint16,
np.int32, np.int32,
np.int64, np.int64,
np.float16, np.float16,
...@@ -119,6 +121,9 @@ def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True): ...@@ -119,6 +121,9 @@ def run_eager(x_np, y_np, op_str, use_gpu=False, binary_op=True):
def np_data_generator(np_shape, dtype, *args, **kwargs): def np_data_generator(np_shape, dtype, *args, **kwargs):
if dtype == bool: if dtype == bool:
return np.random.choice(a=[True, False], size=np_shape).astype(bool) return np.random.choice(a=[True, False], size=np_shape).astype(bool)
elif dtype == np.uint16:
x = np.random.uniform(0.0, 1.0, np_shape).astype(np.float32)
return convert_float_to_uint16(x)
else: else:
return np.random.normal(0, 1, np_shape).astype(dtype) return np.random.normal(0, 1, np_shape).astype(dtype)
...@@ -133,9 +138,8 @@ def test(unit_test, use_gpu=False, test_error=False): ...@@ -133,9 +138,8 @@ def test(unit_test, use_gpu=False, test_error=False):
META_DATA = dict(TEST_META_WRONG_SHAPE_DATA) META_DATA = dict(TEST_META_WRONG_SHAPE_DATA)
for shape_data in META_DATA.values(): for shape_data in META_DATA.values():
for data_type in SUPPORTED_DTYPES: for data_type in SUPPORTED_DTYPES:
if ( if not (paddle.is_compiled_with_cuda() and use_gpu) and (
not (paddle.is_compiled_with_cuda() and use_gpu) data_type in [np.float16, np.uint16]
and data_type == np.float16
): ):
continue continue
meta_data['x_np'] = np_data_generator( meta_data['x_np'] = np_data_generator(
...@@ -194,8 +198,8 @@ def test_type_error(unit_test, use_gpu, type_str_map): ...@@ -194,8 +198,8 @@ def test_type_error(unit_test, use_gpu, type_str_map):
paddle.is_compiled_with_cuda() paddle.is_compiled_with_cuda()
and use_gpu and use_gpu
and ( and (
type_str_map['x'] == np.float16 type_str_map['x'] in [np.float16, np.uint16]
or type_str_map['y'] == np.float16 or type_str_map['y'] in [np.float16, np.uint16]
) )
): ):
continue continue
......
...@@ -50,6 +50,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): ...@@ -50,6 +50,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
"float16", "float16",
"float32", "float32",
"float64", "float64",
"uint16",
], ],
op_name, op_name,
) )
...@@ -66,6 +67,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True): ...@@ -66,6 +67,7 @@ def _logical_op(op_name, x, y, out=None, name=None, binary_op=True):
"float16", "float16",
"float32", "float32",
"float64", "float64",
"uint16",
], ],
op_name, op_name,
) )
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册