diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index 661a74202276cbb3470dbbbbba362f902584e1c6..d61928f4f006e3e12654b600aaa52f25a14d03ba 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -209,7 +209,8 @@ def _poisson( def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor: - assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0" + assert isinstance(n, int) + assert n >= 0, "Permutation is not defined when n < 0" size = (n,) op = PermutationRNG(seed=seed, handle=handle, dtype=dtype) _ref = Tensor([], dtype="int32", device=device) diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index 3150ef8ad84999d4536dd06bc97debbb679c0502..a0a160a54deb03e45c37688128d65b4c98d3974b 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -10,7 +10,7 @@ import numpy as np import pytest import megengine.functional as F -from megengine import Tensor +from megengine import Tensor, jit, random from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.ops import ( @@ -402,3 +402,44 @@ def test_seed(): seed(11) out4 = uniform(size=[10, 10]) assert not (out1.numpy() == out4.numpy()).all() + + +@pytest.mark.parametrize("is_symbolic", [None, False, True]) +def test_rng_empty_tensor(is_symbolic): + shapes = [ + (0,), + (0, 0, 0), + (10, 0, 10), + ] + + def fn(shape): + o1 = random.uniform(0, 1, shape) + o2 = random.normal(0, 1, shape) + o3 = random.gamma(2, 1, shape) + o4 = random.beta(2, 1, shape) + o5 = random.poisson(2, shape) + return o1, o2, o3, o4, o5 + + for shape in shapes: + if is_symbolic is not None: + fn_ = jit.trace(symbolic=is_symbolic)(fn) + else: + fn_ = fn + for _ in range(3): + outs = fn_(shape) + for out in outs: + np.testing.assert_equal(out.numpy().shape, shape) + if is_symbolic is None: + break + + def fn2(n): + return random.permutation(n=n) + + if is_symbolic is not None: + fn2 = jit.trace(symbolic=is_symbolic)(fn2) + + for _ in range(3): + out = fn2(0) + np.testing.assert_equal(out.numpy().shape, (0,)) + if is_symbolic is None: + break diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index b5d147fc455802c0d965de28d65a848754fdd3ac..441bbcb52e1507cd74e552aaa6d3756523bb5ad7 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -312,12 +312,7 @@ struct _InferLayout template static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ - size_t size = inp.layout.total_nr_elems(); - mgb_assert( - size > 0, - "target size of %s expects size>0; got size=%lu actually", - rng.dyn_typeinfo()->name, - size); + mgb_assert(inp.layout.ndim); return inp.layout; } }; @@ -376,6 +371,7 @@ void exec(const OpDef& op, const SmallVector& inputs, auto&& rng = op.cast_final_safe(); auto dest = outputs[0]; + if (dest->layout().is_empty()) return; auto cn = dest->comp_node(); auto handle = rng.handle; if (!handle) { diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 34583ea015910ea080264720fa2130443ae46fd6..7cd3f0765eda46a47dbbdae90d82cd479f6cd483 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -48,6 +48,9 @@ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { auto prop = Super::do_make_node_prop(); \ prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ + for (auto i: input()) { \ + prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \ + } \ return prop; \ } \ RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, \ @@ -56,7 +59,7 @@ RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, { \ DType dtype = DType::from_enum(param.dtype); \ add_input({shape}); \ - add_output(None)->dtype(dtype); \ + add_output(None)->dtype(dtype).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \ cg::add_workspace_output(this); \ add_equivalence_component>(this); \ } \ @@ -84,7 +87,12 @@ void RNGOpr::init_output_static_infer_desc() { {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ } \ void RNGOpr::scn_do_execute() { \ - m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \ + auto&& ret = output(0); \ + if (ret->layout().is_empty()) { \ + mgb_assert(ret->dev_tensor().empty()); \ + return; \ + } \ + m_dnn_opr->exec(ret->dev_tensor().as_megdnn(), \ get_megdnn_workspace_from_var(output(1))); \ } @@ -105,7 +113,7 @@ RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param ¶m, Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ { \ add_input({_INPUTS(,)}); \ - add_output(None)->dtype(i0->dtype()); \ + add_output(None)->dtype(i0->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \ cg::add_workspace_output(this); \ add_equivalence_component>(this); \ } \ @@ -132,9 +140,22 @@ void RNGOpr::add_input_layout_constraint(){ for (auto i : input()) i->add_layout_constraint_contiguous(); \ }; \ void RNGOpr::scn_do_execute() { \ + auto&& ret = output(0); \ + if (ret->layout().is_empty()) { \ + mgb_assert(ret->dev_tensor().empty()); \ + return; \ + } \ m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ get_megdnn_workspace_from_var(output(1))); \ -} +} \ +cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ + auto prop = Super::do_make_node_prop(); \ + prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ + for (auto i: input()) { \ + prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \ + } \ + return prop; \ +} /* ================= 1 input ================= */ #define _INPUTS(prefix, subfix) prefix i0 subfix diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index 57d02248a7a34baac58fd1caa2927409fd0e091f..7bea8bfcb47705dfef92f1ada16b019775b568c1 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -67,6 +67,7 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase) \ void add_input_layout_constraint() override; \ + cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ public: \ RNG(_INPUTS(VarNode*), const Param ¶m, \ const OperatorNodeConfig &config); \ diff --git a/src/opr/test/rand.cpp b/src/opr/test/rand.cpp index 50c41ea6c9f65a24f79b51fb05a10d7052aedc46..d4e28491a8be63ca44a2de92b31140ce47fd45c2 100644 --- a/src/opr/test/rand.cpp +++ b/src/opr/test/rand.cpp @@ -247,6 +247,92 @@ TEST(TestOprRand, PermutationRNG) { } } +TEST(TestOprRand, EmptyShape) { + auto test_uniform = []() { + static constexpr size_t M = 128, N = 0; + auto graph = ComputingGraph::make(); + SymbolVar dev_out = opr::UniformRNG::make( + *graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32}); + HostTensorND host_out; + auto func = graph->compile({make_callback_copy(dev_out, host_out)}); + func->execute(); + ASSERT_EQ(host_out.shape(), TensorShape({M, N})); + + }; + auto test_gaussian = []() { + size_t SIZE = 0; + constexpr float MEAN = 1, STD = 2; + auto graph = ComputingGraph::make(); + auto y = opr::GaussianRNG::make( + SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), + {23, MEAN, STD, DTypeEnum::Float32}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(TensorShape({SIZE}), host_y.shape()); + }; + auto test_gamma = []() { + std::shared_ptr shape_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()}); + std::shared_ptr scale_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()}); + auto graph = ComputingGraph::make(); + auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); + auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); + + auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(TensorShape({10, 0}), host_y.shape()); + }; + auto test_poisson = []() { + std::shared_ptr lam_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()}); + auto graph = ComputingGraph::make(); + auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); + auto y = opr::PoissonRNG::make(lam_sym, {10}); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(TensorShape({10, 0}), host_y.shape()); + }; + auto test_beta = []() { + std::shared_ptr alpha_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()}); + std::shared_ptr beta_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()}); + auto graph = ComputingGraph::make(); + auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); + auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); + auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10}); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(TensorShape({10, 0}), host_y.shape()); + }; + auto test_permutation = []() { + static constexpr size_t SIZE = 0; + auto graph = ComputingGraph::make(); + auto y = opr::PermutationRNG::make( + SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), + {23, DTypeEnum::Int32}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + func->execute(); + ASSERT_EQ(TensorShape({SIZE}), host_y.shape()); + }; + test_uniform(); + test_gaussian(); + test_gamma(); + test_poisson(); + test_beta(); + test_permutation(); +} + + TEST(TestOprRand, UniformReprod) { static constexpr size_t SIZE = 123; auto graph = ComputingGraph::make();