未验证 提交 7baabd4e 编写于 作者: K Kevin_Xiong 提交者: GitHub

add randperm with test and docs (#5680)

* add randperm with test and docs

* format code

* format

* fix docs

* format the code and add more tests

* format code

* Update test_randperm.py

* Update randperm.py

* add head

* format codes

* docs

* Update test_randperm.py

* Update test_randperm.py

* add more tests

* format

* Update randperm.py

* Update randperm_kernel.cu

* Update randperm_kernel.cu

* Update randperm_kernel.cpp

* reconstruct the code

* format the code

* 2

* s

* 1

* 1

* refine

* add more test

* refine code

* fix according to comment

* add more test
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 442f77da
......@@ -82,6 +82,7 @@ oneflow
randn,
repeat,
reshape,
randperm,
reciprocal,
round,
save,
......
......@@ -1001,9 +1001,7 @@
- name: "scalar_logical_less_equal"
signature: "Tensor ScalarLogicalLessEqual(Tensor in, Scalar scalar)"
bind_python: False
- name: "split"
signature: "TensorTuple Split(Tensor x, *, Int64 split_size, Int64 dim=0)"
- name: "rand"
signature: "Tensor Rand(*, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)"
bind_python: True
......@@ -1042,3 +1040,11 @@
- name: "l2_normalize_grad"
signature: "Tensor L2NormalizeGrad(Tensor dy, Tensor y, Tensor square_x_sum, Int32 axis, Float epsilon, *)"
bind_python: False
- name: "randperm"
signature: "Tensor Randperm(Int32 n,*,Device device=None, Generator generator=None)"
bind_python: True
- name: "consistent_randperm"
signature: "Tensor ConsistentRandperm(Int32 n,*, Placement placement, SbpList sbp_tuple, Generator generator=None)"
bind_python: True
......@@ -33,9 +33,8 @@ limitations under the License.
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/normal_kernel.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
namespace one {
namespace functional {
......@@ -70,7 +69,6 @@ class BernoulliFunctor {
private:
std::shared_ptr<OpExpr> bernoulli_op_;
};
class RandFunctor {
public:
RandFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
......@@ -92,6 +90,7 @@ class RandFunctor {
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
......@@ -253,10 +252,82 @@ class ConsistentRandNFunctor {
std::shared_ptr<OpExpr> op_;
};
class RandPermFunctor {
public:
RandPermFunctor() { randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build()); }
Maybe<Tensor> operator()(const int32_t n, const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("n", n));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& randperm_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device.value());
return OpInterpUtil::Dispatch<Tensor>(
*randperm_op_, {}, OpExprInterpContext(attrs, device_symbol, randperm_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*randperm_op_, {},
OpExprInterpContext(attrs, randperm_kernel_state));
}
}
private:
std::shared_ptr<OpExpr> randperm_op_;
};
class ConsistentRandPermFunctor {
public:
ConsistentRandPermFunctor() {
randperm_op_ = CHECK_JUST(one::OpBuilder("randperm").Output("out").Build());
}
Maybe<Tensor> operator()(const int32_t n, const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<one::Generator>& generator) const {
MutableAttrMap attrs;
JUST(attrs.SetAttr<int32_t>("n", n));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
{
for (int i = 0; i < sbp_tuple.size(); ++i) {
nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
}
}
JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(
*randperm_op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
}
private:
std::shared_ptr<OpExpr> randperm_op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BernoulliFunctor>("Bernoulli");
m.add_functor<impl::RandPermFunctor>("Randperm");
m.add_functor<impl::ConsistentRandPermFunctor>("ConsistentRandperm");
m.add_functor<impl::RandFunctor>("Rand");
m.add_functor<impl::ConsistentRandFunctor>("ConsistentRand");
m.add_functor<impl::RandNFunctor>("RandN");
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/user/kernels/range_kernel_util.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
namespace oneflow {
class CpuRandPermKernel final : public user_op::OpKernel {
public:
CpuRandPermKernel() = default;
~CpuRandPermKernel() = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
int32_t* output = out->mut_dptr<int32_t>();
const int32_t n = ctx->Attr<int32_t>("n");
auto* randperm_kernel_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(randperm_kernel_state);
const auto& generator = randperm_kernel_state->generator();
const auto& cpu_generator = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
CHECK_NOTNULL(generator);
user_op::RangeFunctor<DeviceType::kCPU, int32_t>()(ctx->device_ctx(), 0, 1, n, output);
std::shuffle(output, output + n, cpu_generator->engine());
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("randperm")
.SetCreateFn<CpuRandPermKernel>()
.SetIsMatchedHob((user_op::HobDeviceTag() == "cpu"));
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include "oneflow/core/common/data_type.h"
#include "oneflow/core/device/device_context.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/user/kernels/range_kernel_util.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/user/kernels/radix_sort.cuh"
#include <curand.h>
#include <curand_kernel.h>
namespace oneflow {
__global__ void GeneKeysAndValues(const int32_t n, int32_t* values, int32_t* keys,
curandState* state) {
XPU_1D_KERNEL_LOOP(i, n) {
keys[i] = curand(state + i);
values[i] = i;
}
}
class GpuRandPermKernel final : public user_op::OpKernel {
public:
GpuRandPermKernel() = default;
~GpuRandPermKernel() = default;
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<UniformKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0);
int32_t* output = out->mut_dptr<int32_t>();
const int32_t n = ctx->Attr<int32_t>("n");
user_op::Tensor* tmp_buffer = ctx->Tensor4ArgNameAndIndex("tmp_buffer", 0);
auto* randperm_kernel_state = dynamic_cast<UniformKernelState*>(state);
CHECK_NOTNULL(randperm_kernel_state);
const auto& generator = randperm_kernel_state->generator();
const auto& gpu_generator = CHECK_JUST(generator->Get<one::CUDAGeneratorImpl>());
CHECK_NOTNULL(generator);
int32_t block_num = gpu_generator->max_block_num();
int32_t thread_num = gpu_generator->max_thread_num();
curandState* curand_states = gpu_generator->curand_states();
// layout for tmp |...key(in and out,2xN)..|....value....|.... space for sort function....|
// values are the desired indexes ,and keys are generated randomly.
void* tmp = tmp_buffer->mut_dptr<void>();
int32_t* key_base = reinterpret_cast<int32_t*>(tmp);
const int32_t key_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));
int32_t* value_base =
reinterpret_cast<int32_t*>(reinterpret_cast<char*>(key_base) + 2 * key_aligned_bytes);
const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));
void* tmp_base =
reinterpret_cast<void*>(reinterpret_cast<char*>(value_base) + indices_aligned_bytes);
size_t temp_storage_bytes = InferTempStorageForSortPairsDescending<int32_t, int32_t>(1, n);
GeneKeysAndValues<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx->device_ctx()->cuda_stream()>>>(n, value_base, key_base, curand_states);
auto err = cub::DeviceRadixSort::SortPairs(
/* d_temp_storage */ tmp_base,
/* temp_storage_bytes */ temp_storage_bytes,
/* d_keys_in */ key_base,
/* d_keys_out */ key_base + n,
/* d_values_in */ value_base,
/* d_values_out */ output,
/* num_items */ n,
/* begin_bit */ 0,
/* end_bit */ sizeof(int32_t) * 8,
/* stream */ ctx->device_ctx()->cuda_stream());
OF_CUDA_CHECK(err);
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};
REGISTER_USER_KERNEL("randperm")
.SetCreateFn<GpuRandPermKernel>()
.SetIsMatchedHob(user_op::HobDeviceTag() == "gpu")
.SetInferTmpSizeFn([](user_op::InferContext* ctx) {
const int32_t n = ctx->Attr<int32_t>("n");
/* Sorted In */
const int32_t sorted_in_aligned_bytes = 2 * GetCudaAlignedSize(n * sizeof(int32_t));
/* Indices */
const int32_t indices_aligned_bytes = GetCudaAlignedSize(n * sizeof(int32_t));
/* CUB Temp Storage */
const int32_t temp_storage_bytes =
InferTempStorageForSortPairsDescending<int32_t, int32_t>(1, n);
return sorted_in_aligned_bytes + indices_aligned_bytes + temp_storage_bytes;
});
} // namespace oneflow
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/common/global.h"
#include "oneflow/core/job/global_for.h"
namespace oneflow {
Maybe<void> InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx);
REGISTER_NO_GRAD_USER_OP("randperm")
.Output("out")
.Attr<int32_t>("n")
.Attr<std::string>("nd_sbp")
.SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->OutputShape("out", 0);
int32_t n = ctx->Attr<int32_t>("n");
CHECK_GE_OR_RETURN(n, 0);
*out_shape = Shape({n});
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> { return Maybe<void>::Ok(); })
.SetDataTypeInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
*ctx->OutputDType("out", 0) = DataType::kInt32;
return Maybe<void>::Ok();
})
.SetNdSbpInferFn(&InferRandpermNdSbp);
Maybe<void> InferRandpermNdSbp(user_op::InferNdSbpFnContext* ctx) {
cfg::NdSbp* out = ctx->NdSbp4ArgNameAndIndex("out", 0);
if (JUST(*Global<Maybe<bool>, MultiClient>::Get())) {
const auto& pb_str = ctx->user_op_conf().attr<std::string>("nd_sbp");
NdSbp pb;
CHECK_OR_RETURN(TxtString2PbMessage(pb_str, &pb));
out->InitFromProto(pb);
} else {
out->mutable_sbp_parallel()->Add()->mutable_broadcast_parallel();
}
return Maybe<void>::Ok();
}
} // namespace oneflow
......@@ -257,6 +257,7 @@ from oneflow.nn.modules.nonzero import nonzero_op as nonzero
from oneflow.nn.modules.random_ops import bernoulli
from oneflow.nn.modules.random_ops import rand_op as rand
from oneflow.nn.modules.random_ops import randn_op as randn
from oneflow.nn.modules.random_ops import randperm
from oneflow.nn.modules.reduce_ops import _max as max
from oneflow.nn.modules.reduce_ops import _mean as mean
from oneflow.nn.modules.reduce_ops import _min as min
......
......@@ -264,6 +264,94 @@ def randn_op(
)()
class Randperm(Module):
def __init__(
self,
n,
generator: flow.Generator = None,
dtype: flow.dtype = flow.int32,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> None:
super().__init__()
assert n >= 0
self.n = n
self.requires_grad = requires_grad
(
self.size,
self.device,
self.generator,
self.placement,
self.sbp,
) = _rand_op_common_process(1, device, generator, placement, sbp)
self.dtype = dtype
def forward(self, out=None):
if self.placement is not None:
res = flow.F.consistent_randperm(
self.n, self.placement, self.sbp, self.generator
)
else:
res = flow.F.randperm(self.n, self.device, self.generator)
res.requires_grad = self.requires_grad
return res.to(dtype=self.dtype)
def randperm(
n: flow.int32,
generator: flow.Generator = None,
out=None,
dtype: flow.dtype = flow.int32,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False,
pin_memory: bool = False,
):
r"""
Returns a random permutation of integers from ``0`` to ``n - 1``.
Args:
n (int): the upper bound (exclusive)
Keyword args:
generator(:class:`oneflow.Generator`, optional): a pseudorandom number generator for sampling
out (Tensor, optional): output Tensor,not supported yet.
dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor.
Default: ``oneflow.int32``.
layout: layout is not supported yet.
device: the desired device of returned tensor. Default: cpu.
placement:(:class:`flow.placement`, optional): The desired device of returned consistent tensor. If None,
will construct local tensor.
sbp: (:class:`flow.sbp`, optional): The desired sbp of returned consistent tensor. It must be equal with the
numbers of placement.
requires_grad(bool, optional): If autograd should record operations on the returned tensor. Default: False.
pin_memory(bool, optional):pin_memory is not supported yet.
Example:
.. code-block:: python
>>> import oneflow as flow
>>> generator = flow.Generator()
>>> generator.manual_seed(0)
>>> flow.randperm(5, generator=generator)
tensor([2, 4, 3, 0, 1], dtype=oneflow.int32)
"""
assert out is None, "out not supported yet"
assert layout is None, "layout not supported yet"
if generator is None:
generator = flow.default_generator()
return Randperm(
n, generator, dtype, layout, device, placement, sbp, requires_grad, pin_memory
)(out)
if __name__ == "__main__":
import doctest
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import oneflow as flow
from collections import OrderedDict
from automated_test_util import *
from test_util import GenArgList
import numpy as np
import unittest
def _test_randperm_with_generator(test_case, N, device, dtype):
generator = flow.Generator()
generator.manual_seed(0)
y_1 = flow.randperm(N, device=device, dtype=dtype, generator=generator)
generator = flow.Generator()
generator.manual_seed(0)
y_2 = flow.randperm(N, device=device, dtype=dtype, generator=generator)
test_case.assertTrue(np.allclose(y_1.numpy(), y_2.numpy()))
test_case.assertTrue(
y_1.device == flow.device(device) and y_2.device == flow.device(device)
)
test_case.assertTrue(y_1.dtype == dtype and y_2.dtype == dtype)
def _test_randperm_backward(test_case, N, device, dtype):
x = flow.randperm(N, device=device, dtype=dtype)
x.requires_grad = True
y = x.sum()
y.backward()
test_case.assertTrue(np.allclose(x.grad.numpy(), np.ones(N), 1e-05, 1e-05))
@flow.unittest.skip_unless_1n1d()
class Testrandperm(flow.unittest.TestCase):
def test_randperm(test_case):
arg_dict = OrderedDict()
arg_dict["test_functions"] = [
_test_randperm_with_generator,
_test_randperm_backward,
]
arg_dict["N"] = [i for i in range(10, 100, 5)]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["dtype"] = [flow.int32, flow.int64, flow.float32, flow.float64]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
@autotest(auto_backward=False)
def test_auto_1(test_case):
device = random_device()
y = torch.randperm(1, device=device)
return y
@autotest(n=5, auto_backward=False)
def test_auto_0(test_case):
device = random_device()
y = torch.randperm(0, device=device)
return y
def test_randperm_randomness(test_case):
device = "cuda"
n = np.random.randint(100, 1000)
x1 = flow.randperm(n, device=device)
x2 = flow.randperm(n, device=device)
test_case.assertTrue(not np.all(x1.numpy() == x2.numpy()))
device = "cpu"
n = np.random.randint(100, 1000)
x1 = flow.randperm(n, device=device)
x2 = flow.randperm(n, device=device)
test_case.assertTrue(not np.all(x1.numpy() == x2.numpy()))
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册