未验证 提交 2073ffc0 编写于 作者: W wangchaochaohu 提交者: GitHub

Enhance the data type of linspace API (#26247)

上级 bb7fd097
......@@ -53,11 +53,9 @@ class LinspaceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Start"),
ctx.device_context(), layout_, library_);
framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
......@@ -73,6 +71,7 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Num",
"Number of entry in the sequence. It is a tensor of shape [1], "
"should be of type int32.");
AddAttr<int>("dtype", "The output data type.");
AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC(
Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy.
......@@ -85,4 +84,6 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker);
REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>,
ops::CPULinspaceKernel<int32_t>,
ops::CPULinspaceKernel<int64_t>,
ops::CPULinspaceKernel<double>);
......@@ -20,13 +20,15 @@ namespace paddle {
namespace operators {
template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; }
__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) {
out[index] = static_cast<T>(start + step * index);
}
}
template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = start;
out[0] = static_cast<T>(start);
}
template <typename T>
......@@ -51,9 +53,9 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace());
T step = 0;
double step = 0;
if (num != 1) {
step = (stop - start) / (num - 1);
step = (static_cast<double>(stop - start)) / (num - 1);
}
auto stream = context.cuda_device_context().stream();
......@@ -68,4 +70,6 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
ops::CUDALinspaceKernel<int32_t>,
ops::CUDALinspaceKernel<int64_t>,
ops::CUDALinspaceKernel<double>);
......@@ -35,14 +35,12 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) {
T step = (stop - start) / (num - 1);
T value = start;
double step = (static_cast<double>(stop - start)) / (num - 1);
for (int i = 0; i < num; ++i) {
out_data[i] = value;
value += step;
out_data[i] = static_cast<T>(start + step * i);
}
} else {
out_data[0] = start;
out_data[0] = static_cast<T>(start);
}
}
};
......
......@@ -1435,14 +1435,14 @@ def linspace(start, stop, num, dtype=None, name=None):
This OP return fixed number of evenly spaced values within a given interval.
Args:
start(float|Tensor): The input :attr:`start` is start variable of range. It is a float scalar, \
or a Tensor of shape [1] with input data type float32, float64.
stop(float|Tensor): The input :attr:`stop` is start variable of range. It is a float scalar, \
or a Tensor of shape [1] with input data type float32, float64.
start(int|float|Tensor): The input :attr:`start` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
or a Tensor of shape [1] with data type int32.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be 'float32' and 'float64'.
Default: if None, the data type is float32.
or a Tensor of shape [1] with data type int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be
int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.Default: None.
......@@ -1452,9 +1452,11 @@ def linspace(start, stop, num, dtype=None, name=None):
the value with input :attr:`start`.
Raises:
TypeError: The ``dtype`` must be one of float32 and float64.
TypeError: The data type of ``start`` and ``stop`` must be one of float32 and float64.
TypeError: The data type of ``num`` must be one of int32 and int64.
TypeError: The ``dtype`` must be one of int32, int64, float32 and float64.
TypeError: The type of ``num`` must be int When it's not a Tensor.
TypeError: The data type of ``num`` must be int32 When it's a Tensor.
TypeError: The data type of ``start`` and ``stop`` must be same as ``dtype`` When it's a Tensor.
Examples:
......@@ -1467,29 +1469,47 @@ def linspace(start, stop, num, dtype=None, name=None):
"""
if dtype is None:
dtype = 'float32'
tensor_num = num
tensor_start = start
tensor_stop = stop
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable):
start = fill_constant([1], dtype, start)
tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable):
stop = fill_constant([1], dtype, stop)
tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable):
num = fill_constant([1], 'int32', num)
tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode():
return core.ops.linspace(start, stop, num)
return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
helper = LayerHelper("linspace", **locals())
check_dtype(start.dtype, 'start', ['float32', 'float64'], 'linspace')
check_dtype(stop.dtype, 'stop', ['float32', 'float64'], 'linspace')
check_dtype(num.dtype, 'num', ['int32', 'int64'], 'linspace')
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'linspace')
if isinstance(start, Variable):
check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace')
else:
check_type(start, 'start', (int, float), 'linspace')
out = helper.create_variable_for_type_inference(dtype=start.dtype)
if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace')
else:
check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace')
else:
check_type(num, 'num', (int), 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'linspace')
out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op(
type='linspace',
inputs={'Start': start,
'Stop': stop,
'Num': num},
inputs={'Start': tensor_start,
'Stop': tensor_stop,
'Num': tensor_num},
attrs={'dtype': dtype},
outputs={'Out': [out]})
return out
......
......@@ -32,6 +32,7 @@ class TestLinspaceOpCommonCase(OpTest):
'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32')
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
......@@ -48,6 +49,7 @@ class TestLinspaceOpReverseCase(OpTest):
'Stop': np.array([0]).astype(dtype),
'Num': np.array([11]).astype('int32')
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
......@@ -64,6 +66,7 @@ class TestLinspaceOpNumOneCase(OpTest):
'Stop': np.array([0]).astype(dtype),
'Num': np.array([1]).astype('int32')
}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.array(10, dtype=dtype)}
......@@ -72,6 +75,26 @@ class TestLinspaceOpNumOneCase(OpTest):
class TestLinspaceAPI(unittest.TestCase):
def test_variable_input1(self):
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
out = paddle.linspace(start, stop, num, dtype='float32')
exe = fluid.Executor(place=fluid.CPUPlace())
res = exe.run(fluid.default_main_program(), fetch_list=[out])
np_res = np.linspace(0, 10, 5, dtype='float32')
self.assertEqual((res == np_res).all(), True)
def test_variable_input2(self):
paddle.disable_static()
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
out = paddle.linspace(start, stop, num, dtype='float32')
np_res = np.linspace(0, 10, 5, dtype='float32')
self.assertEqual((out.numpy() == np_res).all(), True)
paddle.enable_static()
def test_dtype(self):
out_1 = paddle.linspace(0, 10, 5, dtype='float32')
out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
......@@ -89,10 +112,16 @@ class TestLinspaceAPI(unittest.TestCase):
def test_imperative(self):
paddle.disable_static()
out = paddle.linspace(0, 10, 5, dtype='float32')
np_out = np.linspace(0, 10, 5, dtype='float32')
out1 = paddle.linspace(0, 10, 5, dtype='float32')
np_out1 = np.linspace(0, 10, 5, dtype='float32')
out2 = paddle.linspace(0, 10, 5, dtype='int32')
np_out2 = np.linspace(0, 10, 5, dtype='int32')
out3 = paddle.linspace(0, 10, 200, dtype='int32')
np_out3 = np.linspace(0, 10, 200, dtype='int32')
paddle.enable_static()
self.assertEqual((out.numpy() == np_out).all(), True)
self.assertEqual((out1.numpy() == np_out1).all(), True)
self.assertEqual((out2.numpy() == np_out2).all(), True)
self.assertEqual((out3.numpy() == np_out3).all(), True)
class TestLinspaceOpError(unittest.TestCase):
......@@ -100,7 +129,12 @@ class TestLinspaceOpError(unittest.TestCase):
with program_guard(Program(), Program()):
def test_dtype():
fluid.layers.linspace(0, 10, 1, dtype="int32")
fluid.layers.linspace(0, 10, 1, dtype="int8")
self.assertRaises(TypeError, test_dtype)
def test_dtype():
fluid.layers.linspace(0, 10, 1.33, dtype="int32")
self.assertRaises(TypeError, test_dtype)
......@@ -120,20 +154,20 @@ class TestLinspaceOpError(unittest.TestCase):
self.assertRaises(TypeError, test_step_dtype)
def test_start_dtype():
start = fluid.data(shape=[1], type="int32", name="start")
start = fluid.data(shape=[1], dtype="int32", name="start")
fluid.layers.linspace(start, 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_dtype)
def test_end_dtype():
end = fluid.data(shape=[1], type="int32", name="end")
end = fluid.data(shape=[1], dtype="int32", name="end")
fluid.layers.linspace(0, end, 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype)
def test_step_dtype():
step = fluid.data(shape=[1], type="int32", name="step")
fluid.layers.linspace(0, 10, step, dtype="float32")
def test_num_dtype():
num = fluid.data(shape=[1], dtype="int32", name="step")
fluid.layers.linspace(0, 10, num, dtype="float32")
self.assertRaises(TypeError, test_step_dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册