From 13e6ea349d020e71924026a4f40fd21c48cc0a29 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 24 Feb 2021 21:16:33 +0800 Subject: [PATCH] feat(imperative/opr): rebase rng refactoring to dev & add python module GitOrigin-RevId: ee5984c52d3fa346d5f26d737bf40ec4ed43b2c7 --- .../megengine/core/tensor/array_method.py | 3 +- .../python/megengine/distributed/group.py | 3 + .../python/megengine/random/__init__.py | 2 +- .../python/megengine/random/distribution.py | 42 ++--- imperative/python/megengine/random/rng.py | 89 +++++++++- imperative/python/src/ops.cpp | 31 ++-- .../python/test/unit/random/test_rng.py | 121 +++++++++++++ imperative/python/test/unit/test_rng.py | 76 -------- imperative/src/impl/ops/rng.cpp | 168 +++++++++--------- imperative/src/impl/ops/specializations.cpp | 28 --- .../src/include/megbrain/imperative/ops/rng.h | 87 +-------- imperative/src/test/rng.cpp | 20 ++- src/core/include/megbrain/ir/ops.td | 24 ++- 13 files changed, 371 insertions(+), 323 deletions(-) create mode 100644 imperative/python/test/unit/random/test_rng.py delete mode 100644 imperative/python/test/unit/test_rng.py diff --git a/imperative/python/megengine/core/tensor/array_method.py b/imperative/python/megengine/core/tensor/array_method.py index 55099840c..d69852743 100644 --- a/imperative/python/megengine/core/tensor/array_method.py +++ b/imperative/python/megengine/core/tensor/array_method.py @@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False): def _remove_axis(inp: Tensor, axis) -> Tensor: def get_axes(): if axis is None: - return [i for i, s in enumerate(inp.shape) if s == 1] + shp = inp.shape + return [i for i, s in enumerate(shp) if s == 1] try: return [int(axis)] except (TypeError, ValueError): diff --git a/imperative/python/megengine/distributed/group.py b/imperative/python/megengine/distributed/group.py index fad649be2..c00933e99 100644 --- a/imperative/python/megengine/distributed/group.py +++ b/imperative/python/megengine/distributed/group.py @@ -6,9 +6,11 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import time from typing import List, Optional, Tuple from ..device import set_default_device, what_is_xpu +from ..random import seed from .server import Client, Server @@ -156,6 +158,7 @@ def init_process_group( WORLD.reset(list(range(world_size))) set_default_device("{}{}".format(device_type, device)) + seed(int(time.time()) + rank) def is_distributed() -> bool: diff --git a/imperative/python/megengine/random/__init__.py b/imperative/python/megengine/random/__init__.py index bedf4340e..996be02ba 100644 --- a/imperative/python/megengine/random/__init__.py +++ b/imperative/python/megengine/random/__init__.py @@ -7,7 +7,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from .distribution import normal, uniform -from .rng import seed +from .rng import RNG, seed # pylint: disable=undefined-variable del distribution, rng # type: ignore[name-defined] diff --git a/imperative/python/megengine/random/distribution.py b/imperative/python/megengine/random/distribution.py index 199778ebb..be74a0d64 100644 --- a/imperative/python/megengine/random/distribution.py +++ b/imperative/python/megengine/random/distribution.py @@ -9,11 +9,8 @@ from typing import Iterable, Optional from .. import Tensor -from ..core._imperative_rt import invoke_op -from ..core._imperative_rt.core2 import apply -from ..core.ops.builtin import GaussianRNG, UniformRNG -from ..core.tensor import utils -from .rng import _random_seed_generator +from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed +from .rng import _normal, _uniform __all__ = ["normal", "uniform"] @@ -48,14 +45,14 @@ def normal( [-1.4939808 -1.5824696 ]] """ - if size is None: - size = (1,) - op = GaussianRNG(mean, std) - _ref = Tensor([], dtype="int32") - shape = utils.astensor1d(size, _ref, dtype="int32") - shape = Tensor(shape, dtype="int32") - (output,) = apply(op, shape) - return output + return _normal( + mean=mean, + std=std, + size=size, + seed=_get_global_rng_seed(), + device=None, + handle=0, + ) def uniform( @@ -88,14 +85,11 @@ def uniform( [0.09365904 0.62957656]] """ - assert low < high, "Uniform is not defined when low >= high" - - if size is None: - size = (1,) - op = UniformRNG() - _ref = Tensor([], dtype="int32") - shape = utils.astensor1d(size, _ref, dtype="int32") - shape = Tensor(shape, dtype="int32") - (output,) = apply(op, shape) - - return low + (high - low) * output + return _uniform( + low=low, + high=high, + size=size, + seed=_get_global_rng_seed(), + device=None, + handle=0, + ) diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index 8cdfd0e18..448a232f5 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -7,17 +7,94 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import time +from typing import Iterable, Optional from numpy.random import MT19937 +from .. import Tensor +from ..core._imperative_rt.core2 import apply +from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle +from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed +from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle +from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed +from ..core.ops.builtin import GaussianRNG, UniformRNG +from ..core.tensor import utils +from ..device import get_default_device + _rng = None -def _random_seed_generator(): - if _rng is None: - from ..distributed.group import get_rank +def _normal( + mean: float, + std: float, + size: Optional[Iterable[int]], + seed: int, + device: str, + handle: int, +) -> Tensor: + if size is None: + size = (1,) + op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) + _ref = Tensor([], dtype="int32", device=device) + shape = utils.astensor1d(size, _ref, dtype="int32", device=device) + (output,) = apply(op, shape) + return output + + +def _uniform( + low: float, + high: float, + size: Optional[Iterable[int]], + seed: int, + device: str, + handle: int, +) -> Tensor: + assert low < high, "Uniform is not defined when low >= high" + if size is None: + size = (1,) + op = UniformRNG(seed=seed, handle=handle) + _ref = Tensor([], dtype="int32", device=device) + shape = utils.astensor1d(size, _ref, dtype="int32", device=device) + (output,) = apply(op, shape) + return low + (high - low) * output + + +class RNG: + def __init__(self, seed=0, device=None): + self.seed = seed + self.device = device if device else get_default_device() + self.handle = _new_rng_handle(self.device, self.seed) + + def uniform( + self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None + ): + return _uniform( + low=low, + high=high, + size=size, + seed=self.seed, + device=self.device, + handle=self.handle, + ) - seed(seed=int(time.time()) + get_rank()) + def normal( + self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None + ): + return _normal( + mean=mean, + std=std, + size=size, + seed=self.seed, + device=self.device, + handle=self.handle, + ) + + def __del__(self): + _delete_rng_handle(self.handle) + + +def _random_seed_generator(): + assert _rng while True: yield _rng.random_raw() @@ -25,3 +102,7 @@ def _random_seed_generator(): def seed(seed: int): global _rng # pylint: disable=global-statement _rng = MT19937(seed=seed) + _set_global_rng_seed(seed) + + +seed(int(time.time())) diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 5dc3be143..414edb855 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -10,7 +10,10 @@ */ #include "./ops.h" +#include "./helper.h" +#include "./tensor.h" +#include "megbrain/common.h" #include "megbrain/imperative.h" #include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/opr_attr.h" @@ -491,21 +494,15 @@ void init_ops(py::module m) { _init_py_op_base(m); INIT_ALL_OP(m) - m.def("new_rng_handle", &RNGMixin::new_handle); - // FIXME: RNG op might execute after handle released due to async dispatch, - // which would cause memory leak or use-after-free - m.def("delete_rng_handle", &RNGMixin::delete_handle); - m.def("set_rng_seed", &set_rng_seed); - - py::class_, OpDef>(m, "UniformRNG") - .def(py::init<>()) - .def(py::init()) - .def(py::init()); - - py::class_, OpDef>(m, "GaussianRNG") - .def(py::init<>()) - .def(py::init()) - .def(py::init()) - .def(py::init()) - .def(py::init()); + m.def("new_rng_handle", &rng::new_handle); + m.def("delete_rng_handle", [](size_t handle){ + // RNG op might execute after handle released due to async dispatch, so + // we need sync before delete a handle to avoid memory leak or use-after-free + python::interpreter_for_py->sync(); + mgb::CompNode::sync_all(); + py_task_q.wait_all_task_finish(); + rng::delete_handle(handle); + }, py::call_guard()); + m.def("set_global_rng_seed", &rng::set_global_rng_seed); + m.def("get_global_rng_seed", &rng::get_global_rng_seed); } diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py new file mode 100644 index 000000000..1979eebfe --- /dev/null +++ b/imperative/python/test/unit/random/test_rng.py @@ -0,0 +1,121 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np + +import megengine +from megengine import tensor +from megengine.core._imperative_rt import CompNode +from megengine.core._imperative_rt.core2 import apply +from megengine.core._imperative_rt.ops import ( + delete_rng_handle, + get_global_rng_seed, + new_rng_handle, +) +from megengine.core.ops.builtin import GaussianRNG, UniformRNG +from megengine.random import RNG +from megengine.random.rng import _normal, _uniform + + +def test_gaussian_op(): + shape = ( + 8, + 9, + 11, + 12, + ) + shape = tensor(shape, dtype="int32") + op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 + assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) + (output,) = apply(op, shape) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 + assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 + assert str(output.device) == str(cn) + + +def test_uniform_op(): + shape = ( + 8, + 9, + 11, + 12, + ) + shape = tensor(shape, dtype="int32") + op = UniformRNG(seed=get_global_rng_seed()) + (output,) = apply(op, shape) + assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + op = UniformRNG(seed=seed, handle=h) + (output,) = apply(op, shape) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 + assert str(output.device) == str(cn) + + +def test_UniformRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = m1.uniform(size=(100,)) + out1_ = m1.uniform(size=(100,)) + out2 = m2.uniform(size=(100,)) + out3 = m3.uniform(size=(100,)) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + assert not (out1.numpy() == out1_.numpy()).all() + + low = -234 + high = 123 + out = m1.uniform(low=low, high=high, size=(20, 30, 40)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (20, 30, 40) + else: + assert all(out.shape.numpy() == np.array([20, 30, 40])) + assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1 + + +def test_NormalRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = m1.normal(size=(100,)) + out1_ = m1.uniform(size=(100,)) + out2 = m2.normal(size=(100,)) + out3 = m3.normal(size=(100,)) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + assert not (out1.numpy() == out1_.numpy()).all() + + mean = -1 + std = 2 + out = m1.normal(mean=mean, std=std, size=(20, 30, 40)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (20, 30, 40) + else: + assert all(out.shape.numpy() == np.array([20, 30, 40])) + assert np.abs(out.mean().numpy() - mean) / std < 0.1 + assert np.abs(np.std(out.numpy()) - std) < 0.1 diff --git a/imperative/python/test/unit/test_rng.py b/imperative/python/test/unit/test_rng.py deleted file mode 100644 index d5d8e1643..000000000 --- a/imperative/python/test/unit/test_rng.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -import numpy as np - -from megengine import tensor -from megengine.core._imperative_rt import CompNode -from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle -from megengine.core.ops.builtin import GaussianRNG, UniformRNG -from megengine.core.tensor.core import apply - - -def test_gaussian_rng(): - shape = ( - 8, - 9, - 11, - 12, - ) - shape = tensor(shape, dtype="int32") - op = GaussianRNG(1.0, 3.0) - (output,) = apply(op, shape) - assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 - assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 - assert str(output.device) == str(CompNode("xpux")) - - cn = CompNode("xpu1") - op = GaussianRNG(-1.0, 2.0, cn) - (output,) = apply(op, shape) - assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1 - assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1 - assert str(output.device) == str(cn) - - cn = CompNode("xpu2") - seed = 233333 - h = new_rng_handle(cn, seed) - op = GaussianRNG(3.0, 1.0, h) - (output,) = apply(op, shape) - delete_rng_handle(h) - assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 - assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 - assert str(output.device) == str(cn) - - -def test_uniform_rng(): - shape = ( - 8, - 9, - 11, - 12, - ) - shape = tensor(shape, dtype="int32") - op = UniformRNG() - (output,) = apply(op, shape) - assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 - assert str(output.device) == str(CompNode("xpux")) - - cn = CompNode("xpu1") - op = UniformRNG(cn) - (output,) = apply(op, shape) - assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 - assert str(output.device) == str(cn) - - cn = CompNode("xpu2") - seed = 233333 - h = new_rng_handle(cn, seed) - op = UniformRNG(h) - (output,) = apply(op, shape) - delete_rng_handle(h) - assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 - assert str(output.device) == str(cn) diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 21b373786..0e6adbcad 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -2,7 +2,7 @@ * \file imperative/src/impl/ops/rng.cpp * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -10,23 +10,23 @@ */ #include "megbrain/imperative/ops/rng.h" -#include #include "megbrain/comp_node_env.h" #include "megbrain/graph/helper.h" #include "megbrain/opr/rand.h" -//#include "megbrain/common.h" #include "../op_trait.h" +#include "../dnn_op_helper.h" -namespace mgb { -namespace imperative { +namespace mgb::imperative::rng { namespace { template class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { public: + using DT = CompNode::DeviceType; using Handle = THandle; + using OpTypeInfo = size_t; template Handle new_handle(Args&&... args) { @@ -38,27 +38,26 @@ public: size_t removed = 0; if (!is_finalized()) { MGB_LOCK_GUARD(m_mtx); - removed = m_handle2op.erase(handle); + removed = m_handle2ops.erase(handle); } static_cast(this)->do_delete_handle(handle); return removed; } template - auto get_dnn_op(Handle handle, CompNode cn) { + auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) { mgb_assert(!is_finalized()); DnnOpWithMutex* dnn_op_with_mtx; { MGB_LOCK_GUARD(m_mtx); - dnn_op_with_mtx = &m_handle2op[handle]; + dnn_op_with_mtx = &m_handle2ops[handle][tpinfo]; } auto dnn_handle = MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); - DnnOp* dnn_op; std::unique_lock lock(dnn_op_with_mtx->mtx); bool initialized = false; - if ((dnn_op = dynamic_cast(dnn_op_with_mtx->op.get())) != - nullptr) { + DnnOp* dnn_op = static_cast(dnn_op_with_mtx->op.get()); + if (dnn_op != nullptr) { mgb_assert(dnn_op->handle() == dnn_handle); initialized = true; } else { @@ -77,35 +76,30 @@ private: struct DnnOpWithMutex { std::mutex mtx; std::unique_ptr op; + DnnOpWithMutex(): op{nullptr} {} }; std::shared_ptr on_comp_node_finalize() override { MGB_LOCK_GUARD(m_mtx); - m_handle2op.clear(); + m_handle2ops.clear(); return {}; } - std::unordered_map m_handle2op; + std::unordered_map > m_handle2ops; std::mutex m_mtx; }; class RNGDnnOpManager final - : public DnnOpManagerT { + : public DnnOpManagerT { public: + Handle new_handle(CompNode comp_node, uint64_t seed) { + MGB_LOCK_GUARD(sm_mtx); + return DnnOpManagerBase::new_handle(comp_node, seed); + } + size_t delete_handle(Handle handle) { - size_t ret = 0; - { - MGB_LOCK_GUARD(sm_mtx); - auto iter = sm_partial2full.find(handle); - if (iter != sm_partial2full.end()) { - for (auto&& h : iter->second) { - ret += DnnOpManagerBase::delete_handle(h.second); - } - sm_partial2full.erase(iter); - } - } - ret += DnnOpManagerBase::delete_handle(handle); - return ret; + MGB_LOCK_GUARD(sm_mtx); + return DnnOpManagerBase::delete_handle(handle); } Handle do_new_handle(CompNode comp_node, uint64_t seed) { @@ -118,32 +112,26 @@ public: } static uint64_t get_seed(Handle handle) { + if (!handle) { return glob_default_seed; } return reinterpret_cast(handle)->seed; } static CompNode get_comp_node(Handle handle) { + mgb_assert(handle, "invalid handle"); return reinterpret_cast(handle)->comp_node; } - static Handle get_full_handle(Handle handle, CompNode comp_node) { - if (get_comp_node(handle).valid()) { - return handle; - } - MGB_LOCK_GUARD(sm_mtx); - auto&& full = sm_partial2full[handle][comp_node]; - if (!full) { - full = inst().new_handle(comp_node, get_seed(handle)); - } - return full; - } - static Handle get_default_handle(CompNode comp_node) { - static Handle glob_partial_handle = - inst().new_handle(CompNode{}, glob_default_seed); - if (!comp_node.valid()) { - return glob_partial_handle; + mgb_assert(comp_node.valid()); + MGB_LOCK_GUARD(sm_mtx); + auto&& glob_handle = glob_default_handles[comp_node]; + if (!glob_handle) { + glob_handle = inst().do_new_handle(comp_node, glob_default_seed); + } else if (get_seed(glob_handle) != glob_default_seed) { + inst().DnnOpManagerBase::delete_handle(glob_handle); + glob_handle = inst().do_new_handle(comp_node, glob_default_seed); } - return get_full_handle(glob_partial_handle, comp_node); + return glob_handle; } static RNGDnnOpManager& inst() { @@ -152,9 +140,15 @@ public: } static void set_glob_default_seed(uint64_t seed) { + MGB_LOCK_GUARD(sm_mtx); glob_default_seed = seed; } + static uint64_t get_glob_default_seed() { + MGB_LOCK_GUARD(sm_mtx); + return glob_default_seed; + } + private: struct HandleData { CompNode comp_node; @@ -165,16 +159,13 @@ private: MemPool m_handle_pool; static std::mutex sm_mtx; - static std::unordered_map> - sm_partial2full; + static CompNode::UnorderedMap glob_default_handles; static uint64_t glob_default_seed; }; uint64_t RNGDnnOpManager::glob_default_seed = 0; std::mutex RNGDnnOpManager::sm_mtx; -std::unordered_map> - RNGDnnOpManager::sm_partial2full; +CompNode::UnorderedMap RNGDnnOpManager::glob_default_handles; template struct OpMeth; @@ -185,7 +176,11 @@ struct OpMeth { using Param = DnnOp::Param; using OpNode = mgb::opr::UniformRNG; static Param make_param(const UniformRNG& rng) { - return {RNGDnnOpManager::get_seed(rng.handle())}; + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed}; } }; @@ -195,7 +190,11 @@ struct OpMeth { using Param = DnnOp::Param; using OpNode = mgb::opr::GaussianRNG; static Param make_param(const GaussianRNG& rng) { - return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed, rng.mean, rng.std}; } }; @@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector& inputs, auto dest = outputs[0]; auto cn = dest->comp_node(); - auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); - { - auto handle_cn = RNGDnnOpManager::get_comp_node(handle); - mgb_assert(cn == handle_cn, - "inconsistent comp_node: handle: %s, output: %s", - cn.to_string().c_str(), handle_cn.to_string().c_str()); + auto handle = rng.handle; + if (!handle) { + handle = RNGDnnOpManager::get_default_handle(cn); } // retrieve dnn_op from glob cache auto dnn_op_thread_safe = RNGDnnOpManager::inst() - .get_dnn_op::DnnOp>(handle, cn); + .get_dnn_op::DnnOp>( + handle, reinterpret_cast(op.dyn_typeinfo()), + cn); auto initialized = std::get<0>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe); if (initialized) { auto handle_seed = RNGDnnOpManager::get_seed(handle); mgb_assert(dnn_op->param().seed == handle_seed, - "inconsistent rng seed: handle: %zu, dnn_op: %zu", + "inconsistent rng seed: handle: %lu, dnn_op: %lu", handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); @@ -239,9 +237,12 @@ template SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { LogicalTensorDesc dest; - dest.comp_node = op.cast_final_safe().comp_node(); - if (!dest.comp_node.valid()) + auto handle = op.cast_final_safe().handle; + if (handle) { + dest.comp_node = RNGDnnOpManager::get_comp_node(handle); + } else { dest.comp_node = inputs[0]->comp_node(); + } auto hv = inputs[0]->get_value().proxy_to_default_cpu(); TensorShape tshape; @@ -263,15 +264,22 @@ SmallVector apply_on_physical_tensor( } template -cg::OperatorNodeBase* apply_on_var_node( - const OpDef& def, const VarNodeArray& inputs) { +SymbolVar apply_on_var_node( + const OpDef& def, + const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually", - nr_inp); auto&& rng = def.cast_final_safe(); + mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", + rng.dyn_typeinfo()->name, + nr_inp); auto param = OpMeth::make_param(rng); - return OpMeth::OpNode::make( - inputs[0], param, {rng.comp_node()}).node()->owner_opr(); + OperatorNodeConfig config; + if (rng.handle) { + config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; + } else { + config = {rng.make_name()}; + } + return OpMeth::OpNode::make(inputs[0], param, config); } template @@ -309,28 +317,22 @@ std::tuple, bool> infer_output_attrs_fallible( } // anonymous namespace -RNGMixin::RNGMixin(CompNode cn): - m_handle(RNGDnnOpManager::get_default_handle(cn)) {} - -uint64_t RNGMixin::seed() const { - return RNGDnnOpManager::get_seed(m_handle); -} - -CompNode RNGMixin::comp_node() const { - return RNGDnnOpManager::get_comp_node(m_handle); -} - -RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) { +Handle new_handle(CompNode comp_node, uint64_t seed) { return RNGDnnOpManager::inst().new_handle(comp_node, seed); } -size_t RNGMixin::delete_handle(Handle handle) { +size_t delete_handle(Handle handle) { return RNGDnnOpManager::inst().delete_handle(handle); } -void set_rng_seed(uint64_t seed) { +void set_global_rng_seed(uint64_t seed) { RNGDnnOpManager::set_glob_default_seed(seed); } + +uint64_t get_global_rng_seed() { + return RNGDnnOpManager::get_glob_default_seed(); +} + #define REG_RNG_OP(NAME)\ namespace { \ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ @@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ .infer_output_attrs_fallible(infer_output_attrs_fallible) \ .fallback(); \ } \ -MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME); REG_RNG_OP(UniformRNG) REG_RNG_OP(GaussianRNG) -} // namespace imperative -} // namespace mgb +} // namespace mgb::imperative::rng // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index 23bdfb0dc..71cbebb5c 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual) .fallback(); }} // assert_equal -namespace { namespace uniform_rng { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 1); - OperatorNodeConfig config{op.make_name()}; - return opr::UniformRNG::make(inputs[0], op.param(), config); -} -OP_TRAIT_REG(UniformRNG, UniformRNG) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // uniform_rng - -namespace { namespace gaussian_rng { -auto apply_on_var_node( - const OpDef& def, - const VarNodeArray& inputs) { - auto&& op = static_cast(def); - mgb_assert(inputs.size() == 1); - OperatorNodeConfig config{op.make_name()}; - return opr::GaussianRNG::make(inputs[0], op.param(), config); -} -OP_TRAIT_REG(GaussianRNG, GaussianRNG) - .apply_on_var_node(apply_on_var_node) - .fallback(); -}} // gaussian_rng - namespace { namespace roi_align { VarNodeArray apply_on_var_node( const OpDef& def, diff --git a/imperative/src/include/megbrain/imperative/ops/rng.h b/imperative/src/include/megbrain/imperative/ops/rng.h index eb3ed7b41..7f7e55052 100644 --- a/imperative/src/include/megbrain/imperative/ops/rng.h +++ b/imperative/src/include/megbrain/imperative/ops/rng.h @@ -2,7 +2,7 @@ * \file imperative/src/include/megbrain/imperative/ops/rng.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * - * Copyright (c) 2014-2020 Megvii Inc. All rights reserved. + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an @@ -12,84 +12,15 @@ #pragma once #include "megbrain/imperative/op_def.h" +#include "megbrain/imperative/ops/autogen.h" -namespace mgb::imperative { +namespace mgb::imperative::rng { -class RNGMixin { -public: - using Handle = size_t; +using Handle = size_t; - static Handle new_handle( - CompNode comp_node={}, uint64_t seed=0); +Handle new_handle(CompNode comp_node, uint64_t seed); +size_t delete_handle(Handle handle); +void set_global_rng_seed(uint64_t seed); +uint64_t get_global_rng_seed(); - static size_t delete_handle(Handle handle); - - Handle handle() const { - return m_handle; - } - - uint64_t seed() const; - - CompNode comp_node() const; -protected: - RNGMixin(Handle handle): m_handle(handle) {} - RNGMixin(CompNode comp_node); -private: - Handle m_handle; -}; - -class GaussianRNG : public OpDefImplBase, - public RNGMixin { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - float mean = 1.0f, std = 0.0; - GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {} - GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}): - GaussianRNG(comp_node_) { mean = mean_; std = std_; } - GaussianRNG(float mean_, float std_, Handle handle): - RNGMixin(handle), mean(mean_), std(std_) {} - size_t hash() const override { - XXHash xxhash{}; - auto append = [&xxhash](auto field){ - auto hash_val = HashTrait::eval(field); - xxhash.update(reinterpret_cast(&hash_val), sizeof(hash_val)); - }; - append(dyn_typeinfo()); - append(seed()); - append(mean); - append(std); - return xxhash.digest(); - } - - - bool is_same_st(const Hashable& rhs_) const override { - auto&& rhs = static_cast(rhs_); - return rhs.seed() == seed() - && rhs.mean == mean - && rhs.std == std; - } -}; - -class UniformRNG : public OpDefImplBase, - public RNGMixin { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {} - UniformRNG(Handle handle): RNGMixin(handle) {} - - size_t hash() const override { - return hash_pair_combine( - mgb::hash(seed()), - reinterpret_cast(dyn_typeinfo())); - } - - bool is_same_st(const Hashable& rhs_) const override { - auto&& rhs = static_cast(rhs_); - return rhs.dyn_typeinfo() == dyn_typeinfo() - && rhs.seed() == seed(); - } - -}; - -void set_rng_seed(uint64_t seed); -} // namespace mgb::imperative +} // namespace mgb::imperative::rng diff --git a/imperative/src/test/rng.cpp b/imperative/src/test/rng.cpp index cf3786e54..9604df988 100644 --- a/imperative/src/test/rng.cpp +++ b/imperative/src/test/rng.cpp @@ -14,6 +14,7 @@ using namespace mgb; using namespace imperative; +using namespace imperative::rng; template void check_rng_basic(Args&& ...args) { @@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) { {3, 4, 5, 6}, {2333}}) for (auto&& cn: { - CompNode::load("cpu0"), - CompNode::load("xpu0")}) + CompNode::load("xpu0"), + CompNode::load("xpu1")}) { - auto op = Op::make(std::forward(args)..., cn); + Handle h = new_handle(cn, 123); + auto op = Op::make(std::forward(args)..., h); DeviceTensorND tshape_dev; cg::copy_shape_to_tensor_value(tshape_dev, tshape); - auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); + SmallVector inputs = {Tensor::make(tshape_dev)}; + auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); ASSERT_TRUE(cn == outputs[0]->comp_node()); + // sync before delete handle + for (auto&& p: outputs) { + p->get_value(); + } + delete_handle(h); } } TEST(TestImperative, UniformRNGBasic) { - check_rng_basic(); + check_rng_basic(123); } TEST(TestImperative, GaussianRNGBasic) { - check_rng_basic(2.f, 3.f); + check_rng_basic(123, 2.f, 3.f); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index e13bb859a..a7e2c6555 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>; def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { - let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; - let cmpFunction = [{return true;}]; + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash($_self.handle)); + }]; + let cmpFunction = [{return $0.handle == $1.handle;}]; } def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); let hashFunction = [{ return mgb::hash_pair_combine( mgb::hash($_self.dyn_typeinfo()), - mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); + mgb::hash_pair_combine( + mgb::hash($_self.handle), + mgb::hash_pair_combine( + mgb::hash($_self.mean), + mgb::hash($_self.std)) + ) + ); }]; - let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; + let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; } def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { -- GitLab