未验证 提交 2a0bd17c 编写于 作者: FormlessUnit's avatar FormlessUnit 提交者: GitHub

fill_constant_batch_size_like support bf16 (#51396)

shape support bf16
上级 b2385821
......@@ -59,7 +59,8 @@ PD_REGISTER_KERNEL(full_batch_size_like,
int,
int64_t,
bool,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
......@@ -65,7 +65,8 @@ PD_REGISTER_KERNEL(shape,
double,
phi::dtype::complex<float>,
phi::dtype::complex<double>,
phi::dtype::float16) {
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::INT32);
......
......@@ -15,9 +15,10 @@
import unittest
import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16
import paddle
import paddle.fluid.core as core
from paddle.fluid.framework import convert_np_dtype_to_dtype_
paddle.enable_static()
......@@ -70,5 +71,46 @@ class TestFillConstatnBatchSizeLike1(OpTest):
self.check_output()
@unittest.skipIf(
not core.is_compiled_with_cuda() or not core.supports_bfloat16(),
"core is not compiled with CUDA or place do not support bfloat16",
)
class TestFillConstatnBatchSizeLikeBf16(OpTest):
# test bf16
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_data()
input = np.zeros(self.shape).astype("float32")
input_bf16 = convert_float_to_uint16(input)
out = np.full_like(input, self.value, np.float32)
out_bf16 = convert_float_to_uint16(out)
self.inputs = {'Input': input_bf16}
self.outputs = {'Out': out_bf16}
self.attrs = {
'shape': self.shape,
'dtype': convert_np_dtype_to_dtype_(self.dtype),
'value': self.value,
'input_dim_idx': self.input_dim_idx,
'output_dim_idx': self.output_dim_idx,
'force_cpu': self.force_cpu,
}
def init_data(self):
self.shape = [10, 10]
self.dtype = np.uint16
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
self.force_cpu = False
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16
import paddle
from paddle.fluid import core
......@@ -85,5 +85,29 @@ class TestShapeWithSelectedRows(unittest.TestCase):
self.check_with_place(place)
@unittest.skipIf(
not core.is_compiled_with_cuda() or not core.supports_bfloat16(),
"core is not compiled with CUDA or place do not support bfloat16",
)
class TestShapeOpBf16(OpTest):
def setUp(self):
self.op_type = "shape"
self.dtype = 'bfloat16'
self.python_api = paddle.shape
self.config()
self.shape = [2, 3]
input = np.zeros(self.shape)
input = convert_float_to_uint16(input.astype('float32'))
self.inputs = {'Input': input}
self.outputs = {'Out': np.array(self.shape)}
def config(self):
self.shape = [2, 3]
def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册