未验证 提交 2073ffc0 编写于 作者: W wangchaochaohu 提交者: GitHub

Enhance the data type of linspace API (#26247)

上级 bb7fd097
...@@ -53,11 +53,9 @@ class LinspaceOp : public framework::OperatorWithKernel { ...@@ -53,11 +53,9 @@ class LinspaceOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
framework::LibraryType library_{framework::LibraryType::kPlain};
framework::DataLayout layout_ = framework::DataLayout::kAnyLayout;
return framework::OpKernelType( return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "Start"), framework::proto::VarType::Type(ctx.Attr<int>("dtype")),
ctx.device_context(), layout_, library_); ctx.GetPlace());
} }
}; };
...@@ -73,6 +71,7 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -73,6 +71,7 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Num", AddInput("Num",
"Number of entry in the sequence. It is a tensor of shape [1], " "Number of entry in the sequence. It is a tensor of shape [1], "
"should be of type int32."); "should be of type int32.");
AddAttr<int>("dtype", "The output data type.");
AddOutput("Out", "A sequence of numbers."); AddOutput("Out", "A sequence of numbers.");
AddComment(R"DOC( AddComment(R"DOC(
Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy.
...@@ -85,4 +84,6 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -85,4 +84,6 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker);
REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>, REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel<float>,
ops::CPULinspaceKernel<int32_t>,
ops::CPULinspaceKernel<int64_t>,
ops::CPULinspaceKernel<double>); ops::CPULinspaceKernel<double>);
...@@ -20,13 +20,15 @@ namespace paddle { ...@@ -20,13 +20,15 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { __global__ void LinspaceKernel(T start, double step, int64_t size, T* out) {
CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } CUDA_KERNEL_LOOP(index, size) {
out[index] = static_cast<T>(start + step * index);
}
} }
template <typename T> template <typename T>
__global__ void LinspaceSpecialKernel(T start, T* out) { __global__ void LinspaceSpecialKernel(T start, T* out) {
out[0] = start; out[0] = static_cast<T>(start);
} }
template <typename T> template <typename T>
...@@ -51,9 +53,9 @@ class CUDALinspaceKernel : public framework::OpKernel<T> { ...@@ -51,9 +53,9 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
out->Resize(framework::make_ddim({num})); out->Resize(framework::make_ddim({num}));
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
T step = 0; double step = 0;
if (num != 1) { if (num != 1) {
step = (stop - start) / (num - 1); step = (static_cast<double>(stop - start)) / (num - 1);
} }
auto stream = context.cuda_device_context().stream(); auto stream = context.cuda_device_context().stream();
...@@ -68,4 +70,6 @@ class CUDALinspaceKernel : public framework::OpKernel<T> { ...@@ -68,4 +70,6 @@ class CUDALinspaceKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>, REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel<float>,
ops::CUDALinspaceKernel<int32_t>,
ops::CUDALinspaceKernel<int64_t>,
ops::CUDALinspaceKernel<double>); ops::CUDALinspaceKernel<double>);
...@@ -35,14 +35,12 @@ class CPULinspaceKernel : public framework::OpKernel<T> { ...@@ -35,14 +35,12 @@ class CPULinspaceKernel : public framework::OpKernel<T> {
T* out_data = out->mutable_data<T>(context.GetPlace()); T* out_data = out->mutable_data<T>(context.GetPlace());
if (num > 1) { if (num > 1) {
T step = (stop - start) / (num - 1); double step = (static_cast<double>(stop - start)) / (num - 1);
T value = start;
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
out_data[i] = value; out_data[i] = static_cast<T>(start + step * i);
value += step;
} }
} else { } else {
out_data[0] = start; out_data[0] = static_cast<T>(start);
} }
} }
}; };
......
...@@ -1435,14 +1435,14 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1435,14 +1435,14 @@ def linspace(start, stop, num, dtype=None, name=None):
This OP return fixed number of evenly spaced values within a given interval. This OP return fixed number of evenly spaced values within a given interval.
Args: Args:
start(float|Tensor): The input :attr:`start` is start variable of range. It is a float scalar, \ start(int|float|Tensor): The input :attr:`start` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type float32, float64. or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
stop(float|Tensor): The input :attr:`stop` is start variable of range. It is a float scalar, \ stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \
or a Tensor of shape [1] with input data type float32, float64. or a Tensor of shape [1] with input data type int32, int64, float32 or float64.
num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \ num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \
or a Tensor of shape [1] with data type int32. or a Tensor of shape [1] with data type int32 or int64.
dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be 'float32' and 'float64'. dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be
Default: if None, the data type is float32. int32, int64, float32 and float64. Default: if None, the data type is float32.
name(str, optional): Normally there is no need for user to set this property. name(str, optional): Normally there is no need for user to set this property.
For more information, please refer to :ref:`api_guide_Name`.Default: None. For more information, please refer to :ref:`api_guide_Name`.Default: None.
...@@ -1452,9 +1452,11 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1452,9 +1452,11 @@ def linspace(start, stop, num, dtype=None, name=None):
the value with input :attr:`start`. the value with input :attr:`start`.
Raises: Raises:
TypeError: The ``dtype`` must be one of float32 and float64. TypeError: The ``dtype`` must be one of int32, int64, float32 and float64.
TypeError: The data type of ``start`` and ``stop`` must be one of float32 and float64. TypeError: The type of ``num`` must be int When it's not a Tensor.
TypeError: The data type of ``num`` must be one of int32 and int64. TypeError: The data type of ``num`` must be int32 When it's a Tensor.
TypeError: The data type of ``start`` and ``stop`` must be same as ``dtype`` When it's a Tensor.
Examples: Examples:
...@@ -1467,29 +1469,47 @@ def linspace(start, stop, num, dtype=None, name=None): ...@@ -1467,29 +1469,47 @@ def linspace(start, stop, num, dtype=None, name=None):
""" """
if dtype is None: if dtype is None:
dtype = 'float32' dtype = 'float32'
tensor_num = num
tensor_start = start
tensor_stop = stop
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
if not isinstance(start, Variable): if not isinstance(start, Variable):
start = fill_constant([1], dtype, start) tensor_start = fill_constant([1], dtype, start)
if not isinstance(stop, Variable): if not isinstance(stop, Variable):
stop = fill_constant([1], dtype, stop) tensor_stop = fill_constant([1], dtype, stop)
if not isinstance(num, Variable): if not isinstance(num, Variable):
num = fill_constant([1], 'int32', num) tensor_num = fill_constant([1], 'int32', num)
if in_dygraph_mode(): if in_dygraph_mode():
return core.ops.linspace(start, stop, num) return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype',
dtype)
helper = LayerHelper("linspace", **locals()) helper = LayerHelper("linspace", **locals())
check_dtype(start.dtype, 'start', ['float32', 'float64'], 'linspace') if isinstance(start, Variable):
check_dtype(stop.dtype, 'stop', ['float32', 'float64'], 'linspace') check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace')
check_dtype(num.dtype, 'num', ['int32', 'int64'], 'linspace') else:
check_dtype(dtype, 'dtype', ['float32', 'float64'], 'linspace') check_type(start, 'start', (int, float), 'linspace')
if isinstance(stop, Variable):
check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace')
else:
check_type(stop, 'stop', (int, float), 'linspace')
if isinstance(num, Variable):
check_dtype(num.dtype, 'num', ['int32'], 'linspace')
else:
check_type(num, 'num', (int), 'linspace')
check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'],
'linspace')
out = helper.create_variable_for_type_inference(dtype=start.dtype) out = helper.create_variable_for_type_inference(dtype=dtype)
helper.append_op( helper.append_op(
type='linspace', type='linspace',
inputs={'Start': start, inputs={'Start': tensor_start,
'Stop': stop, 'Stop': tensor_stop,
'Num': num}, 'Num': tensor_num},
attrs={'dtype': dtype},
outputs={'Out': [out]}) outputs={'Out': [out]})
return out return out
......
...@@ -32,6 +32,7 @@ class TestLinspaceOpCommonCase(OpTest): ...@@ -32,6 +32,7 @@ class TestLinspaceOpCommonCase(OpTest):
'Stop': np.array([10]).astype(dtype), 'Stop': np.array([10]).astype(dtype),
'Num': np.array([11]).astype('int32') 'Num': np.array([11]).astype('int32')
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(0, 11).astype(dtype)} self.outputs = {'Out': np.arange(0, 11).astype(dtype)}
...@@ -48,6 +49,7 @@ class TestLinspaceOpReverseCase(OpTest): ...@@ -48,6 +49,7 @@ class TestLinspaceOpReverseCase(OpTest):
'Stop': np.array([0]).astype(dtype), 'Stop': np.array([0]).astype(dtype),
'Num': np.array([11]).astype('int32') 'Num': np.array([11]).astype('int32')
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)}
...@@ -64,6 +66,7 @@ class TestLinspaceOpNumOneCase(OpTest): ...@@ -64,6 +66,7 @@ class TestLinspaceOpNumOneCase(OpTest):
'Stop': np.array([0]).astype(dtype), 'Stop': np.array([0]).astype(dtype),
'Num': np.array([1]).astype('int32') 'Num': np.array([1]).astype('int32')
} }
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': np.array(10, dtype=dtype)} self.outputs = {'Out': np.array(10, dtype=dtype)}
...@@ -72,6 +75,26 @@ class TestLinspaceOpNumOneCase(OpTest): ...@@ -72,6 +75,26 @@ class TestLinspaceOpNumOneCase(OpTest):
class TestLinspaceAPI(unittest.TestCase): class TestLinspaceAPI(unittest.TestCase):
def test_variable_input1(self):
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
out = paddle.linspace(start, stop, num, dtype='float32')
exe = fluid.Executor(place=fluid.CPUPlace())
res = exe.run(fluid.default_main_program(), fetch_list=[out])
np_res = np.linspace(0, 10, 5, dtype='float32')
self.assertEqual((res == np_res).all(), True)
def test_variable_input2(self):
paddle.disable_static()
start = paddle.full(shape=[1], fill_value=0, dtype='float32')
stop = paddle.full(shape=[1], fill_value=10, dtype='float32')
num = paddle.full(shape=[1], fill_value=5, dtype='int32')
out = paddle.linspace(start, stop, num, dtype='float32')
np_res = np.linspace(0, 10, 5, dtype='float32')
self.assertEqual((out.numpy() == np_res).all(), True)
paddle.enable_static()
def test_dtype(self): def test_dtype(self):
out_1 = paddle.linspace(0, 10, 5, dtype='float32') out_1 = paddle.linspace(0, 10, 5, dtype='float32')
out_2 = paddle.linspace(0, 10, 5, dtype=np.float32) out_2 = paddle.linspace(0, 10, 5, dtype=np.float32)
...@@ -89,10 +112,16 @@ class TestLinspaceAPI(unittest.TestCase): ...@@ -89,10 +112,16 @@ class TestLinspaceAPI(unittest.TestCase):
def test_imperative(self): def test_imperative(self):
paddle.disable_static() paddle.disable_static()
out = paddle.linspace(0, 10, 5, dtype='float32') out1 = paddle.linspace(0, 10, 5, dtype='float32')
np_out = np.linspace(0, 10, 5, dtype='float32') np_out1 = np.linspace(0, 10, 5, dtype='float32')
out2 = paddle.linspace(0, 10, 5, dtype='int32')
np_out2 = np.linspace(0, 10, 5, dtype='int32')
out3 = paddle.linspace(0, 10, 200, dtype='int32')
np_out3 = np.linspace(0, 10, 200, dtype='int32')
paddle.enable_static() paddle.enable_static()
self.assertEqual((out.numpy() == np_out).all(), True) self.assertEqual((out1.numpy() == np_out1).all(), True)
self.assertEqual((out2.numpy() == np_out2).all(), True)
self.assertEqual((out3.numpy() == np_out3).all(), True)
class TestLinspaceOpError(unittest.TestCase): class TestLinspaceOpError(unittest.TestCase):
...@@ -100,7 +129,12 @@ class TestLinspaceOpError(unittest.TestCase): ...@@ -100,7 +129,12 @@ class TestLinspaceOpError(unittest.TestCase):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
def test_dtype(): def test_dtype():
fluid.layers.linspace(0, 10, 1, dtype="int32") fluid.layers.linspace(0, 10, 1, dtype="int8")
self.assertRaises(TypeError, test_dtype)
def test_dtype():
fluid.layers.linspace(0, 10, 1.33, dtype="int32")
self.assertRaises(TypeError, test_dtype) self.assertRaises(TypeError, test_dtype)
...@@ -120,20 +154,20 @@ class TestLinspaceOpError(unittest.TestCase): ...@@ -120,20 +154,20 @@ class TestLinspaceOpError(unittest.TestCase):
self.assertRaises(TypeError, test_step_dtype) self.assertRaises(TypeError, test_step_dtype)
def test_start_dtype(): def test_start_dtype():
start = fluid.data(shape=[1], type="int32", name="start") start = fluid.data(shape=[1], dtype="int32", name="start")
fluid.layers.linspace(start, 10, 1, dtype="float32") fluid.layers.linspace(start, 10, 1, dtype="float32")
self.assertRaises(TypeError, test_start_dtype) self.assertRaises(TypeError, test_start_dtype)
def test_end_dtype(): def test_end_dtype():
end = fluid.data(shape=[1], type="int32", name="end") end = fluid.data(shape=[1], dtype="int32", name="end")
fluid.layers.linspace(0, end, 1, dtype="float32") fluid.layers.linspace(0, end, 1, dtype="float32")
self.assertRaises(TypeError, test_end_dtype) self.assertRaises(TypeError, test_end_dtype)
def test_step_dtype(): def test_num_dtype():
step = fluid.data(shape=[1], type="int32", name="step") num = fluid.data(shape=[1], dtype="int32", name="step")
fluid.layers.linspace(0, 10, step, dtype="float32") fluid.layers.linspace(0, 10, num, dtype="float32")
self.assertRaises(TypeError, test_step_dtype) self.assertRaises(TypeError, test_step_dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册