From 7baabd4e6bf92595417e6fd688817d3df0e670f3 Mon Sep 17 00:00:00 2001 From: Kevin_Xiong Date: Tue, 17 Aug 2021 13:37:03 +0800 Subject: [PATCH] add randperm with test and docs (#5680) * add randperm with test and docs * format code * format * fix docs * format the code and add more tests * format code * Update test_randperm.py * Update randperm.py * add head * format codes * docs * Update test_randperm.py * Update test_randperm.py * add more tests * format * Update randperm.py * Update randperm_kernel.cu * Update randperm_kernel.cu * Update randperm_kernel.cpp * reconstruct the code * format the code * 2 * s * 1 * 1 * refine * add more test * refine code * fix according to comment * add more test Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/oneflow.rst | 1 + oneflow/core/functional/functional_api.yaml | 12 +- .../core/functional/impl/random_functor.cpp | 79 +++++++++++- oneflow/user/kernels/randperm_kernel.cpp | 56 +++++++++ oneflow/user/kernels/randperm_kernel.cu | 112 ++++++++++++++++++ oneflow/user/ops/randperm_op.cpp | 55 +++++++++ python/oneflow/__init__.py | 1 + python/oneflow/nn/modules/random_ops.py | 88 ++++++++++++++ python/oneflow/test/modules/test_randperm.py | 82 +++++++++++++ 9 files changed, 479 insertions(+), 7 deletions(-) create mode 100644 oneflow/user/kernels/randperm_kernel.cpp create mode 100644 oneflow/user/kernels/randperm_kernel.cu create mode 100644 oneflow/user/ops/randperm_op.cpp create mode 100644 python/oneflow/test/modules/test_randperm.py diff --git a/docs/source/oneflow.rst b/docs/source/oneflow.rst index c26ee76cf4..3a64cad28c 100644 --- a/docs/source/oneflow.rst +++ b/docs/source/oneflow.rst @@ -82,6 +82,7 @@ oneflow randn, repeat, reshape, + randperm, reciprocal, round, save, diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 3378136a27..f059e45f5b 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -1001,9 +1001,7 @@ - name: "scalar_logical_less_equal" signature: "Tensor ScalarLogicalLessEqual(Tensor in, Scalar scalar)" bind_python: False - -- name: "split" - signature: "TensorTuple Split(Tensor x, *, Int64 split_size, Int64 dim=0)" + - name: "rand" signature: "Tensor Rand(*, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)" bind_python: True @@ -1042,3 +1040,11 @@ - name: "l2_normalize_grad" signature: "Tensor L2NormalizeGrad(Tensor dy, Tensor y, Tensor square_x_sum, Int32 axis, Float epsilon, *)" bind_python: False + +- name: "randperm" + signature: "Tensor Randperm(Int32 n,*,Device device=None, Generator generator=None)" + bind_python: True + +- name: "consistent_randperm" + signature: "Tensor ConsistentRandperm(Int32 n,*, Placement placement, SbpList sbp_tuple, Generator generator=None)" + bind_python: True diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp index 1814e25336..0183d36616 100644 --- a/oneflow/core/functional/impl/random_functor.cpp +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -33,9 +33,8 @@ limitations under the License. #include "oneflow/user/kernels/bernoulli_kernel.h" #include "oneflow/user/kernels/distributions/normal_kernel.h" #include "oneflow/user/kernels/distributions/uniform_kernel.h" -#include "oneflow/core/job/parallel_desc.h" -#include "oneflow/core/job/global_for.h" - +#include "oneflow/core/job/sbp_parallel.h" +#include "oneflow/core/job/lazy_mode.h" namespace oneflow { namespace one { namespace functional { @@ -70,7 +69,6 @@ class BernoulliFunctor { private: std::shared_ptr bernoulli_op_; }; - class RandFunctor { public: RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); } @@ -92,6 +90,7 @@ class RandFunctor { JUST(attrs.SetAttr("dtype", dtype_val)); std::shared_ptr gen; + if (!generator) { gen = JUST(one::DefaultAutoGenerator()); } else { @@ -253,10 +252,82 @@ class ConsistentRandNFunctor { std::shared_ptr op_; }; +class RandPermFunctor { + public: + RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); } + Maybe operator()(const int32_t n, const Optional>& device, + const Optional& generator) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("n", n)); + std::shared_ptr gen; + if (!generator) { + gen = JUST(one::DefaultAutoGenerator()); + } else { + gen = JUST(generator.value()); + } + + JUST(attrs.SetAttr("seed", gen->current_seed())); + + const auto& randperm_kernel_state = std::make_shared(gen); + if (device.has_value()) { + Symbol device_symbol = JUST(device.value()); + return OpInterpUtil::Dispatch( + *randperm_op_, {}, OpExprInterpContext(attrs, device_symbol, randperm_kernel_state)); + } else { + return OpInterpUtil::Dispatch(*randperm_op_, {}, + OpExprInterpContext(attrs, randperm_kernel_state)); + } + } + + private: + std::shared_ptr randperm_op_; +}; + +class ConsistentRandPermFunctor { + public: + ConsistentRandPermFunctor() { + randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); + } + Maybe operator()(const int32_t n, const Symbol& placement, + const std::vector>& sbp_tuple, + const Optional& generator) const { + MutableAttrMap attrs; + JUST(attrs.SetAttr("n", n)); + + std::shared_ptr gen; + if (!generator) { + gen = JUST(one::DefaultAutoGenerator()); + } else { + gen = JUST(generator.value()); + } + + JUST(attrs.SetAttr("seed", gen->current_seed())); + + const auto& uniform_kernel_state = std::make_shared(gen); + + if (LazyMode::is_enabled()) { + std::vector nd_sbp(sbp_tuple.size()); + { + for (int i = 0; i < sbp_tuple.size(); ++i) { + nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i)); + } + } + JUST(attrs.SetAttr>("nd_sbp", nd_sbp)); + } + const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple)); + return OpInterpUtil::Dispatch( + *randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state)); + } + + private: + std::shared_ptr randperm_op_; +}; } // namespace impl ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Bernoulli"); + m.add_functor("Randperm"); + m.add_functor("ConsistentRandperm"); m.add_functor("Rand"); m.add_functor("ConsistentRand"); m.add_functor("RandN"); diff --git a/oneflow/user/kernels/randperm_kernel.cpp b/oneflow/user/kernels/randperm_kernel.cpp new file mode 100644 index 0000000000..31eaac34bd --- /dev/null +++ b/oneflow/user/kernels/randperm_kernel.cpp @@ -0,0 +1,56 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/framework/random_generator.h" +#include "oneflow/user/kernels/range_kernel_util.h" +#include "oneflow/user/kernels/distributions/uniform_kernel.h" +namespace oneflow { + +class CpuRandPermKernel final : public user_op::OpKernel { + public: + CpuRandPermKernel() = default; + ~CpuRandPermKernel() = default; + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + const auto& generator = CHECK_JUST(one::MakeAutoGenerator()); + generator->set_current_seed(ctx->Attr("seed")); + return std::make_shared(generator); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + int32_t* output = out->mut_dptr(); + const int32_t n = ctx->Attr("n"); + auto* randperm_kernel_state = dynamic_cast(state); + CHECK_NOTNULL(randperm_kernel_state); + const auto& generator = randperm_kernel_state->generator(); + const auto& cpu_generator = CHECK_JUST(generator->Get()); + CHECK_NOTNULL(generator); + user_op::RangeFunctor()(ctx->device_ctx(), 0, 1, n, output); + std::shuffle(output, output + n, cpu_generator->engine()); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +REGISTER_USER_KERNEL("randperm") + .SetCreateFn() + .SetIsMatchedHob((user_op::HobDeviceTag() == "cpu")); + +} // namespace oneflow diff --git a/oneflow/user/kernels/randperm_kernel.cu b/oneflow/user/kernels/randperm_kernel.cu new file mode 100644 index 0000000000..8fea7635cc --- /dev/null +++ b/oneflow/user/kernels/randperm_kernel.cu @@ -0,0 +1,112 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/op_kernel_state_wrapper.h" +#include "oneflow/core/common/data_type.h" +#include "oneflow/core/device/device_context.h" +#include "oneflow/core/framework/random_generator.h" +#include "oneflow/user/kernels/range_kernel_util.h" +#include "oneflow/user/kernels/distributions/uniform_kernel.h" +#include "oneflow/user/kernels/radix_sort.cuh" +#include +#include +namespace oneflow { +__global__ void GeneKeysAndValues(const int32_t n, int32_t* values, int32_t* keys, + curandState* state) { + XPU_1D_KERNEL_LOOP(i, n) { + keys[i] = curand(state + i); + values[i] = i; + } +} + +class GpuRandPermKernel final : public user_op::OpKernel { + public: + GpuRandPermKernel() = default; + ~GpuRandPermKernel() = default; + std::shared_ptr CreateOpKernelState( + user_op::KernelInitContext* ctx) const override { + const auto& generator = CHECK_JUST(one::MakeAutoGenerator()); + generator->set_current_seed(ctx->Attr("seed")); + return std::make_shared(generator); + } + + private: + void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + int32_t* output = out->mut_dptr(); + const int32_t n = ctx->Attr("n"); + user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0); + + auto* randperm_kernel_state = dynamic_cast(state); + CHECK_NOTNULL(randperm_kernel_state); + const auto& generator = randperm_kernel_state->generator(); + const auto& gpu_generator = CHECK_JUST(generator->Get()); + CHECK_NOTNULL(generator); + + int32_t block_num = gpu_generator->max_block_num(); + int32_t thread_num = gpu_generator->max_thread_num(); + curandState* curand_states = gpu_generator->curand_states(); + + // layout for tmp |...key(in and out,2xN)..|....value....|.... space for sort function....| + // values are the desired indexes ,and keys are generated randomly. + void* tmp = tmp_buffer->mut_dptr(); + int32_t* key_base = reinterpret_cast(tmp); + + const int32_t key_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); + int32_t* value_base = + reinterpret_cast(reinterpret_cast(key_base) + 2 * key_aligned_bytes); + + const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); + void* tmp_base = + reinterpret_cast(reinterpret_cast(value_base) + indices_aligned_bytes); + size_t temp_storage_bytes = InferTempStorageForSortPairsDescending(1, n); + + GeneKeysAndValues<<device_ctx()->cuda_stream()>>>(n, value_base, key_base, curand_states); + + auto err = cub::DeviceRadixSort::SortPairs( + /* d_temp_storage */ tmp_base, + /* temp_storage_bytes */ temp_storage_bytes, + /* d_keys_in */ key_base, + /* d_keys_out */ key_base + n, + /* d_values_in */ value_base, + /* d_values_out */ output, + /* num_items */ n, + /* begin_bit */ 0, + /* end_bit */ sizeof(int32_t) * 8, + /* stream */ ctx->device_ctx()->cuda_stream()); + OF_CUDA_CHECK(err); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; +REGISTER_USER_KERNEL("randperm") + .SetCreateFn() + .SetIsMatchedHob(user_op::HobDeviceTag() == "gpu") + .SetInferTmpSizeFn([](user_op::InferContext* ctx) { + const int32_t n = ctx->Attr("n"); + /* Sorted In */ + const int32_t sorted_in_aligned_bytes = 2 * GetCudaAlignedSize(n * sizeof(int32_t)); + /* Indices */ + const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t)); + + /* CUB Temp Storage */ + const int32_t temp_storage_bytes = + InferTempStorageForSortPairsDescending(1, n); + + return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes; + }); + +} // namespace oneflow diff --git a/oneflow/user/ops/randperm_op.cpp b/oneflow/user/ops/randperm_op.cpp new file mode 100644 index 0000000000..6696064331 --- /dev/null +++ b/oneflow/user/ops/randperm_op.cpp @@ -0,0 +1,55 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/common/protobuf.h" +#include "oneflow/core/common/global.h" +#include "oneflow/core/job/global_for.h" +namespace oneflow { + +Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx); +REGISTER_NO_GRAD_USER_OP("randperm") + .Output("out") + .Attr("n") + .Attr("nd_sbp") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + Shape* out_shape = ctx->OutputShape("out", 0); + int32_t n = ctx->Attr("n"); + CHECK_GE_OR_RETURN(n, 0); + *out_shape = Shape({n}); + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { return Maybe::Ok(); }) + .SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe { + *ctx->OutputDType("out", 0) = DataType::kInt32; + return Maybe::Ok(); + }) + .SetNdSbpInferFn(&InferRandpermNdSbp); + +Maybe InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx) { + cfg::NdSbp* out = ctx->NdSbp4ArgNameAndIndex("out", 0); + if (JUST(*Global, MultiClient>::Get())) { + const auto& pb_str = ctx->user_op_conf().attr("nd_sbp"); + NdSbp pb; + CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb)); + out->InitFromProto(pb); + } else { + out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel(); + } + return Maybe::Ok(); +} + +} // namespace oneflow diff --git a/python/oneflow/__init__.py b/python/oneflow/__init__.py index cb49dd388a..7f6e23bfd2 100644 --- a/python/oneflow/__init__.py +++ b/python/oneflow/__init__.py @@ -257,6 +257,7 @@ from oneflow.nn.modules.nonzero import nonzero_op as nonzero from oneflow.nn.modules.random_ops import bernoulli from oneflow.nn.modules.random_ops import rand_op as rand from oneflow.nn.modules.random_ops import randn_op as randn +from oneflow.nn.modules.random_ops import randperm from oneflow.nn.modules.reduce_ops import _max as max from oneflow.nn.modules.reduce_ops import _mean as mean from oneflow.nn.modules.reduce_ops import _min as min diff --git a/python/oneflow/nn/modules/random_ops.py b/python/oneflow/nn/modules/random_ops.py index 6752b0a880..fb7503bc89 100644 --- a/python/oneflow/nn/modules/random_ops.py +++ b/python/oneflow/nn/modules/random_ops.py @@ -264,6 +264,94 @@ def randn_op( )() +class Randperm(Module): + def __init__( + self, + n, + generator: flow.Generator = None, + dtype: flow.dtype = flow.int32, + layout=None, + device: Union[flow.device, str, None] = None, + placement: flow.placement = None, + sbp: flow._oneflow_internal.sbp.sbp = None, + requires_grad: bool = False, + pin_memory: bool = False, + ) -> None: + super().__init__() + assert n >= 0 + self.n = n + self.requires_grad = requires_grad + ( + self.size, + self.device, + self.generator, + self.placement, + self.sbp, + ) = _rand_op_common_process(1, device, generator, placement, sbp) + self.dtype = dtype + + def forward(self, out=None): + if self.placement is not None: + res = flow.F.consistent_randperm( + self.n, self.placement, self.sbp, self.generator + ) + else: + res = flow.F.randperm(self.n, self.device, self.generator) + res.requires_grad = self.requires_grad + return res.to(dtype=self.dtype) + + +def randperm( + n: flow.int32, + generator: flow.Generator = None, + out=None, + dtype: flow.dtype = flow.int32, + layout=None, + device: Union[flow.device, str, None] = None, + placement: flow.placement = None, + sbp: flow._oneflow_internal.sbp.sbp = None, + requires_grad: bool = False, + pin_memory: bool = False, +): + r""" + Returns a random permutation of integers from ``0`` to ``n - 1``. + + Args: + n (int): the upper bound (exclusive) + + Keyword args: + generator(:class:`oneflow.Generator`, optional): a pseudorandom number generator for sampling + out (Tensor, optional): output Tensor,not supported yet. + dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor. + Default: ``oneflow.int32``. + layout: layout is not supported yet. + device: the desired device of returned tensor. Default: cpu. + placement:(:class:`flow.placement`, optional): The desired device of returned consistent tensor. If None, + will construct local tensor. + sbp: (:class:`flow.sbp`, optional): The desired sbp of returned consistent tensor. It must be equal with the + numbers of placement. + requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False. + pin_memory(bool, optional):pin_memory is not supported yet. + + Example: + + .. code-block:: python + + >>> import oneflow as flow + >>> generator = flow.Generator() + >>> generator.manual_seed(0) + >>> flow.randperm(5, generator=generator) + tensor([2, 4, 3, 0, 1], dtype=oneflow.int32) + """ + assert out is None, "out not supported yet" + assert layout is None, "layout not supported yet" + if generator is None: + generator = flow.default_generator() + return Randperm( + n, generator, dtype, layout, device, placement, sbp, requires_grad, pin_memory + )(out) + + if __name__ == "__main__": import doctest diff --git a/python/oneflow/test/modules/test_randperm.py b/python/oneflow/test/modules/test_randperm.py new file mode 100644 index 0000000000..a826324237 --- /dev/null +++ b/python/oneflow/test/modules/test_randperm.py @@ -0,0 +1,82 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import oneflow as flow +from collections import OrderedDict +from automated_test_util import * +from test_util import GenArgList +import numpy as np +import unittest + + +def _test_randperm_with_generator(test_case, N, device, dtype): + generator = flow.Generator() + generator.manual_seed(0) + y_1 = flow.randperm(N, device=device, dtype=dtype, generator=generator) + generator = flow.Generator() + generator.manual_seed(0) + y_2 = flow.randperm(N, device=device, dtype=dtype, generator=generator) + test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy())) + test_case.assertTrue( + y_1.device == flow.device(device) and y_2.device == flow.device(device) + ) + test_case.assertTrue(y_1.dtype == dtype and y_2.dtype == dtype) + + +def _test_randperm_backward(test_case, N, device, dtype): + x = flow.randperm(N, device=device, dtype=dtype) + x.requires_grad = True + y = x.sum() + y.backward() + test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(N), 1e-05, 1e-05)) + + +@flow.unittest.skip_unless_1n1d() +class Testrandperm(flow.unittest.TestCase): + def test_randperm(test_case): + arg_dict = OrderedDict() + arg_dict["test_functions"] = [ + _test_randperm_with_generator, + _test_randperm_backward, + ] + arg_dict["N"] = [i for i in range(10, 100, 5)] + arg_dict["device"] = ["cpu", "cuda"] + arg_dict["dtype"] = [flow.int32, flow.int64, flow.float32, flow.float64] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + @autotest(auto_backward=False) + def test_auto_1(test_case): + device = random_device() + y = torch.randperm(1, device=device) + return y + + @autotest(n=5, auto_backward=False) + def test_auto_0(test_case): + device = random_device() + y = torch.randperm(0, device=device) + return y + + def test_randperm_randomness(test_case): + device = "cuda" + n = np.random.randint(100, 1000) + x1 = flow.randperm(n, device=device) + x2 = flow.randperm(n, device=device) + test_case.assertTrue(not np.all(x1.numpy() == x2.numpy())) + device = "cpu" + n = np.random.randint(100, 1000) + x1 = flow.randperm(n, device=device) + x2 = flow.randperm(n, device=device) + test_case.assertTrue(not np.all(x1.numpy() == x2.numpy())) -- GitLab