未验证 提交 40b30f50 编写于 作者: R Roc 提交者: GitHub

[AMP OP&Test] add fp16 test for linspace (#52161)

上级 73544322
...@@ -29,9 +29,10 @@ __global__ void LinspaceKernelInner( ...@@ -29,9 +29,10 @@ __global__ void LinspaceKernelInner(
for (; index < size; index += blockDim.x * gridDim.x) { for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) { if (index < size / 2) {
out[index] = static_cast<T>(start + step * index); out[index] = static_cast<T>(static_cast<double>(start) + step * index);
} else { } else {
out[index] = static_cast<T>(stop - step * (size - index - 1)); out[index] =
static_cast<T>(static_cast<double>(stop) - step * (size - index - 1));
} }
} }
} }
...@@ -111,7 +112,9 @@ PD_REGISTER_KERNEL(linspace, ...@@ -111,7 +112,9 @@ PD_REGISTER_KERNEL(linspace,
float, float,
int32_t, int32_t,
int64_t, int64_t,
double) { double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
import unittest import unittest
import numpy as np import numpy as np
from eager_op_test import OpTest, paddle_static_guard from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle import paddle
from paddle import fluid from paddle import fluid
...@@ -26,56 +26,120 @@ class TestLinspaceOpCommonCase(OpTest): ...@@ -26,56 +26,120 @@ class TestLinspaceOpCommonCase(OpTest):
def setUp(self): def setUp(self):
self.op_type = "linspace" self.op_type = "linspace"
self.python_api = paddle.linspace self.python_api = paddle.linspace
dtype = 'float32' self._set_dtype()
self._set_data()
self.attrs = {'dtype': self.attr_dtype}
def _set_dtype(self):
self.dtype = "float32"
self.attr_dtype = int(core.VarDesc.VarType.FP32)
def _set_data(self):
self.outputs = {'Out': np.arange(0, 11).astype(self.dtype)}
self.inputs = { self.inputs = {
'Start': np.array([0]).astype(dtype), 'Start': np.array([0]).astype(self.dtype),
'Stop': np.array([10]).astype(dtype), 'Stop': np.array([10]).astype(self.dtype),
'Num': np.array([11]).astype('int32'), 'Num': np.array([11]).astype('int32'),
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestLinspaceOpReverseCase(OpTest): class TestLinspaceOpReverseCase(TestLinspaceOpCommonCase):
def setUp(self): def _set_data(self):
self.op_type = "linspace"
self.python_api = paddle.linspace
dtype = 'float32'
self.inputs = { self.inputs = {
'Start': np.array([10]).astype(dtype), 'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(dtype), 'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([11]).astype('int32'), 'Num': np.array([11]).astype('int32'),
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)}
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestLinspaceOpNumOneCase(OpTest): class TestLinspaceOpNumOneCase(TestLinspaceOpCommonCase):
def setUp(self): def _set_data(self):
self.op_type = "linspace"
self.python_api = paddle.linspace
dtype = 'float32'
self.inputs = { self.inputs = {
'Start': np.array([10]).astype(dtype), 'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(dtype), 'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([1]).astype('int32'), 'Num': np.array([1]).astype('int32'),
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': np.array(10, dtype=self.dtype)}
self.outputs = {'Out': np.array(10, dtype=dtype)}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
class TestLinspaceOpCommonCaseFP16(TestLinspaceOpCommonCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
class TestLinspaceOpReverseCaseFP16(TestLinspaceOpReverseCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
class TestLinspaceOpNumOneCaseFP16(TestLinspaceOpNumOneCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
'not supported bf16',
)
class TestLinspaceOpCommonCaseBF16(TestLinspaceOpCommonCaseFP16):
def _set_dtype(self):
self.dtype = np.uint16
self.attr_dtype = int(core.VarDesc.VarType.BF16)
def _set_data(self):
self.outputs = {
'Out': convert_float_to_uint16(np.arange(0, 11).astype("float32"))
}
self.inputs = {
'Start': convert_float_to_uint16(np.array([0]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([10]).astype("float32")),
'Num': np.array([11]).astype('int32'),
}
def test_check_output(self):
return self.check_output_with_place(core.CUDAPlace(0))
class TestLinspaceOpReverseCaseBF16(TestLinspaceOpCommonCaseBF16):
def _set_data(self):
self.inputs = {
'Start': convert_float_to_uint16(np.array([10]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")),
'Num': np.array([11]).astype('int32'),
}
self.outputs = {
'Out': convert_float_to_uint16(
np.arange(10, -1, -1).astype("float32")
)
}
class TestLinspaceOpNumOneCaseBF16(TestLinspaceOpCommonCaseBF16):
def _set_data(self):
self.inputs = {
'Start': convert_float_to_uint16(np.array([10]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")),
'Num': np.array([1]).astype('int32'),
}
self.outputs = {
'Out': convert_float_to_uint16(np.array(10, dtype="float32"))
}
class TestLinspaceAPI(unittest.TestCase): class TestLinspaceAPI(unittest.TestCase):
def test_variable_input1(self): def test_variable_input1(self):
with paddle_static_guard(): with paddle_static_guard():
......
...@@ -332,7 +332,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -332,7 +332,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype( check_dtype(
start.dtype, start.dtype,
'start', 'start',
['float32', 'float64', 'int32', 'int64'], ['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
'linspace', 'linspace',
) )
else: else:
...@@ -342,7 +342,7 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -342,7 +342,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype( check_dtype(
stop.dtype, stop.dtype,
'stop', 'stop',
['float32', 'float64', 'int32', 'int64'], ['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
'linspace', 'linspace',
) )
else: else:
...@@ -350,7 +350,10 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -350,7 +350,10 @@ def linspace(start, stop, num, dtype=None, name=None):
if isinstance(num, Variable): if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace') check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype( check_dtype(
dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], 'linspace' dtype,
'dtype',
['int32', 'int64', 'float32', 'float64', 'float16', 'bfloat16'],
'linspace',
) )
if ( if (
(stop_dtype == "float64" or start_dtype == "float64") (stop_dtype == "float64" or start_dtype == "float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册