diff --git a/paddle/fluid/operators/randperm_op.cc b/paddle/fluid/operators/randperm_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..67d7c578dcd777d084fcbe14658a9ae2cd3e0ed6 --- /dev/null +++ b/paddle/fluid/operators/randperm_op.cc @@ -0,0 +1,96 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/randperm_op.h" +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" + +namespace paddle { +namespace operators { + +class RandpermOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + platform::errors::NotFound( + "The output(Out) of randperm op must not be null.")); + int n = ctx->Attrs().Get("n"); + PADDLE_ENFORCE_GT( + n, 0, platform::errors::InvalidArgument( + "The input(n) of randperm op must be greater than 0.")); + + ctx->SetOutputDim("Out", framework::make_ddim({n})); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto data_type = + static_cast(ctx.Attr("dtype")); + return framework::OpKernelType(data_type, ctx.GetPlace()); + } +}; + +class RandpermOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "The output tensor of randperm op."); + + AddAttr( + "n", "The upper bound (exclusive), and it should be greater than 0."); + AddAttr("dtype", + "The data type of output tensor. " + "Default: 3[int64].") + .SetDefault(framework::proto::VarType::INT64); + AddAttr("seed", + "Random seed used for permute samples. " + "0 means use a seed generated by the system." + "Note that if seed is not 0, this operator will always " + "generate the same random permutation every time. " + "Default: 0.") + .SetDefault(0); + + AddComment(R"DOC( +This operator returns a random permutation of integers from 0 to n-1. +)DOC"); + } +}; + +class RandpermOpVarTypeInference : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + auto var_data_type = static_cast( + boost::get(ctx->GetAttr("dtype"))); + auto out_var_name = ctx->Output("Out").front(); + ctx->SetDataType(out_var_name, var_data_type); + } +}; + +} // namespace operators +} // namespace paddle + +REGISTER_OPERATOR( + randperm, paddle::operators::RandpermOp, paddle::operators::RandpermOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::operators::RandpermOpVarTypeInference); + +template +using kernel = + paddle::operators::RandpermKernel; + +REGISTER_OP_CPU_KERNEL(randperm, kernel, kernel); diff --git a/paddle/fluid/operators/randperm_op.cu b/paddle/fluid/operators/randperm_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..21ae1a4968a7e1fd9fd8aee3a12ea71c42a74d46 --- /dev/null +++ b/paddle/fluid/operators/randperm_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/randperm_op.h" + +template +using kernel = + paddle::operators::RandpermKernel; + +REGISTER_OP_CUDA_KERNEL(randperm, kernel, kernel); diff --git a/paddle/fluid/operators/randperm_op.h b/paddle/fluid/operators/randperm_op.h new file mode 100644 index 0000000000000000000000000000000000000000..64ef1c771423f2d820c73df8ed9ff25834f07875 --- /dev/null +++ b/paddle/fluid/operators/randperm_op.h @@ -0,0 +1,65 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include +#include +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +template +static inline void random_permate(T* data_ptr, int num, unsigned int seed) { + for (int i = 0; i < num; ++i) { + data_ptr[i] = static_cast(i); + } + if (seed == 0) { + seed = std::random_device()(); + } + std::srand(seed); + std::random_shuffle(data_ptr, data_ptr + num); +} + +template +class RandpermKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + int n = ctx.Attr("n"); + unsigned int seed = static_cast(ctx.Attr("seed")); + framework::Variable* out_var = ctx.OutputVar("Out"); + framework::Tensor* out_tensor = + framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var); + + if (platform::is_cpu_place(ctx.GetPlace())) { + T* out_data = out_tensor->mutable_data(platform::CPUPlace()); + random_permate(out_data, n, seed); + } else { + framework::Tensor tmp_tensor; + tmp_tensor.Resize(framework::make_ddim({n})); + T* tmp_data = tmp_tensor.mutable_data(platform::CPUPlace()); + random_permate(tmp_data, n, seed); + framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 2b13bab747995d303b0619e6e86242ee3a048e5d..c9606f6fb4fd4ca4cacd86881860bb50d942ab21 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -87,7 +87,7 @@ from .tensor.logic import elementwise_equal #DEFINE_ALIAS # from .tensor.random import uniform #DEFINE_ALIAS # from .tensor.random import shuffle #DEFINE_ALIAS # from .tensor.random import randn #DEFINE_ALIAS -# from .tensor.random import randperm #DEFINE_ALIAS +from .tensor.random import randperm # from .tensor.random import rand #DEFINE_ALIAS # from .tensor.random import randint #DEFINE_ALIAS # from .tensor.math import abs #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_randperm_op.py b/python/paddle/fluid/tests/unittests/test_randperm_op.py new file mode 100644 index 0000000000000000000000000000000000000000..2fbdc83f3abffb7f832e0b0396b745635cc47a00 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_randperm_op.py @@ -0,0 +1,175 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.op import Operator +from paddle.fluid import Program, program_guard + + +def check_randperm_out(n, data_np): + assert isinstance(data_np, np.ndarray), \ + "The input data_np should be np.ndarray." + gt_sorted = np.arange(n) + out_sorted = np.sort(data_np) + return list(gt_sorted == out_sorted) + + +def error_msg(data_np): + return "The sorted ground truth and sorted out should " + \ + "be equal, out = " + str(data_np) + + +def convert_dtype(dtype_str): + dtype_str_list = ["int32", "int64"] + dtype_num_list = [2, 3] + assert dtype_str in dtype_str_list, dtype_str + \ + " should in " + str(dtype_str_list) + return dtype_num_list[dtype_str_list.index(dtype_str)] + + +class TestRandpermOp(OpTest): + """ Test randperm op.""" + + def setUp(self): + self.op_type = "randperm" + self.n = 200 + self.dtype = "int64" + self.device = None + self.seed = 0 + + self.inputs = {} + self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)} + self.init_attrs() + self.attrs = { + "n": self.n, + "dtype": convert_dtype(self.dtype), + "device": self.device, + "seed": self.seed, + } + + def init_attrs(self): + pass + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + out_np = np.array(outs[0]) + self.assertTrue( + check_randperm_out(self.n, out_np), msg=error_msg(out_np)) + + +class TestRandpermOp_attr_n(TestRandpermOp): + """ Test randperm op for attr n. """ + + def init_attrs(self): + self.n = 10000 + + +class TestRandpermOp_attr_int32(TestRandpermOp): + """ Test randperm op for attr int32 dtype. """ + + def init_attrs(self): + self.dtype = "int32" + + +class TestRandpermOp_attr_device_cpu(TestRandpermOp): + """ Test randperm op for cpu device. """ + + def init_attrs(self): + self.device = "cpu" + + +class TestRandpermOp_attr_device_gpu(TestRandpermOp): + """ Test randperm op for gpu device. """ + + def init_attrs(self): + self.device = "gpu" + + +class TestRandpermOp_attr_seed(TestRandpermOp): + """ Test randperm op for attr seed. """ + + def init_attrs(self): + self.seed = 10 + + +class TestRandpermOpError(unittest.TestCase): + """ Test randperm op for raise error. """ + + def test_errors(self): + main_prog = Program() + start_prog = Program() + with program_guard(main_prog, start_prog): + + def test_Variable(): + out = np.arange(10) + paddle.randperm(n=10, out=out) + + self.assertRaises(TypeError, test_Variable) + + def test_value(): + paddle.randperm(n=-3) + + self.assertRaises(ValueError, test_value) + + +class TestRandpermOp_attr_out(unittest.TestCase): + """ Test randperm op for attr out. """ + + def test_attr_tensor_API(self): + startup_program = fluid.Program() + train_program = fluid.Program() + with fluid.program_guard(train_program, startup_program): + n = 10 + data_1 = fluid.layers.fill_constant([n], "int64", 3) + paddle.randperm(n=n, out=data_1) + + data_2 = paddle.randperm(n=n, dtype="int32", device="cpu") + + place = fluid.CPUPlace() + if fluid.core.is_compiled_with_cuda(): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + + exe.run(startup_program) + outs = exe.run(train_program, fetch_list=[data_1, data_2]) + + out_np = np.array(outs[0]) + self.assertTrue( + check_randperm_out(n, out_np), msg=error_msg(out_np)) + + +class TestRandpermDygraphMode(unittest.TestCase): + def test_check_output(self): + with fluid.dygraph.guard(): + n = 10 + data_1 = paddle.randperm(n, dtype="int64") + data_1_np = data_1.numpy() + self.assertTrue( + check_randperm_out(n, data_1_np), msg=error_msg(data_1_np)) + + data_2 = paddle.randperm(n, dtype="int32", device="cpu") + data_2_np = data_2.numpy() + self.assertTrue( + check_randperm_out(n, data_2_np), msg=error_msg(data_2_np)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 2923f2bf1c6f2703d86b408694568b329fa7cf04..68b487107c325f88c1c01456f93c1a3ad96e1019 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -63,9 +63,9 @@ from .logic import elementwise_equal #DEFINE_ALIAS # from .random import uniform #DEFINE_ALIAS # from .random import shuffle #DEFINE_ALIAS # from .random import randn #DEFINE_ALIAS -# from .random import randperm #DEFINE_ALIAS # from .random import rand #DEFINE_ALIAS # from .random import randint #DEFINE_ALIAS +from .random import randperm # from .math import abs #DEFINE_ALIAS # from .math import acos #DEFINE_ALIAS # from .math import asin #DEFINE_ALIAS diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 6f8ff58d651bec5356c09d60a78d8a7deac44208..97fd319e7db514300914a8536a64052a4aa63ec8 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -17,6 +17,100 @@ # 'uniform', # 'shuffle', # 'randn', -# 'randperm', # 'rand', # 'randint'] + +from ..fluid import core +from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator +from ..fluid.layers.layer_function_generator import templatedoc +from ..fluid.layer_helper import LayerHelper +from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype + +__all__ = ['randperm'] + + +@templatedoc() +def randperm(n, + out=None, + dtype="int64", + device=None, + stop_gradient=True, + seed=0): + """ + ${comment} + + Args: + n (int): The upper bound (exclusive), and it should be greater than 0. + out (Variable, optional): Optional output which can be any created + Variable that meets the requirements to store the result of operation. + If out is None, a new Varibale will be create to store the result. + Default: None. + dtype (np.dtype|core.VarDesc.VarType|str, optional): The type of the + output Tensor. Supported data types: int64, int32. Default: int32. + device (str, optional): Specific the output variable to be saved in cpu + or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output + variable will be automatically assigned devices. + Default: None. + stop_gradient (bool, optional): Whether grad should record operations + on the returned tensor. Default: True. + seed (int, optional): Random seed used for permute samples. If seed is + equal to 0, it means use a seed generated by the system. Note that + if seed is not 0, this operator will always generate the same random + permutation every time. Default: 0. + + Returns: + ${out_comment}. + + Return Type: + ${out_type} + + Examples: + .. code-block:: python + + import paddle + import paddle.fluid as fluid + + num = 6 + is_use_gpu = False + + data_1 = paddle.randperm(num) + fluid.layers.Print(data_1) + + data_2 = paddle.randperm(num, dtype="int32", seed=1) + fluid.layers.Print(data_2) + + data_3 = paddle.randperm(num, stop_gradient=False, device="cpu") + fluid.layers.Print(data_3) + + paddle.randperm(num, out=data_3) + fluid.layers.Print(data_3) + + place = fluid.CUDAPlace(0) if is_use_gpu else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(fluid.default_startup_program()) + exe.run() + + """ + + if n < 1: + raise ValueError("The input n should be greater than 0 in randperm op.") + check_dtype(dtype, 'dtype', ['int64', 'int32'], 'randperm') + dtype = convert_dtype(dtype) + if device not in [None, 'cpu', 'gpu']: + raise ValueError("The input device should in [None, 'cpu', 'gpu'].") + check_type(stop_gradient, 'stop_gradient', bool, 'randperm') + + helper = LayerHelper("randperm", **locals()) + if out is None: + out = helper.create_variable_for_type_inference(dtype=dtype) + else: + check_variable_and_dtype(out, 'out', [dtype], 'randperm') + if stop_gradient: + out.stop_gradient = True + inputs = dict() + outputs = {'Out': [out]} + attrs = {'n': n, 'dtype': out.dtype, 'seed': seed} + with device_guard(device): + helper.append_op( + type='randperm', inputs=inputs, outputs=outputs, attrs=attrs) + return out