未验证 提交 40b30f50 编写于 作者: R Roc 提交者: GitHub

[AMP OP&Test] add fp16 test for linspace (#52161)

上级 73544322
......@@ -29,9 +29,10 @@ __global__ void LinspaceKernelInner(
for (; index < size; index += blockDim.x * gridDim.x) {
if (index < size / 2) {
out[index] = static_cast<T>(start + step * index);
out[index] = static_cast<T>(static_cast<double>(start) + step * index);
} else {
out[index] = static_cast<T>(stop - step * (size - index - 1));
out[index] =
static_cast<T>(static_cast<double>(stop) - step * (size - index - 1));
}
}
}
......@@ -111,7 +112,9 @@ PD_REGISTER_KERNEL(linspace,
float,
int32_t,
int64_t,
double) {
double,
phi::dtype::float16,
phi::dtype::bfloat16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->InputAt(2).SetBackend(phi::Backend::ALL_BACKEND);
......
......@@ -15,7 +15,7 @@
import unittest
import numpy as np
from eager_op_test import OpTest, paddle_static_guard
from eager_op_test import OpTest, convert_float_to_uint16, paddle_static_guard
import paddle
from paddle import fluid
......@@ -26,56 +26,120 @@ class TestLinspaceOpCommonCase(OpTest):
def setUp(self):
self.op_type = "linspace"
self.python_api = paddle.linspace
dtype = 'float32'
self._set_dtype()
self._set_data()
self.attrs = {'dtype': self.attr_dtype}
def _set_dtype(self):
self.dtype = "float32"
self.attr_dtype = int(core.VarDesc.VarType.FP32)
def _set_data(self):
self.outputs = {'Out': np.arange(0, 11).astype(self.dtype)}
self.inputs = {
'Start': np.array([0]).astype(dtype),
'Stop': np.array([10]).astype(dtype),
'Start': np.array([0]).astype(self.dtype),
'Stop': np.array([10]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
def test_check_output(self):
self.check_output()
class TestLinspaceOpReverseCase(OpTest):
def setUp(self):
self.op_type = "linspace"
self.python_api = paddle.linspace
dtype = 'float32'
class TestLinspaceOpReverseCase(TestLinspaceOpCommonCase):
def _set_data(self):
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([11]).astype('int32'),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
self.outputs = {'Out': np.arange(10, -1, -1).astype(self.dtype)}
def test_check_output(self):
self.check_output()
class TestLinspaceOpNumOneCase(OpTest):
def setUp(self):
self.op_type = "linspace"
self.python_api = paddle.linspace
dtype = 'float32'
class TestLinspaceOpNumOneCase(TestLinspaceOpCommonCase):
def _set_data(self):
self.inputs = {
'Start': np.array([10]).astype(dtype),
'Stop': np.array([0]).astype(dtype),
'Start': np.array([10]).astype(self.dtype),
'Stop': np.array([0]).astype(self.dtype),
'Num': np.array([1]).astype('int32'),
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.array(10, dtype=dtype)}
self.outputs = {'Out': np.array(10, dtype=self.dtype)}
def test_check_output(self):
self.check_output()
class TestLinspaceOpCommonCaseFP16(TestLinspaceOpCommonCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
class TestLinspaceOpReverseCaseFP16(TestLinspaceOpReverseCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
class TestLinspaceOpNumOneCaseFP16(TestLinspaceOpNumOneCase):
def _set_dtype(self):
self.dtype = np.float16
self.attr_dtype = int(core.VarDesc.VarType.FP16)
@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
'not supported bf16',
)
class TestLinspaceOpCommonCaseBF16(TestLinspaceOpCommonCaseFP16):
def _set_dtype(self):
self.dtype = np.uint16
self.attr_dtype = int(core.VarDesc.VarType.BF16)
def _set_data(self):
self.outputs = {
'Out': convert_float_to_uint16(np.arange(0, 11).astype("float32"))
}
self.inputs = {
'Start': convert_float_to_uint16(np.array([0]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([10]).astype("float32")),
'Num': np.array([11]).astype('int32'),
}
def test_check_output(self):
return self.check_output_with_place(core.CUDAPlace(0))
class TestLinspaceOpReverseCaseBF16(TestLinspaceOpCommonCaseBF16):
def _set_data(self):
self.inputs = {
'Start': convert_float_to_uint16(np.array([10]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")),
'Num': np.array([11]).astype('int32'),
}
self.outputs = {
'Out': convert_float_to_uint16(
np.arange(10, -1, -1).astype("float32")
)
}
class TestLinspaceOpNumOneCaseBF16(TestLinspaceOpCommonCaseBF16):
def _set_data(self):
self.inputs = {
'Start': convert_float_to_uint16(np.array([10]).astype("float32")),
'Stop': convert_float_to_uint16(np.array([0]).astype("float32")),
'Num': np.array([1]).astype('int32'),
}
self.outputs = {
'Out': convert_float_to_uint16(np.array(10, dtype="float32"))
}
class TestLinspaceAPI(unittest.TestCase):
def test_variable_input1(self):
with paddle_static_guard():
......
......@@ -332,7 +332,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype(
start.dtype,
'start',
['float32', 'float64', 'int32', 'int64'],
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
'linspace',
)
else:
......@@ -342,7 +342,7 @@ def linspace(start, stop, num, dtype=None, name=None):
check_dtype(
stop.dtype,
'stop',
['float32', 'float64', 'int32', 'int64'],
['float32', 'float64', 'int32', 'int64', 'float16', 'bfloat16'],
'linspace',
)
else:
......@@ -350,7 +350,10 @@ def linspace(start, stop, num, dtype=None, name=None):
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace')
check_dtype(
dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], 'linspace'
dtype,
'dtype',
['int32', 'int64', 'float32', 'float64', 'float16', 'bfloat16'],
'linspace',
)
if (
(stop_dtype == "float64" or start_dtype == "float64")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册