diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 95e4004066e657afe2b1663b9cf4ea1ea18a4251..3f576a45169c9c7e4581d304efc7cf0bca1b310a 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -275,6 +275,7 @@ paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, de paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '2e53e83127dbfd86e7098bdfe9a549e8')) paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '0a437011c3906079fd8947ed3e52d292')) paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '2ec937ede953ded2fdff2675883900bb')) +paddle.fluid.layers.linspace (ArgSpec(args=['start', 'stop', 'num', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '495e21e9a848c2d075a102802fc67756')) paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.While.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.Switch.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/operators/linspace_op.cc b/paddle/fluid/operators/linspace_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..f4aeb062d8dfae31a72b8ebccb3d377276662da6 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.cc @@ -0,0 +1,84 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/linspace_op.h" + +namespace paddle { +namespace operators { + +class LinspaceOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Start"), + "Input(Start) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Stop"), + "Input(Stop) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Num"), + "Input(Num) of LinspaceOp should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(OUt) of LinspaceOp should not be null."); + + auto s_dims = ctx->GetInputDim("Start"); + PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1), + "The shape of Input(Start) should be [1]."); + + auto e_dims = ctx->GetInputDim("Stop"); + PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1), + "The shape of Input(Stop) should be [1]."); + + auto step_dims = ctx->GetInputDim("Num"); + PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1), + "The shape of Input(Num) should be [1]."); + + ctx->SetOutputDim("Out", {-1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + framework::LibraryType library_{framework::LibraryType::kPlain}; + framework::DataLayout layout_ = framework::DataLayout::kAnyLayout; + return framework::OpKernelType( + ctx.Input("Start")->type(), ctx.device_context(), + layout_, library_); + } +}; + +class LinspaceOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Start", + "First entry in the sequence. It is a tensor of shape [1], should " + "be of type float32 or float64."); + AddInput("Stop", + "Last entry in the sequence. It is a tensor of shape [1], should " + "be of type float32 or float64."); + AddInput("Num", + "Number of entry in the sequence. It is a tensor of shape [1], " + "should be of type int32."); + AddOutput("Out", "A sequence of numbers."); + AddComment(R"DOC( + Return fixed number of evenly spaced values within a given interval. First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(linspace, ops::LinspaceOp, ops::LinspaceOpMaker); +REGISTER_OP_CPU_KERNEL(linspace, ops::CPULinspaceKernel, + ops::CPULinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.cu b/paddle/fluid/operators/linspace_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..90bd17cda0e0d1f78810233537bb502f9115fbd0 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.cu @@ -0,0 +1,75 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/linspace_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void LinspaceKernel(T start, T step, int64_t size, T* out) { + CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +__global__ void LinspaceSpecialKernel(T start, T* out) { + out[0] = start; +} + +template +class CUDALinspaceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* start_t = context.Input("Start"); + auto* stop_t = context.Input("Stop"); + auto* num_t = context.Input("Num"); + auto* out = context.Output("Out"); + + framework::Tensor n; + framework::TensorCopy(*start_t, platform::CPUPlace(), &n); + T start = n.data()[0]; + framework::TensorCopy(*stop_t, platform::CPUPlace(), &n); + T stop = n.data()[0]; + framework::TensorCopy(*num_t, platform::CPUPlace(), &n); + int32_t num = n.data()[0]; + + PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); + + out->Resize(framework::make_ddim({num})); + T* out_data = out->mutable_data(context.GetPlace()); + + T step = 0; + if (num != 1) { + step = (stop - start) / (num - 1); + } + + auto stream = context.cuda_device_context().stream(); + int block = 512; + int grid = (num + block - 1) / block; + LinspaceKernel<<>>(start, step, num, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(linspace, ops::CUDALinspaceKernel, + ops::CUDALinspaceKernel); diff --git a/paddle/fluid/operators/linspace_op.h b/paddle/fluid/operators/linspace_op.h new file mode 100644 index 0000000000000000000000000000000000000000..b1fcac73b0ad249aa19859bde770a8554cdb7408 --- /dev/null +++ b/paddle/fluid/operators/linspace_op.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class CPULinspaceKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + T start = context.Input("Start")->data()[0]; + T stop = context.Input("Stop")->data()[0]; + int32_t num = context.Input("Num")->data()[0]; + auto* out = context.Output("Out"); + PADDLE_ENFORCE(num > 0, "The num of linspace op should be larger than 0."); + + out->Resize(framework::make_ddim({num})); + + T* out_data = out->mutable_data(context.GetPlace()); + + if (num > 1) { + T step = (stop - start) / (num - 1); + T value = start; + for (int i = 0; i < num; ++i) { + out_data[i] = value; + value += step; + } + } else { + out_data[0] = start; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index 80450119f44e93aae4b483983484ea18be5b2035..03ebd41fa00c69bfce66d325e32fc9aeb25a2486 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -24,26 +24,11 @@ from .layer_function_generator import templatedoc import numpy __all__ = [ - 'create_tensor', - 'create_parameter', - 'create_global_var', - 'cast', - 'tensor_array_to_tensor', - 'concat', - 'sums', - 'assign', - 'fill_constant_batch_size_like', - 'fill_constant', - 'argmin', - 'argmax', - 'argsort', - 'ones', - 'zeros', - 'reverse', - 'has_inf', - 'has_nan', - 'isfinite', - 'range', + 'create_tensor', 'create_parameter', 'create_global_var', 'cast', + 'tensor_array_to_tensor', 'concat', 'sums', 'assign', + 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', + 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite', + 'range', 'linspace' ] @@ -826,3 +811,45 @@ def range(start, end, step, dtype): 'Step': step}, outputs={'Out': [out]}) return out + + +def linspace(start, stop, num, dtype): + """ + Return fixed number of evenly spaced values within a given interval. + + First entry is start, and last entry is stop. In the case when Num is 1, only Start is returned. Like linspace function of numpy. + + Args: + start(float|Variable): First entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'. + stop(float|Variable): Last entry in the sequence. It is a float scalar, or a tensor of shape [1] with type 'float32'|'float64'. + num(int|Variable): Number of entry in the sequence. It is an int scalar, or a tensor of shape [1] with type int32. + dtype(string): 'float32'|'float64', the data type of the output tensor. + + Returns: + Variable: The tensor variable storing a 1-D tensor. + + Examples: + .. code-block:: python + + data = fluid.layers.linspace(0, 10, 5, 'float32') # [0.0, 2.5, 5.0, 7.5, 10.0] + data = fluid.layers.linspace(0, 10, 1, 'float32') # [0.0] + + """ + helper = LayerHelper("linspace", **locals()) + + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + if not isinstance(stop, Variable): + stop = fill_constant([1], dtype, stop) + if not isinstance(num, Variable): + num = fill_constant([1], 'int32', num) + + out = helper.create_variable_for_type_inference(dtype=start.dtype) + + helper.append_op( + type='linspace', + inputs={'Start': start, + 'Stop': stop, + 'Num': num}, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 38d0533a7ec820241c6a08f2180a7426984068f2..6630fb26aff9a8c570e65c34a753595da883bea1 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1925,6 +1925,13 @@ class TestBook(LayerTest): out = layers.flatten(x, axis=1, name="flatten") return (out) + def test_linspace(self): + program = Program() + with program_guard(program): + out = layers.linspace(20, 10, 5, 'float64') + self.assertIsNotNone(out) + print(str(program)) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_linspace.py b/python/paddle/fluid/tests/unittests/test_linspace.py new file mode 100644 index 0000000000000000000000000000000000000000..eeecf178320327cc251f32bfe46c1622200339f4 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linspace.py @@ -0,0 +1,71 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestLinspaceOpCommonCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([0]).astype(dtype), + 'Stop': np.array([10]).astype(dtype), + 'Num': np.array([11]).astype('int32') + } + + self.outputs = {'Out': np.arange(0, 11).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLinspaceOpReverseCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([11]).astype('int32') + } + + self.outputs = {'Out': np.arange(10, -1, -1).astype(dtype)} + + def test_check_output(self): + self.check_output() + + +class TestLinspaceOpNumOneCase(OpTest): + def setUp(self): + self.op_type = "linspace" + dtype = 'float32' + self.inputs = { + 'Start': np.array([10]).astype(dtype), + 'Stop': np.array([0]).astype(dtype), + 'Num': np.array([1]).astype('int32') + } + + self.outputs = {'Out': np.array(10, dtype=dtype)} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main()