diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc index 0a7146be83dcb673573f1fdcb94ed2d2c57bd2c3..2c3172d2a1112e2c79a3c1215ccd0d3f08d59451 100644 --- a/paddle/fluid/operators/linspace_op.cc +++ b/paddle/fluid/operators/linspace_op.cc @@ -53,11 +53,9 @@ class LinspaceOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - framework::LibraryType library_{framework::LibraryType::kPlain}; - framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "Start"), - ctx.device_context(), layout_, library_); + framework::proto::VarType::Type(ctx.Attr("dtype")), + ctx.GetPlace()); } }; @@ -73,6 +71,7 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Num", "Number of entry in the sequence. It is a tensor of shape [1], " "should be of type int32."); + AddAttr("dtype", "The output data type."); AddOutput("Out", "A sequence of numbers."); AddComment(R"DOC( Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. @@ -85,4 +84,6 @@ class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { namespace ops = paddle::operators; REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel, + ops::CPULinspaceKernel, + ops::CPULinspaceKernel, ops::CPULinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu index 47d4536dcfe2a0ab43b3584196a138214e438e3e..8aca892a81d41b1e0a9f7f9c14169c2817ae9452 100644 --- a/paddle/fluid/operators/linspace_op.cu +++ b/paddle/fluid/operators/linspace_op.cu @@ -20,13 +20,15 @@ namespace paddle { namespace operators { template -__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { - CUDA_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +__global__ void LinspaceKernel(T start, double step, int64_t size, T* out) { + CUDA_KERNEL_LOOP(index, size) { + out[index] = static_cast(start + step * index); + } } template __global__ void LinspaceSpecialKernel(T start, T* out) { - out[0] = start; + out[0] = static_cast(start); } template @@ -51,9 +53,9 @@ class CUDALinspaceKernel : public framework::OpKernel { out->Resize(framework::make_ddim({num})); T* out_data = out->mutable_data(context.GetPlace()); - T step = 0; + double step = 0; if (num != 1) { - step = (stop - start) / (num - 1); + step = (static_cast(stop - start)) / (num - 1); } auto stream = context.cuda_device_context().stream(); @@ -68,4 +70,6 @@ class CUDALinspaceKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel, + ops::CUDALinspaceKernel, + ops::CUDALinspaceKernel, ops::CUDALinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h index b1fcac73b0ad249aa19859bde770a8554cdb7408..9fb4960375ed7be60598d558c65310bd4a4b84bc 100644 --- a/paddle/fluid/operators/linspace_op.h +++ b/paddle/fluid/operators/linspace_op.h @@ -35,14 +35,12 @@ class CPULinspaceKernel : public framework::OpKernel { T* out_data = out->mutable_data(context.GetPlace()); if (num > 1) { - T step = (stop - start) / (num - 1); - T value = start; + double step = (static_cast(stop - start)) / (num - 1); for (int i = 0; i < num; ++i) { - out_data[i] = value; - value += step; + out_data[i] = static_cast(start + step * i); } } else { - out_data[0] = start; + out_data[0] = static_cast(start); } } }; diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 7ac67b1bc817964ca65d5b7009b446458d2cc7ab..d8521586beac44f0439d780dde173b7f0769e17a 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -1435,14 +1435,14 @@ def linspace(start, stop, num, dtype=None, name=None): This OP return fixed number of evenly spaced values within a given interval. Args: - start(float|Tensor): The input :attr:`start` is start variable of range. It is a float scalar, \ - or a Tensor of shape [1] with input data type float32, float64. - stop(float|Tensor): The input :attr:`stop` is start variable of range. It is a float scalar, \ - or a Tensor of shape [1] with input data type float32, float64. + start(int|float|Tensor): The input :attr:`start` is start variable of range. It is a scalar, \ + or a Tensor of shape [1] with input data type int32, int64, float32 or float64. + stop(int|float|Tensor): The input :attr:`stop` is start variable of range. It is a scalar, \ + or a Tensor of shape [1] with input data type int32, int64, float32 or float64. num(int|Tensor): The input :attr:`num` is given num of the sequence. It is an int scalar, \ - or a Tensor of shape [1] with data type int32. - dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be 'float32' and 'float64'. - Default: if None, the data type is float32. + or a Tensor of shape [1] with data type int32 or int64. + dtype(np.dtype|core.VarDesc.VarType|str, optional): The data type of output tensor, it could be + int32, int64, float32 and float64. Default: if None, the data type is float32. name(str, optional): Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name`.Default: None. @@ -1452,9 +1452,11 @@ def linspace(start, stop, num, dtype=None, name=None): the value with input :attr:`start`. Raises: - TypeError: The ``dtype`` must be one of float32 and float64. - TypeError: The data type of ``start`` and ``stop`` must be one of float32 and float64. - TypeError: The data type of ``num`` must be one of int32 and int64. + TypeError: The ``dtype`` must be one of int32, int64, float32 and float64. + TypeError: The type of ``num`` must be int When it's not a Tensor. + TypeError: The data type of ``num`` must be int32 When it's a Tensor. + TypeError: The data type of ``start`` and ``stop`` must be same as ``dtype`` When it's a Tensor. + Examples: @@ -1467,29 +1469,47 @@ def linspace(start, stop, num, dtype=None, name=None): """ if dtype is None: dtype = 'float32' + tensor_num = num + tensor_start = start + tensor_stop = stop + if not isinstance(dtype, core.VarDesc.VarType): + dtype = convert_np_dtype_to_dtype_(dtype) if not isinstance(start, Variable): - start = fill_constant([1], dtype, start) + tensor_start = fill_constant([1], dtype, start) if not isinstance(stop, Variable): - stop = fill_constant([1], dtype, stop) + tensor_stop = fill_constant([1], dtype, stop) if not isinstance(num, Variable): - num = fill_constant([1], 'int32', num) + tensor_num = fill_constant([1], 'int32', num) if in_dygraph_mode(): - return core.ops.linspace(start, stop, num) + return core.ops.linspace(tensor_start, tensor_stop, tensor_num, 'dtype', + dtype) helper = LayerHelper("linspace", **locals()) - check_dtype(start.dtype, 'start', ['float32', 'float64'], 'linspace') - check_dtype(stop.dtype, 'stop', ['float32', 'float64'], 'linspace') - check_dtype(num.dtype, 'num', ['int32', 'int64'], 'linspace') - check_dtype(dtype, 'dtype', ['float32', 'float64'], 'linspace') + if isinstance(start, Variable): + check_dtype(start.dtype, 'start', (convert_dtype(dtype)), 'linspace') + else: + check_type(start, 'start', (int, float), 'linspace') - out = helper.create_variable_for_type_inference(dtype=start.dtype) + if isinstance(stop, Variable): + check_dtype(stop.dtype, 'stop', (convert_dtype(dtype)), 'linspace') + else: + check_type(stop, 'stop', (int, float), 'linspace') + if isinstance(num, Variable): + check_dtype(num.dtype, 'num', ['int32'], 'linspace') + else: + check_type(num, 'num', (int), 'linspace') + check_dtype(dtype, 'dtype', ['int32', 'int64', 'float32', 'float64'], + 'linspace') + + out = helper.create_variable_for_type_inference(dtype=dtype) helper.append_op( type='linspace', - inputs={'Start': start, - 'Stop': stop, - 'Num': num}, + inputs={'Start': tensor_start, + 'Stop': tensor_stop, + 'Num': tensor_num}, + attrs={'dtype': dtype}, outputs={'Out': [out]}) return out diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py index 068993c4c1c5e770dd6cf7dc7a35b9ccc3f49aae..6d1f42111eebff0f469317ddf2a9ec7698a7ae1e 100644 --- a/python/paddle/fluid/tests/unittests/test_linspace.py +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -32,6 +32,7 @@ class TestLinspaceOpCommonCase(OpTest): 'Stop': np.array([10]).astype(dtype), 'Num': np.array([11]).astype('int32') } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': np.arange(0, 11).astype(dtype)} @@ -48,6 +49,7 @@ class TestLinspaceOpReverseCase(OpTest): 'Stop': np.array([0]).astype(dtype), 'Num': np.array([11]).astype('int32') } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} @@ -64,6 +66,7 @@ class TestLinspaceOpNumOneCase(OpTest): 'Stop': np.array([0]).astype(dtype), 'Num': np.array([1]).astype('int32') } + self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)} self.outputs = {'Out': np.array(10, dtype=dtype)} @@ -72,6 +75,26 @@ class TestLinspaceOpNumOneCase(OpTest): class TestLinspaceAPI(unittest.TestCase): + def test_variable_input1(self): + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + out = paddle.linspace(start, stop, num, dtype='float32') + exe = fluid.Executor(place=fluid.CPUPlace()) + res = exe.run(fluid.default_main_program(), fetch_list=[out]) + np_res = np.linspace(0, 10, 5, dtype='float32') + self.assertEqual((res == np_res).all(), True) + + def test_variable_input2(self): + paddle.disable_static() + start = paddle.full(shape=[1], fill_value=0, dtype='float32') + stop = paddle.full(shape=[1], fill_value=10, dtype='float32') + num = paddle.full(shape=[1], fill_value=5, dtype='int32') + out = paddle.linspace(start, stop, num, dtype='float32') + np_res = np.linspace(0, 10, 5, dtype='float32') + self.assertEqual((out.numpy() == np_res).all(), True) + paddle.enable_static() + def test_dtype(self): out_1 = paddle.linspace(0, 10, 5, dtype='float32') out_2 = paddle.linspace(0, 10, 5, dtype=np.float32) @@ -89,10 +112,16 @@ class TestLinspaceAPI(unittest.TestCase): def test_imperative(self): paddle.disable_static() - out = paddle.linspace(0, 10, 5, dtype='float32') - np_out = np.linspace(0, 10, 5, dtype='float32') + out1 = paddle.linspace(0, 10, 5, dtype='float32') + np_out1 = np.linspace(0, 10, 5, dtype='float32') + out2 = paddle.linspace(0, 10, 5, dtype='int32') + np_out2 = np.linspace(0, 10, 5, dtype='int32') + out3 = paddle.linspace(0, 10, 200, dtype='int32') + np_out3 = np.linspace(0, 10, 200, dtype='int32') paddle.enable_static() - self.assertEqual((out.numpy() == np_out).all(), True) + self.assertEqual((out1.numpy() == np_out1).all(), True) + self.assertEqual((out2.numpy() == np_out2).all(), True) + self.assertEqual((out3.numpy() == np_out3).all(), True) class TestLinspaceOpError(unittest.TestCase): @@ -100,7 +129,12 @@ class TestLinspaceOpError(unittest.TestCase): with program_guard(Program(), Program()): def test_dtype(): - fluid.layers.linspace(0, 10, 1, dtype="int32") + fluid.layers.linspace(0, 10, 1, dtype="int8") + + self.assertRaises(TypeError, test_dtype) + + def test_dtype(): + fluid.layers.linspace(0, 10, 1.33, dtype="int32") self.assertRaises(TypeError, test_dtype) @@ -120,20 +154,20 @@ class TestLinspaceOpError(unittest.TestCase): self.assertRaises(TypeError, test_step_dtype) def test_start_dtype(): - start = fluid.data(shape=[1], type="int32", name="start") + start = fluid.data(shape=[1], dtype="int32", name="start") fluid.layers.linspace(start, 10, 1, dtype="float32") self.assertRaises(TypeError, test_start_dtype) def test_end_dtype(): - end = fluid.data(shape=[1], type="int32", name="end") + end = fluid.data(shape=[1], dtype="int32", name="end") fluid.layers.linspace(0, end, 1, dtype="float32") self.assertRaises(TypeError, test_end_dtype) - def test_step_dtype(): - step = fluid.data(shape=[1], type="int32", name="step") - fluid.layers.linspace(0, 10, step, dtype="float32") + def test_num_dtype(): + num = fluid.data(shape=[1], dtype="int32", name="step") + fluid.layers.linspace(0, 10, num, dtype="float32") self.assertRaises(TypeError, test_step_dtype)