未验证 提交 aadbb464 编写于 作者: K Kevin_Xiong 提交者: GitHub

add randint (#5718)

* add randint

* add

* add doc test

* Update randint_kernel.cu

* Update randint_kernel.cpp

* Update randint_op.cpp

* reconstrcut

* refine the code

* add test

* add test

* add test

* format

* Dev randint refine (#5981)

* disable backward pass consistent tensor meta check. (#5871)

* disable backward pass consistent tensor meta check.

* auto format by CI
Co-authored-by: qq_22305325's avatarbinbinHan <han_binbin@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* ddp broadcast params and buffers (#5913)

* ddp broadcast params and buffers
Signed-off-by: Ndaquexian <daquexian566@gmail.com>

* auto format by CI
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* add clang tidy target (#5957)

* add clang tidy target

* fix a bug

* refine

* refine

* reformat
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* cfg: add move assignment operator for performance (#5962)
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* add zhangshen op-test (#5600)

* add some op-test

* fix dims_error in my branch

* Fix the bad backward kernel function by using 'cuda::atomic::Add' (#5614)

* Test `nn.AdaptiveAvgPoolXd` (#5615)

* Fix the bad backward kernel function by using 'cuda::atomic::Add'

* Support the 'NoneType' annotation

* Support objects of 'collections.abc.Iterable' as 'output_size'

* Test with all cases of 'output_size'

* Update adaptive_pool_gpu_kernel.cu

* Skip testing `nn.AdaptiveAvgPool3d` for the current PyTorch

* remove some useless test

* Format TODO

* Add the assertion messages for 'output_size'

* Reformat codes

* Remove raw tests for `flow.negative`

* Remove unnecessary codes and add the assertion messages

* Merge updates for 'generators.py' from master

* Remove unnecessary 'random()'

* Delete the separate test for `AvgPool2d`

* Fix import paths

* Fix import problems

* Remove the PyTorch import

* Denote the annotations for `tile` and `repeat` ops

* Add the test for `nn.AvgPool1d`

* Choose better generators for `nn.MaxPoolXd`

* Randomly choose `dilation` and default values

* auto format by CI

* Test more kwargs for `nn.AvgPoolXd`

* Add tests for `return_indices`

* auto format by CI
Co-authored-by: NTianyu Zhao <guikarist@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* fix wrong names (#5951)

* fix wrong names

* auto format by CI

* refine

* auto format by CI
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Enable more checkers for clang-tidy in CI (#5738)

* CI: enable more checkers for clang-tidy

* .clang-tidy: remove cppcoreguidelines-pro-type-vararg

* CI: remove duplicate checkers

* CI: remove clang-analyzer-alpha.deadcode.*

* .clang-tidy: add performance-*

* oneflow/core/eager: remove unnecessary malloc & free

* .clang-tidy: add clang-analyzer-cplusplus.* to werror

* user_kernel: remove useless move

* quantization_aware_training: fix move return

* .clang-tidy: add google-*

* CI: fix clang tidy command

* CI: fix test
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Feat grad mode classes (#5956)

* feat(no_grad): support no_grad decorator

* feat(AutogradMode): export flow.autograd_mode

* feat(GradMode): export some grad_mode class

* docs(GradMode): export documents

* refine

* docs(GradMode): export document for is_grad_enabled

* auto format by CI

* fix(GradMode): fix single client bug

* fix bug
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* extract_consistent_to_consistent_op_expr (#5870)

* abstract_consistent_to_consistent_op_expr

* fix compiler complaint

* refactor consistent-to-consistent eager consisitent op interpreter

* fix compiler complaint

* refactor ConsistentToConsistentOpExpr

* lazy interpreter (#5903)

* fix bugs about consistent_id

* refactor functional::ToConsistent

* refactor GetNdSbp

* Update eager_consistent_op_interpreter.cpp

* Update eager_mirrored_op_interpreter.cpp

* fix error

* fix error

* auto format by CI

* Update nd_sbp.h

* refine identity boxing

* fix sync checkmeta error

* avoid consistent id check in lazy
Co-authored-by: NXinqi Li <lixinqi0703106@163.com>
Co-authored-by: Nleaves-zwx <kunta0932@gmail.com>
Co-authored-by: NLi Xinqi <lixinqi2010@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* add CMAKE_INTERPROCEDURAL_OPTIMIZATION in fast cmake cache (#5970)

* add CMAKE_INTERPROCEDURAL_OPTIMIZATION in fast cmake cache

* skip test targets of re2
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* check: fix clang-tidy-diff commands (#5972)

* check: fix clang-tidy-diff commands

* CI: fix step names
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Cpu mpi (#5865)

* cuda base cpu mpi boxing

* cpu_mpi

* fix conflicts

* add cpu mpi unittests

* more checks and unittests

* abstract_consistent_to_consistent_op_expr

* fix compiler complaint

* refactor consistent-to-consistent eager consisitent op interpreter

* fix compiler complaint

* refactor ConsistentToConsistentOpExpr

* lazy interpreter (#5903)

* fix bugs about consistent_id

* more test_consistent_cast unittests

* refactor functional::ToConsistent

* refactor GetNdSbp

* fix compiler complaints

* refactor GetDevice4CurrentProcessCtx

* fix error
Co-authored-by: qq_22305325's avatarclackhan <han_binbin@163.com>
Co-authored-by: Nleaves-zwx <kunta0932@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* fix_bug_test_tensor_str (#5958)

* fix bug int test_tensor_str

* format

* fix comment

* fix bug to(cuda) is unavailable in cpu env
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* common/error: fix build error in mac (#5971)
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>

* Prevent running oneflow in forked subprocess (#5976)

* prevent_running_oneflow_in_forked_subprocess

* add line change

* IsFork => IsForkedSubProcess

* auto format by CI
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>

* refine randint
Co-authored-by: NLi Xinqi <lixinqi2010@gmail.com>
Co-authored-by: qq_22305325's avatarbinbinHan <han_binbin@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Ndaquexian <daquexian566@gmail.com>
Co-authored-by: NPeihong Liu <mosout@qq.com>
Co-authored-by: NTwice <i@twice.moe>
Co-authored-by: NZhangShen <55383772+zhangshen12356@users.noreply.github.com>
Co-authored-by: NTianyu Zhao <guikarist@gmail.com>
Co-authored-by: NLuyang <flowingsun007@163.com>
Co-authored-by: NYinggang Wang <wyg19970408@gmail.com>
Co-authored-by: NXinqi Li <lixinqi0703106@163.com>
Co-authored-by: Nleaves-zwx <kunta0932@gmail.com>
Co-authored-by: NShenghang Tsai <jackalcooper@gmail.com>
Co-authored-by: Nliufengwei0103 <2472937968@qq.com>

* refine

* refine

* auto format by CI

* refine

* Update functional_api.yaml

* Update functional_api.yaml

* refine the code

* auto format by CI

* refine

* fix ci error

* fix test

* auto format by CI

* fixtest

* refine code

* auto format by CI

* refine code

* auto format by CI

* fix ci fail

* remove redefination api

* fix ci test

* auto format by CI

* fix consistency with torch

* auto format by CI

* unittest fixed

* fix doctest
Co-authored-by: NBowen Chen <bob2420083992@gmail.com>
Co-authored-by: NLi Xinqi <lixinqi2010@gmail.com>
Co-authored-by: qq_22305325's avatarbinbinHan <han_binbin@163.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
Co-authored-by: Noneflow-ci-bot <ci-bot@oneflow.org>
Co-authored-by: Ndaquexian <daquexian566@gmail.com>
Co-authored-by: NPeihong Liu <mosout@qq.com>
Co-authored-by: NTwice <i@twice.moe>
Co-authored-by: NZhangShen <55383772+zhangshen12356@users.noreply.github.com>
Co-authored-by: NTianyu Zhao <guikarist@gmail.com>
Co-authored-by: NLuyang <flowingsun007@163.com>
Co-authored-by: NYinggang Wang <wyg19970408@gmail.com>
Co-authored-by: NXinqi Li <lixinqi0703106@163.com>
Co-authored-by: Nleaves-zwx <kunta0932@gmail.com>
Co-authored-by: NShenghang Tsai <jackalcooper@gmail.com>
Co-authored-by: Nliufengwei0103 <2472937968@qq.com>
上级 92929e18
......@@ -84,6 +84,7 @@ oneflow
randn,
repeat,
reshape,
randint,
randperm,
reciprocal,
round,
......
......@@ -1048,19 +1048,27 @@
Generator generator=None) => ConsistentRandN"
bind_python: True
- name: "randint"
signature: "Tensor (Int64 low, Int64 high, Shape shape, DataType dtype=None, Device device=None, Generator generator=None)=> RandInt"
bind_python: True
- name: "consistent_randint"
signature: "Tensor (Int64 low, Int64 high, Shape shape, Placement placement, SbpList sbp_tuple, DataType dtype=None, Generator generator=None)=> ConsistentRandInt"
bind_python: True
- name: "unfold"
signature: "Tensor (Tensor x, String data_format=\"channels_first\", Int32List kernel_size,
Int32List dilation_rate, Int32List padding,
Int32List strides) => Unfold"
bind_python: True
- name: "fold"
signature: "Tensor (Tensor x, String data_format=\"channels_first\",
Int32List output_size, Int32List kernel_size,
Int32List dilation_rate, Int32List padding,
Int32List strides) => Fold"
bind_python: True
- name: "split"
signature: "TensorTuple (Tensor x, Int64 split_size, Int64 dim=0) => Split"
bind_python: True
......@@ -1078,7 +1086,7 @@
bind_python: False
- name: "randperm"
signature: "Tensor (Int32 n, Device device=None, Generator generator=None) => Randperm"
signature: "Tensor (Int32 n, Device device=None, Generator generator=None) => RandPerm"
bind_python: True
- name: "fused_self_attention"
......
......@@ -30,11 +30,12 @@ limitations under the License.
#include "oneflow/core/functional/impl/unary_functor.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/lazy_mode.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/distributions/normal_kernel.h"
#include "oneflow/user/kernels/distributions/uniform_kernel.h"
#include "oneflow/core/job/sbp_parallel.h"
#include "oneflow/core/job/lazy_mode.h"
namespace oneflow {
namespace one {
namespace functional {
......@@ -251,6 +252,103 @@ class ConsistentRandNFunctor {
private:
std::shared_ptr<OpExpr> op_;
};
class RandIntFunctor {
public:
RandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
const Optional<Symbol<DType>>& dtype,
const Optional<Symbol<Device>>& device,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kInt64;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<double>("low", low));
JUST(attrs.SetAttr<double>("high", high - 1));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (device.has_value()) {
Symbol<Device> device_symbol = JUST(device.value());
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, device_symbol, uniform_kernel_state));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {},
OpExprInterpContext(attrs, uniform_kernel_state));
}
}
private:
std::shared_ptr<OpExpr> op_;
};
class ConsistentRandIntFunctor {
public:
ConsistentRandIntFunctor() { op_ = CHECK_JUST(one::OpBuilder("uniform").Output("out").Build()); }
Maybe<Tensor> operator()(const int64_t low, const int64_t high, const Shape& shape,
const Symbol<ParallelDesc>& placement,
const std::vector<Symbol<cfg::SbpParallel>>& sbp_tuple,
const Optional<Symbol<DType>>& dtype,
const Optional<one::Generator>& generator) const {
DataType dtype_val = DataType::kInt64;
if (dtype.has_value()) {
dtype_val = JUST(dtype.value())->data_type();
if (dtype_val != DataType::kFloat && dtype_val != DataType::kDouble) {
OF_UNIMPLEMENTED() << dtype_val << "not supported in randn";
}
}
MutableAttrMap attrs;
JUST(attrs.SetAttr<Shape>("shape", shape));
JUST(attrs.SetAttr<double>("low", low));
JUST(attrs.SetAttr<double>("high", high - 1));
JUST(attrs.SetAttr<DataType>("dtype", dtype_val));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& uniform_kernel_state = std::make_shared<UniformKernelState>(gen);
if (LazyMode::is_enabled()) {
std::vector<std::string> nd_sbp(sbp_tuple.size());
{
for (int i = 0; i < sbp_tuple.size(); ++i) {
nd_sbp.at(i) = SbpParallelToString(*sbp_tuple.at(i));
}
}
JUST(attrs.SetAttr<std::vector<std::string>>("nd_sbp", nd_sbp));
}
const auto& nd_sbp = JUST(GetNdSbp(sbp_tuple));
return OpInterpUtil::Dispatch<Tensor>(
*op_, {}, OpExprInterpContext(attrs, placement, nd_sbp, uniform_kernel_state));
}
private:
std::shared_ptr<OpExpr> op_;
};
class RandPermFunctor {
public:
......@@ -326,12 +424,14 @@ class ConsistentRandPermFunctor {
ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BernoulliFunctor>("Bernoulli");
m.add_functor<impl::RandPermFunctor>("Randperm");
m.add_functor<impl::ConsistentRandPermFunctor>("ConsistentRandperm");
m.add_functor<impl::RandPermFunctor>("RandPerm");
m.add_functor<impl::ConsistentRandPermFunctor>("ConsistentRandPerm");
m.add_functor<impl::RandFunctor>("Rand");
m.add_functor<impl::ConsistentRandFunctor>("ConsistentRand");
m.add_functor<impl::RandNFunctor>("RandN");
m.add_functor<impl::ConsistentRandNFunctor>("ConsistentRandN");
m.add_functor<impl::RandIntFunctor>("RandInt");
m.add_functor<impl::ConsistentRandIntFunctor>("ConsistentRandInt");
};
} // namespace functional
......
......@@ -76,8 +76,8 @@ class GpuRandPermKernel final : public user_op::OpKernel {
reinterpret_cast<void*>(reinterpret_cast<char*>(value_base) + indices_aligned_bytes);
size_t temp_storage_bytes = InferTempStorageForSortPairsDescending<int32_t, int32_t>(1, n);
GeneKeysAndValues<<<BlocksNum4ThreadsNum(n), kCudaThreadsNumPerBlock, 0,
ctx->device_ctx()->cuda_stream()>>>(n, value_base, key_base, curand_states);
GeneKeysAndValues<<<block_num, kCudaThreadsNumPerBlock, 0, ctx->device_ctx()->cuda_stream()>>>(
n, value_base, key_base, curand_states);
auto err = cub::DeviceRadixSort::SortPairs(
/* d_temp_storage */ tmp_base,
......
......@@ -291,7 +291,8 @@ from oneflow.nn.modules.nonzero import nonzero_op as nonzero
from oneflow.nn.modules.random_ops import bernoulli
from oneflow.nn.modules.random_ops import rand_op as rand
from oneflow.nn.modules.random_ops import randn_op as randn
from oneflow.nn.modules.random_ops import randperm
from oneflow.nn.modules.random_ops import randint_op as randint
from oneflow.nn.modules.random_ops import randperm_op as randperm
from oneflow.nn.modules.reduce_ops import _max as max
from oneflow.nn.modules.reduce_ops import _mean as mean
from oneflow.nn.modules.reduce_ops import _min as min
......
......@@ -264,12 +264,121 @@ def randn_op(
)()
class Randperm(Module):
class RandInt(Module):
def __init__(
self,
low: flow.int64,
high: flow.int64,
size: tuple,
generator: flow.Generator = None,
dtype: Optional[flow.dtype] = None,
device=None,
placement=None,
sbp=None,
requires_grad=False,
) -> None:
super().__init__()
if generator is None:
generator = flow.Generator()
assert low < high
self.requires_grad = requires_grad
(
self.size,
self.device,
self.generator,
self.placement,
self.sbp,
) = _rand_op_common_process(size, device, generator, placement, sbp)
self.dtype = dtype
self.low = low
self.high = high
def forward(self):
if self.placement is not None:
res = flow.F.consistent_randint(
self.low,
self.high,
shape=self.size,
placement=self.placement,
sbp_tuple=self.sbp,
dtype=self.dtype,
generator=self.generator,
)
else:
res = flow.F.randint(
self.low,
self.high,
shape=self.size,
dtype=self.dtype,
device=self.device,
generator=self.generator,
)
res.requires_grad = self.requires_grad
return res.to(dtype=self.dtype)
def randint_op(
low: flow.int64,
high: flow.int64,
size: tuple,
out=None,
generator=None,
dtype: Optional[flow.dtype] = None,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False,
):
"""
Returns a tensor filled with random integers generated uniformly between low (inclusive) and high (exclusive).
The shape of the tensor is defined by the variable argument ``size``.
Args:
size (int... or flow.Size): Defining the shape of the output tensor.
Can be a variable number of arguments or a collection like a list or tuple or flow.Size.
out (optional): The output tensor.
dtype (flow.dtype, optional): The desired data type of returned tensor. Default: ``flow.int64``.
layout (optional): The desired layout of returned Tensor.
generator (flow.Generator, optional) – a pseudorandom number generator for sampling
device (flow.device, optional): The desired device of returned local tensor. If None, uses the
current device.
placement (flow.placement, optional): The desired device of returned consistent tensor. If None, will
construct local tensor.
sbp (flow.sbp, optional): The desired sbp of returned consistent tensor. It must be equal with the
numbers of placement.
requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.
For example:
.. code-block:: python
>>> import oneflow as flow
>>> generator = flow.Generator()
>>> generator.manual_seed(0)
>>> flow.randint(0, 5, (3,3), generator=generator)
tensor([[2, 2, 3],
[4, 3, 4],
[2, 4, 2]], dtype=oneflow.int64)
"""
assert out is None, "out not supported yet"
assert layout is None, "layout not supported yet"
if generator is None:
generator = flow.default_generator()
return RandInt(
low, high, size, generator, dtype, device, placement, sbp, requires_grad
)()
class RandPerm(Module):
def __init__(
self,
n,
generator: flow.Generator = None,
dtype: flow.dtype = flow.int32,
dtype: Optional[flow.dtype] = None,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
......@@ -280,15 +389,15 @@ class Randperm(Module):
super().__init__()
assert n >= 0
self.n = n
self.requires_grad = requires_grad
self.dtype = dtype
(
self.size,
_,
self.device,
self.generator,
self.placement,
self.sbp,
) = _rand_op_common_process(1, device, generator, placement, sbp)
self.dtype = dtype
) = _rand_op_common_process((), device, generator, placement, sbp)
self.requires_grad = requires_grad
def forward(self, out=None):
if self.placement is not None:
......@@ -301,29 +410,29 @@ class Randperm(Module):
return res.to(dtype=self.dtype)
def randperm(
def randperm_op(
n: flow.int32,
generator: flow.Generator = None,
out=None,
dtype: flow.dtype = flow.int32,
dtype: Optional[flow.dtype] = None,
layout=None,
device: Union[flow.device, str, None] = None,
placement: flow.placement = None,
sbp: flow._oneflow_internal.sbp.sbp = None,
requires_grad: bool = False,
pin_memory: bool = False,
):
) -> flow.Tensor:
r"""
Returns a random permutation of integers from ``0`` to ``n - 1``.
Args:
n (int): the upper bound (exclusive)
Keyword args:
generator(:class:`oneflow.Generator`, optional): a pseudorandom number generator for sampling
out (Tensor, optional): output Tensor,not supported yet.
dtype (:class:`oneflow.dtype`, optional): the desired data type of returned tensor.
Default: ``oneflow.int32``.
Default: ``oneflow.int64``.
layout: layout is not supported yet.
device: the desired device of returned tensor. Default: cpu.
placement:(:class:`flow.placement`, optional): The desired device of returned consistent tensor. If None,
......@@ -345,11 +454,12 @@ def randperm(
"""
assert out is None, "out not supported yet"
assert layout is None, "layout not supported yet"
assert pin_memory is False, "pin_memory not supported yet"
if generator is None:
generator = flow.default_generator()
return Randperm(
n, generator, dtype, layout, device, placement, sbp, requires_grad, pin_memory
)(out)
return RandPerm(n, generator, dtype, layout, device, placement, sbp, requires_grad)(
out
)
if __name__ == "__main__":
......
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import unittest
from collections import OrderedDict
import numpy as np
import oneflow as flow
import oneflow.unittest
from test_util import GenArgList
def _test_rand(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertFalse(np.all(y1.numpy() == y2.numpy()))
test_case.assertTrue(shape == y1.shape)
def _test_0d_rand(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertTrue(
np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4)
) # 0d is [] and []
test_case.assertTrue(shape == y1.shape)
def _test_different_dtype(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, dtype=flow.float32, device=flow.device(device))
y2 = flow.randint(low, high, shape, dtype=flow.float64, device=flow.device(device))
test_case.assertTrue(not np.allclose(y1.numpy(), y2.numpy(), atol=1e-4, rtol=1e-4))
test_case.assertTrue(shape == y1.shape)
def _test_backward(test_case, device, shape, low, high):
x = flow.randint(low, high, shape, device=flow.device(device), requires_grad=True)
y = x.sum()
y.backward()
test_case.assertTrue(
np.allclose(np.ones(shape), x.grad.numpy(), atol=1e-4, rtol=1e-4)
)
def _test_with_generator(test_case, device, shape, low, high):
gen = flow.Generator()
gen.manual_seed(0)
y1 = flow.randint(
low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
y1_np = y1.numpy()
gen.manual_seed(0)
y2 = flow.randint(
low, high, shape, dtype=flow.float32, device=flow.device(device), generator=gen
)
test_case.assertTrue(np.allclose(y1_np, y2.numpy(), atol=1e-4, rtol=1e-4))
def _test_high(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
y2 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertFalse(np.all(y1.numpy() == y2.numpy()))
test_case.assertTrue(shape == y1.shape)
def _test_0rank(test_case, device, shape, low, high):
y1 = flow.randint(low, high, shape, device=flow.device(device))
test_case.assertTrue(y1.shape == shape)
@flow.unittest.skip_unless_1n1d()
@unittest.skipIf(os.getenv("ONEFLOW_TEST_CPU_ONLY"), "only test cpu cases")
class TestRandint(flow.unittest.TestCase):
def test_consistent_naive(test_case):
placement = flow.placement("cpu", {0: [0]})
sbp = (flow.sbp.broadcast,)
x = flow.randint(0, 16, (10, 1), placement=placement, sbp=sbp)
test_case.assertEqual(x.sbp, sbp)
test_case.assertEqual(x.placement, placement)
def test_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [
_test_rand,
_test_different_dtype,
_test_backward,
_test_with_generator,
]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
arg_dict["low"] = [i for i in range(10)]
arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_0d_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_0d_rand]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 0, 4), (2, 0, 2)]
arg_dict["low"] = [i for i in range(10)]
arg_dict["high"] = [10 + np.random.randint(1, 20) for i in range(10)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_high_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_high]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [(2, 3, 4), (2, 5, 2)]
arg_dict["low"] = [i for i in range(10)]
arg_dict["high"] = [10 + np.random.randint(10, 20) for i in range(10)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
def test_0rank_randint(test_case):
arg_dict = OrderedDict()
arg_dict["test_fun"] = [_test_0rank]
arg_dict["device"] = ["cpu", "cuda"]
arg_dict["shape"] = [()]
arg_dict["low"] = [i for i in range(10)]
arg_dict["high"] = [1000 + np.random.randint(1, 10) for i in range(10)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册