diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 598604ca7aeeb9938d124606b4efbf61397ee41f..6c61a3d63df3f1734cf4a27e7e27e6b954232af3 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -255,6 +255,7 @@ paddle.fluid.layers.reverse (ArgSpec(args=['x', 'axis'], varargs=None, keywords= paddle.fluid.layers.has_inf (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '8f8c0306117ea441f20dcbbdba1f0ecc')) paddle.fluid.layers.has_nan (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '2e53e83127dbfd86e7098bdfe9a549e8')) paddle.fluid.layers.isfinite (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', '0a437011c3906079fd8947ed3e52d292')) +paddle.fluid.layers.range (ArgSpec(args=['start', 'end', 'step', 'dtype'], varargs=None, keywords=None, defaults=None), ('document', '2ec937ede953ded2fdff2675883900bb')) paddle.fluid.layers.While.__init__ (ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.While.block (ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.Switch.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) diff --git a/paddle/fluid/operators/range_op.cc b/paddle/fluid/operators/range_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..ee8c68fd008c8c9764e9ef74dc37fa08cf31be19 --- /dev/null +++ b/paddle/fluid/operators/range_op.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/range_op.h" + +namespace paddle { +namespace operators { + +class RangeOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + if (ctx->HasInput("Start")) { + auto s_dims = ctx->GetInputDim("Start"); + PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1), + "The shape of Input(Start) should be [1]."); + } + if (ctx->HasInput("End")) { + auto e_dims = ctx->GetInputDim("End"); + PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1), + "The shape of Input(End) should be [1]."); + } + if (ctx->HasInput("Step")) { + auto step_dims = ctx->GetInputDim("Step"); + PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1), + "The shape of Input(Step) should be [1]."); + } + ctx->SetOutputDim("Out", {-1}); + } +}; + +class RangeOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("Start", + "Start of interval. The interval includes this value. It is a " + "tensor with shape=[1]."); + AddInput("End", + "End of interval. The interval does not include this value, " + "except in some cases where step is not an integer and floating " + "point round-off affects the length of out. It is a tensor with " + "shape=[1]."); + AddInput("Step", "Spacing between values. It is a tensor with shape=[1]."); + AddOutput("Out", "A sequence of numbers."); + AddComment(R"DOC( + Return evenly spaced values within a given interval. Values are generated within the half-open interval [start, stop) (in other words, the interval including start but excluding stop). Like arange function of numpy. +)DOC"); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker); +REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel, + ops::CPURangeKernel, ops::CPURangeKernel, + ops::CPURangeKernel); diff --git a/paddle/fluid/operators/range_op.cu b/paddle/fluid/operators/range_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..e2c03716d55ee41ce3a9053b48b5c6d4c70e391f --- /dev/null +++ b/paddle/fluid/operators/range_op.cu @@ -0,0 +1,67 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/range_op.h" +#include "paddle/fluid/platform/cuda_primitives.h" + +namespace paddle { +namespace operators { + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + +template +__global__ void RangeKernel(T start, T step, int64_t size, T* out) { + CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } +} + +template +class CUDARangeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* start_t = context.Input("Start"); + auto* end_t = context.Input("End"); + auto* step_t = context.Input("Step"); + auto* out = context.Output("Out"); + + framework::Tensor n; + framework::TensorCopy(*start_t, platform::CPUPlace(), &n); + T start = n.data()[0]; + framework::TensorCopy(*end_t, platform::CPUPlace(), &n); + T end = n.data()[0]; + framework::TensorCopy(*step_t, platform::CPUPlace(), &n); + T step = n.data()[0]; + + int64_t size = 0; + GetSize(start, end, step, &size); + out->Resize(framework::make_ddim({size})); + T* out_data = out->mutable_data(context.GetPlace()); + + auto stream = context.cuda_device_context().stream(); + int block = 512; + int grid = (size + block - 1) / block; + RangeKernel<<>>(start, step, size, out_data); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_CUDA_KERNEL(range, ops::CUDARangeKernel, + ops::CUDARangeKernel, + ops::CUDARangeKernel, + ops::CUDARangeKernel); diff --git a/paddle/fluid/operators/range_op.h b/paddle/fluid/operators/range_op.h new file mode 100644 index 0000000000000000000000000000000000000000..fce58b45c96ad76dfdd4ed7f54becde327070002 --- /dev/null +++ b/paddle/fluid/operators/range_op.h @@ -0,0 +1,56 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +void GetSize(T start, T end, T step, int64_t* size) { + PADDLE_ENFORCE(!std::equal_to()(step, 0), + "The step of range op should not be 0."); + PADDLE_ENFORCE(((start < end) && (step > 0)) || ((start > end) && (step < 0)), + "The step should be greater than 0 while start < end. And the " + "step should be less than 0 while start > end."); + *size = std::is_integral::value + ? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step)) + : std::ceil(std::abs((end - start) / step)); +} + +template +class CPURangeKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + T start = context.Input("Start")->data()[0]; + T end = context.Input("End")->data()[0]; + T step = context.Input("Step")->data()[0]; + auto* out = context.Output("Out"); + int64_t size = 0; + GetSize(start, end, step, &size); + out->Resize(framework::make_ddim({size})); + T* out_data = out->mutable_data(context.GetPlace()); + T value = start; + for (int64_t i = 0; i < size; ++i) { + out_data[i] = value; + value += step; + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/tensor.py b/python/paddle/fluid/layers/tensor.py index cb973986988c2909f5ef1e15dd32db3e83b1d269..a18e5b6a9c3fe69ee0bcadc150f07b72227df85e 100644 --- a/python/paddle/fluid/layers/tensor.py +++ b/python/paddle/fluid/layers/tensor.py @@ -25,10 +25,26 @@ from .layer_function_generator import templatedoc import numpy __all__ = [ - 'create_tensor', 'create_parameter', 'create_global_var', 'cast', - 'tensor_array_to_tensor', 'concat', 'sums', 'assign', - 'fill_constant_batch_size_like', 'fill_constant', 'argmin', 'argmax', - 'argsort', 'ones', 'zeros', 'reverse', 'has_inf', 'has_nan', 'isfinite' + 'create_tensor', + 'create_parameter', + 'create_global_var', + 'cast', + 'tensor_array_to_tensor', + 'concat', + 'sums', + 'assign', + 'fill_constant_batch_size_like', + 'fill_constant', + 'argmin', + 'argmax', + 'argsort', + 'ones', + 'zeros', + 'reverse', + 'has_inf', + 'has_nan', + 'isfinite', + 'range', ] @@ -764,3 +780,50 @@ def isfinite(x): out = helper.create_variable_for_type_inference(dtype=x.dtype) helper.append_op(type="isfinite", inputs={"X": x}, outputs={"Out": out}) return out + + +def range(start, end, step, dtype): + """ + Return evenly spaced values within a given interval. + + Values are generated within the half-open interval [start, stop) (in other words, + the interval including start but excluding stop). + + args: + start(int|float|Variable): Start of interval. The interval includes this value. + end(int|float|Variable): End of interval. The interval does not include this + value, except in some cases where step is not an integer + and floating point round-off affects the length of out. + step(int|float|Variable): Spacing between values. For any output out, this is the + distance between two adjacent values, out[i+1] - out[i]. + The default step size is 1. + dtype(string): 'float32'|'int32'|..., the data type of the output tensor. + + returns: + Evenly spaced values within a given interval. + + examples: + + .. code-block:: python + + data = fluid.layers.range(0, 10, 2, 'int32') + + """ + helper = LayerHelper("range", **locals()) + + if not isinstance(start, Variable): + start = fill_constant([1], dtype, start) + if not isinstance(end, Variable): + end = fill_constant([1], dtype, end) + if not isinstance(step, Variable): + step = fill_constant([1], dtype, step) + + out = helper.create_variable_for_type_inference(dtype=start.dtype) + + helper.append_op( + type='range', + inputs={'Start': start, + 'End': end, + 'Step': step}, + outputs={'Out': [out]}) + return out diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 885ee170e8032ef865ebfdd646fed1e995e9e60b..1672c3600f389d87e85f965f96122065137cf0ac 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -1240,6 +1240,14 @@ class TestBook(unittest.TestCase): print(str(program)) + def test_range(self): + program = Program() + with program_guard(program): + layers.range(0, 10, 2, 'int32') + layers.range(0.1, 10.0, 0.2, 'float32') + + print(str(program)) + def test_spectral_norm(self): program = Program() with program_guard(program): diff --git a/python/paddle/fluid/tests/unittests/test_range.py b/python/paddle/fluid/tests/unittests/test_range.py new file mode 100644 index 0000000000000000000000000000000000000000..f129ae78cbf7e2ccd5d974de265b8e95d1391df8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_range.py @@ -0,0 +1,70 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest + + +class TestRangeOp(OpTest): + def setUp(self): + self.op_type = "range" + self.init_config() + self.inputs = { + 'Start': np.array([self.case[0]]).astype(self.dtype), + 'End': np.array([self.case[1]]).astype(self.dtype), + 'Step': np.array([self.case[2]]).astype(self.dtype) + } + + self.outputs = { + 'Out': np.arange(self.case[0], self.case[1], + self.case[2]).astype(self.dtype) + } + + def init_config(self): + self.dtype = np.float32 + self.case = (0, 1, 0.2) + + def test_check_output(self): + self.check_output() + + +class TestFloatRangeOpCase0(TestRangeOp): + def init_config(self): + self.dtype = np.float32 + self.case = (0, 5, 1) + + +class TestInt32RangeOpCase0(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (0, 5, 2) + + +class TestInt32RangeOpCase1(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (10, 1, -2) + + +class TestInt32RangeOpCase2(TestRangeOp): + def init_config(self): + self.dtype = np.int32 + self.case = (-1, -10, -2) + + +if __name__ == "__main__": + unittest.main()