提交 13e6ea34 编写于 作者: M Megvii Engine Team

feat(imperative/opr): rebase rng refactoring to dev & add python module

GitOrigin-RevId: ee5984c52d3fa346d5f26d737bf40ec4ed43b2c7
上级 cded8ef1
...@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False): ...@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
def _remove_axis(inp: Tensor, axis) -> Tensor: def _remove_axis(inp: Tensor, axis) -> Tensor:
def get_axes(): def get_axes():
if axis is None: if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1] shp = inp.shape
return [i for i, s in enumerate(shp) if s == 1]
try: try:
return [int(axis)] return [int(axis)]
except (TypeError, ValueError): except (TypeError, ValueError):
......
...@@ -6,9 +6,11 @@ ...@@ -6,9 +6,11 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from ..device import set_default_device, what_is_xpu from ..device import set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server from .server import Client, Server
...@@ -156,6 +158,7 @@ def init_process_group( ...@@ -156,6 +158,7 @@ def init_process_group(
WORLD.reset(list(range(world_size))) WORLD.reset(list(range(world_size)))
set_default_device("{}{}".format(device_type, device)) set_default_device("{}{}".format(device_type, device))
seed(int(time.time()) + rank)
def is_distributed() -> bool: def is_distributed() -> bool:
......
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .distribution import normal, uniform from .distribution import normal, uniform
from .rng import seed from .rng import RNG, seed
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
del distribution, rng # type: ignore[name-defined] del distribution, rng # type: ignore[name-defined]
...@@ -9,11 +9,8 @@ ...@@ -9,11 +9,8 @@
from typing import Iterable, Optional from typing import Iterable, Optional
from .. import Tensor from .. import Tensor
from ..core._imperative_rt import invoke_op from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.core2 import apply from .rng import _normal, _uniform
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from .rng import _random_seed_generator
__all__ = ["normal", "uniform"] __all__ = ["normal", "uniform"]
...@@ -48,14 +45,14 @@ def normal( ...@@ -48,14 +45,14 @@ def normal(
[-1.4939808 -1.5824696 ]] [-1.4939808 -1.5824696 ]]
""" """
if size is None: return _normal(
size = (1,) mean=mean,
op = GaussianRNG(mean, std) std=std,
_ref = Tensor([], dtype="int32") size=size,
shape = utils.astensor1d(size, _ref, dtype="int32") seed=_get_global_rng_seed(),
shape = Tensor(shape, dtype="int32") device=None,
(output,) = apply(op, shape) handle=0,
return output )
def uniform( def uniform(
...@@ -88,14 +85,11 @@ def uniform( ...@@ -88,14 +85,11 @@ def uniform(
[0.09365904 0.62957656]] [0.09365904 0.62957656]]
""" """
assert low < high, "Uniform is not defined when low >= high" return _uniform(
low=low,
if size is None: high=high,
size = (1,) size=size,
op = UniformRNG() seed=_get_global_rng_seed(),
_ref = Tensor([], dtype="int32") device=None,
shape = utils.astensor1d(size, _ref, dtype="int32") handle=0,
shape = Tensor(shape, dtype="int32") )
(output,) = apply(op, shape)
return low + (high - low) * output
...@@ -7,17 +7,94 @@ ...@@ -7,17 +7,94 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time import time
from typing import Iterable, Optional
from numpy.random import MT19937 from numpy.random import MT19937
from .. import Tensor
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from ..device import get_default_device
_rng = None _rng = None
def _random_seed_generator(): def _normal(
if _rng is None: mean: float,
from ..distributed.group import get_rank std: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
if size is None:
size = (1,)
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return output
def _uniform(
low: float,
high: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
assert low < high, "Uniform is not defined when low >= high"
if size is None:
size = (1,)
op = UniformRNG(seed=seed, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return low + (high - low) * output
class RNG:
def __init__(self, seed=0, device=None):
self.seed = seed
self.device = device if device else get_default_device()
self.handle = _new_rng_handle(self.device, self.seed)
def uniform(
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
):
return _uniform(
low=low,
high=high,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)
seed(seed=int(time.time()) + get_rank()) def normal(
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
):
return _normal(
mean=mean,
std=std,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)
def __del__(self):
_delete_rng_handle(self.handle)
def _random_seed_generator():
assert _rng
while True: while True:
yield _rng.random_raw() yield _rng.random_raw()
...@@ -25,3 +102,7 @@ def _random_seed_generator(): ...@@ -25,3 +102,7 @@ def _random_seed_generator():
def seed(seed: int): def seed(seed: int):
global _rng # pylint: disable=global-statement global _rng # pylint: disable=global-statement
_rng = MT19937(seed=seed) _rng = MT19937(seed=seed)
_set_global_rng_seed(seed)
seed(int(time.time()))
...@@ -10,7 +10,10 @@ ...@@ -10,7 +10,10 @@
*/ */
#include "./ops.h" #include "./ops.h"
#include "./helper.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h" #include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h" #include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h" #include "megbrain/imperative/ops/opr_attr.h"
...@@ -491,21 +494,15 @@ void init_ops(py::module m) { ...@@ -491,21 +494,15 @@ void init_ops(py::module m) {
_init_py_op_base(m); _init_py_op_base(m);
INIT_ALL_OP(m) INIT_ALL_OP(m)
m.def("new_rng_handle", &RNGMixin::new_handle); m.def("new_rng_handle", &rng::new_handle);
// FIXME: RNG op might execute after handle released due to async dispatch, m.def("delete_rng_handle", [](size_t handle){
// which would cause memory leak or use-after-free // RNG op might execute after handle released due to async dispatch, so
m.def("delete_rng_handle", &RNGMixin::delete_handle); // we need sync before delete a handle to avoid memory leak or use-after-free
m.def("set_rng_seed", &set_rng_seed); python::interpreter_for_py->sync();
mgb::CompNode::sync_all();
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG") py_task_q.wait_all_task_finish();
.def(py::init<>()) rng::delete_handle(handle);
.def(py::init<mgb::CompNode>()) }, py::call_guard<py::gil_scoped_release>());
.def(py::init<RNGMixin::Handle>()); m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed);
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<float ,float>())
.def(py::init<float ,float, mgb::CompNode>())
.def(py::init<float ,float, RNGMixin::Handle>());
} }
...@@ -8,14 +8,21 @@ ...@@ -8,14 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np import numpy as np
import megengine
from megengine import tensor from megengine import tensor
from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import (
delete_rng_handle,
get_global_rng_seed,
new_rng_handle,
)
from megengine.core.ops.builtin import GaussianRNG, UniformRNG from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.core.tensor.core import apply from megengine.random import RNG
from megengine.random.rng import _normal, _uniform
def test_gaussian_rng(): def test_gaussian_op():
shape = ( shape = (
8, 8,
9, 9,
...@@ -23,23 +30,16 @@ def test_gaussian_rng(): ...@@ -23,23 +30,16 @@ def test_gaussian_rng():
12, 12,
) )
shape = tensor(shape, dtype="int32") shape = tensor(shape, dtype="int32")
op = GaussianRNG(1.0, 3.0) op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0)
(output,) = apply(op, shape) (output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
assert str(output.device) == str(CompNode("xpux")) assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = GaussianRNG(-1.0, 2.0, cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2") cn = CompNode("xpu2")
seed = 233333 seed = 233333
h = new_rng_handle(cn, seed) h = new_rng_handle(cn, seed)
op = GaussianRNG(3.0, 1.0, h) op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h)
(output,) = apply(op, shape) (output,) = apply(op, shape)
delete_rng_handle(h) delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
...@@ -47,7 +47,7 @@ def test_gaussian_rng(): ...@@ -47,7 +47,7 @@ def test_gaussian_rng():
assert str(output.device) == str(cn) assert str(output.device) == str(cn)
def test_uniform_rng(): def test_uniform_op():
shape = ( shape = (
8, 8,
9, 9,
...@@ -55,22 +55,67 @@ def test_uniform_rng(): ...@@ -55,22 +55,67 @@ def test_uniform_rng():
12, 12,
) )
shape = tensor(shape, dtype="int32") shape = tensor(shape, dtype="int32")
op = UniformRNG() op = UniformRNG(seed=get_global_rng_seed())
(output,) = apply(op, shape) (output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(CompNode("xpux")) assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = UniformRNG(cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2") cn = CompNode("xpu2")
seed = 233333 seed = 233333
h = new_rng_handle(cn, seed) h = new_rng_handle(cn, seed)
op = UniformRNG(h) op = UniformRNG(seed=seed, handle=h)
(output,) = apply(op, shape) (output,) = apply(op, shape)
delete_rng_handle(h) delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn) assert str(output.device) == str(cn)
def test_UniformRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.uniform(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.uniform(size=(100,))
out3 = m3.uniform(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()
low = -234
high = 123
out = m1.uniform(low=low, high=high, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
def test_NormalRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.normal(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.normal(size=(100,))
out3 = m3.normal(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()
mean = -1
std = 2
out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - mean) / std < 0.1
assert np.abs(np.std(out.numpy()) - std) < 0.1
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* \file imperative/src/impl/ops/rng.cpp * \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
...@@ -10,23 +10,23 @@ ...@@ -10,23 +10,23 @@
*/ */
#include "megbrain/imperative/ops/rng.h" #include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h" #include "megbrain/opr/rand.h"
//#include "megbrain/common.h"
#include "../op_trait.h" #include "../op_trait.h"
#include "../dnn_op_helper.h"
namespace mgb { namespace mgb::imperative::rng {
namespace imperative {
namespace { namespace {
template <typename HandleFactory, typename THandle> template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj { class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public: public:
using DT = CompNode::DeviceType;
using Handle = THandle; using Handle = THandle;
using OpTypeInfo = size_t;
template <typename... Args> template <typename... Args>
Handle new_handle(Args&&... args) { Handle new_handle(Args&&... args) {
...@@ -38,27 +38,26 @@ public: ...@@ -38,27 +38,26 @@ public:
size_t removed = 0; size_t removed = 0;
if (!is_finalized()) { if (!is_finalized()) {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
removed = m_handle2op.erase(handle); removed = m_handle2ops.erase(handle);
} }
static_cast<HandleFactory*>(this)->do_delete_handle(handle); static_cast<HandleFactory*>(this)->do_delete_handle(handle);
return removed; return removed;
} }
template <typename DnnOp> template <typename DnnOp>
auto get_dnn_op(Handle handle, CompNode cn) { auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
mgb_assert(!is_finalized()); mgb_assert(!is_finalized());
DnnOpWithMutex* dnn_op_with_mtx; DnnOpWithMutex* dnn_op_with_mtx;
{ {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
dnn_op_with_mtx = &m_handle2op[handle]; dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
} }
auto dnn_handle = auto dnn_handle =
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle(); MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
DnnOp* dnn_op;
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx); std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
bool initialized = false; bool initialized = false;
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) != DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
nullptr) { if (dnn_op != nullptr) {
mgb_assert(dnn_op->handle() == dnn_handle); mgb_assert(dnn_op->handle() == dnn_handle);
initialized = true; initialized = true;
} else { } else {
...@@ -77,35 +76,30 @@ private: ...@@ -77,35 +76,30 @@ private:
struct DnnOpWithMutex { struct DnnOpWithMutex {
std::mutex mtx; std::mutex mtx;
std::unique_ptr<megdnn::OperatorBase> op; std::unique_ptr<megdnn::OperatorBase> op;
DnnOpWithMutex(): op{nullptr} {}
}; };
std::shared_ptr<void> on_comp_node_finalize() override { std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx); MGB_LOCK_GUARD(m_mtx);
m_handle2op.clear(); m_handle2ops.clear();
return {}; return {};
} }
std::unordered_map<Handle, DnnOpWithMutex> m_handle2op; std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops;
std::mutex m_mtx; std::mutex m_mtx;
}; };
class RNGDnnOpManager final class RNGDnnOpManager final
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> { : public DnnOpManagerT<RNGDnnOpManager, Handle> {
public: public:
Handle new_handle(CompNode comp_node, uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::new_handle(comp_node, seed);
}
size_t delete_handle(Handle handle) { size_t delete_handle(Handle handle) {
size_t ret = 0; MGB_LOCK_GUARD(sm_mtx);
{ return DnnOpManagerBase::delete_handle(handle);
MGB_LOCK_GUARD(sm_mtx);
auto iter = sm_partial2full.find(handle);
if (iter != sm_partial2full.end()) {
for (auto&& h : iter->second) {
ret += DnnOpManagerBase::delete_handle(h.second);
}
sm_partial2full.erase(iter);
}
}
ret += DnnOpManagerBase::delete_handle(handle);
return ret;
} }
Handle do_new_handle(CompNode comp_node, uint64_t seed) { Handle do_new_handle(CompNode comp_node, uint64_t seed) {
...@@ -118,32 +112,26 @@ public: ...@@ -118,32 +112,26 @@ public:
} }
static uint64_t get_seed(Handle handle) { static uint64_t get_seed(Handle handle) {
if (!handle) { return glob_default_seed; }
return reinterpret_cast<HandleData*>(handle)->seed; return reinterpret_cast<HandleData*>(handle)->seed;
} }
static CompNode get_comp_node(Handle handle) { static CompNode get_comp_node(Handle handle) {
mgb_assert(handle, "invalid handle");
return reinterpret_cast<HandleData*>(handle)->comp_node; return reinterpret_cast<HandleData*>(handle)->comp_node;
} }
static Handle get_full_handle(Handle handle, CompNode comp_node) {
if (get_comp_node(handle).valid()) {
return handle;
}
MGB_LOCK_GUARD(sm_mtx);
auto&& full = sm_partial2full[handle][comp_node];
if (!full) {
full = inst().new_handle(comp_node, get_seed(handle));
}
return full;
}
static Handle get_default_handle(CompNode comp_node) { static Handle get_default_handle(CompNode comp_node) {
static Handle glob_partial_handle = mgb_assert(comp_node.valid());
inst().new_handle(CompNode{}, glob_default_seed); MGB_LOCK_GUARD(sm_mtx);
if (!comp_node.valid()) { auto&& glob_handle = glob_default_handles[comp_node];
return glob_partial_handle; if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} else if (get_seed(glob_handle) != glob_default_seed) {
inst().DnnOpManagerBase::delete_handle(glob_handle);
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} }
return get_full_handle(glob_partial_handle, comp_node); return glob_handle;
} }
static RNGDnnOpManager& inst() { static RNGDnnOpManager& inst() {
...@@ -152,9 +140,15 @@ public: ...@@ -152,9 +140,15 @@ public:
} }
static void set_glob_default_seed(uint64_t seed) { static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
glob_default_seed = seed; glob_default_seed = seed;
} }
static uint64_t get_glob_default_seed() {
MGB_LOCK_GUARD(sm_mtx);
return glob_default_seed;
}
private: private:
struct HandleData { struct HandleData {
CompNode comp_node; CompNode comp_node;
...@@ -165,16 +159,13 @@ private: ...@@ -165,16 +159,13 @@ private:
MemPool<HandleData> m_handle_pool; MemPool<HandleData> m_handle_pool;
static std::mutex sm_mtx; static std::mutex sm_mtx;
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>> static CompNode::UnorderedMap<Handle> glob_default_handles;
sm_partial2full;
static uint64_t glob_default_seed; static uint64_t glob_default_seed;
}; };
uint64_t RNGDnnOpManager::glob_default_seed = 0; uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx; std::mutex RNGDnnOpManager::sm_mtx;
std::unordered_map<RNGDnnOpManager::Handle, CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
CompNode::UnorderedMap<RNGDnnOpManager::Handle>>
RNGDnnOpManager::sm_partial2full;
template <typename Op> template <typename Op>
struct OpMeth; struct OpMeth;
...@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> { ...@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
using Param = DnnOp::Param; using Param = DnnOp::Param;
using OpNode = mgb::opr::UniformRNG; using OpNode = mgb::opr::UniformRNG;
static Param make_param(const UniformRNG& rng) { static Param make_param(const UniformRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle())}; auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed};
} }
}; };
...@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> { ...@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
using Param = DnnOp::Param; using Param = DnnOp::Param;
using OpNode = mgb::opr::GaussianRNG; using OpNode = mgb::opr::GaussianRNG;
static Param make_param(const GaussianRNG& rng) { static Param make_param(const GaussianRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std}; auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed, rng.mean, rng.std};
} }
}; };
...@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, ...@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto dest = outputs[0]; auto dest = outputs[0];
auto cn = dest->comp_node(); auto cn = dest->comp_node();
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn); auto handle = rng.handle;
{ if (!handle) {
auto handle_cn = RNGDnnOpManager::get_comp_node(handle); handle = RNGDnnOpManager::get_default_handle(cn);
mgb_assert(cn == handle_cn,
"inconsistent comp_node: handle: %s, output: %s",
cn.to_string().c_str(), handle_cn.to_string().c_str());
} }
// retrieve dnn_op from glob cache // retrieve dnn_op from glob cache
auto dnn_op_thread_safe = RNGDnnOpManager::inst() auto dnn_op_thread_safe = RNGDnnOpManager::inst()
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn); .get_dnn_op<typename OpMeth<Op>::DnnOp>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()),
cn);
auto initialized = std::get<0>(dnn_op_thread_safe); auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe); auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) { if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle); auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(dnn_op->param().seed == handle_seed, mgb_assert(dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %zu, dnn_op: %zu", "inconsistent rng seed: handle: %lu, dnn_op: %lu",
handle_seed, dnn_op->param().seed); handle_seed, dnn_op->param().seed);
} }
dnn_op->param() = OpMeth<Op>::make_param(rng); dnn_op->param() = OpMeth<Op>::make_param(rng);
...@@ -239,9 +237,12 @@ template <typename Op> ...@@ -239,9 +237,12 @@ template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs( SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) { const OpDef& op, const SmallVector<TensorPtr>& inputs) {
LogicalTensorDesc dest; LogicalTensorDesc dest;
dest.comp_node = op.cast_final_safe<Op>().comp_node(); auto handle = op.cast_final_safe<Op>().handle;
if (!dest.comp_node.valid()) if (handle) {
dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dest.comp_node = inputs[0]->comp_node(); dest.comp_node = inputs[0]->comp_node();
}
auto hv = inputs[0]->get_value().proxy_to_default_cpu(); auto hv = inputs[0]->get_value().proxy_to_default_cpu();
TensorShape tshape; TensorShape tshape;
...@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor( ...@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
} }
template<typename Op> template<typename Op>
cg::OperatorNodeBase* apply_on_var_node( SymbolVar apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) { const OpDef& def,
const VarNodeArray& inputs) {
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually",
nr_inp);
auto&& rng = def.cast_final_safe<Op>(); auto&& rng = def.cast_final_safe<Op>();
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
rng.dyn_typeinfo()->name,
nr_inp);
auto param = OpMeth<Op>::make_param(rng); auto param = OpMeth<Op>::make_param(rng);
return OpMeth<Op>::OpNode::make( OperatorNodeConfig config;
inputs[0], param, {rng.comp_node()}).node()->owner_opr(); if (rng.handle) {
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)};
} else {
config = {rng.make_name()};
}
return OpMeth<Op>::OpNode::make(inputs[0], param, config);
} }
template<typename T> template<typename T>
...@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
} // anonymous namespace } // anonymous namespace
RNGMixin::RNGMixin(CompNode cn): Handle new_handle(CompNode comp_node, uint64_t seed) {
m_handle(RNGDnnOpManager::get_default_handle(cn)) {}
uint64_t RNGMixin::seed() const {
return RNGDnnOpManager::get_seed(m_handle);
}
CompNode RNGMixin::comp_node() const {
return RNGDnnOpManager::get_comp_node(m_handle);
}
RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed); return RNGDnnOpManager::inst().new_handle(comp_node, seed);
} }
size_t RNGMixin::delete_handle(Handle handle) { size_t delete_handle(Handle handle) {
return RNGDnnOpManager::inst().delete_handle(handle); return RNGDnnOpManager::inst().delete_handle(handle);
} }
void set_rng_seed(uint64_t seed) { void set_global_rng_seed(uint64_t seed) {
RNGDnnOpManager::set_glob_default_seed(seed); RNGDnnOpManager::set_glob_default_seed(seed);
} }
uint64_t get_global_rng_seed() {
return RNGDnnOpManager::get_glob_default_seed();
}
#define REG_RNG_OP(NAME)\ #define REG_RNG_OP(NAME)\
namespace { \ namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
...@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ ...@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \ .fallback(); \
} \ } \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);
REG_RNG_OP(UniformRNG) REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG) REG_RNG_OP(GaussianRNG)
} // namespace imperative } // namespace mgb::imperative::rng
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual) ...@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
.fallback(); .fallback();
}} // assert_equal }} // assert_equal
namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::UniformRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // uniform_rng
namespace { namespace gaussian_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::GaussianRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // gaussian_rng
namespace { namespace roi_align { namespace { namespace roi_align {
VarNodeArray apply_on_var_node( VarNodeArray apply_on_var_node(
const OpDef& def, const OpDef& def,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
* \file imperative/src/include/megbrain/imperative/ops/rng.h * \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
* *
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
...@@ -12,84 +12,15 @@ ...@@ -12,84 +12,15 @@
#pragma once #pragma once
#include "megbrain/imperative/op_def.h" #include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative { namespace mgb::imperative::rng {
class RNGMixin { using Handle = size_t;
public:
using Handle = size_t;
static Handle new_handle( Handle new_handle(CompNode comp_node, uint64_t seed);
CompNode comp_node={}, uint64_t seed=0); size_t delete_handle(Handle handle);
void set_global_rng_seed(uint64_t seed);
uint64_t get_global_rng_seed();
static size_t delete_handle(Handle handle); } // namespace mgb::imperative::rng
Handle handle() const {
return m_handle;
}
uint64_t seed() const;
CompNode comp_node() const;
protected:
RNGMixin(Handle handle): m_handle(handle) {}
RNGMixin(CompNode comp_node);
private:
Handle m_handle;
};
class GaussianRNG : public OpDefImplBase<GaussianRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float mean = 1.0f, std = 0.0;
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {}
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}):
GaussianRNG(comp_node_) { mean = mean_; std = std_; }
GaussianRNG(float mean_, float std_, Handle handle):
RNGMixin(handle), mean(mean_), std(std_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(dyn_typeinfo());
append(seed());
append(mean);
append(std);
return xxhash.digest();
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const GaussianRNG&>(rhs_);
return rhs.seed() == seed()
&& rhs.mean == mean
&& rhs.std == std;
}
};
class UniformRNG : public OpDefImplBase<UniformRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {}
UniformRNG(Handle handle): RNGMixin(handle) {}
size_t hash() const override {
return hash_pair_combine(
mgb::hash(seed()),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const UniformRNG&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.seed() == seed();
}
};
void set_rng_seed(uint64_t seed);
} // namespace mgb::imperative
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
using namespace mgb; using namespace mgb;
using namespace imperative; using namespace imperative;
using namespace imperative::rng;
template<typename Op, typename ...Args> template<typename Op, typename ...Args>
void check_rng_basic(Args&& ...args) { void check_rng_basic(Args&& ...args) {
...@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) { ...@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
{3, 4, 5, 6}, {3, 4, 5, 6},
{2333}}) {2333}})
for (auto&& cn: { for (auto&& cn: {
CompNode::load("cpu0"), CompNode::load("xpu0"),
CompNode::load("xpu0")}) CompNode::load("xpu1")})
{ {
auto op = Op::make(std::forward<Args>(args)..., cn); Handle h = new_handle(cn, 123);
auto op = Op::make(std::forward<Args>(args)..., h);
DeviceTensorND tshape_dev; DeviceTensorND tshape_dev;
cg::copy_shape_to_tensor_value(tshape_dev, tshape); cg::copy_shape_to_tensor_value(tshape_dev, tshape);
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)}); SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)};
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs);
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape)); ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape));
ASSERT_TRUE(cn == outputs[0]->comp_node()); ASSERT_TRUE(cn == outputs[0]->comp_node());
// sync before delete handle
for (auto&& p: outputs) {
p->get_value();
}
delete_handle(h);
} }
} }
TEST(TestImperative, UniformRNGBasic) { TEST(TestImperative, UniformRNGBasic) {
check_rng_basic<UniformRNG>(); check_rng_basic<UniformRNG>(123);
} }
TEST(TestImperative, GaussianRNGBasic) { TEST(TestImperative, GaussianRNGBasic) {
check_rng_basic<GaussianRNG>(2.f, 3.f); check_rng_basic<GaussianRNG>(123, 2.f, 3.f);
} }
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>; ...@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>; def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}]; let extraArguments = (ins
let cmpFunction = [{return true;}]; MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle));
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
} }
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{ let hashFunction = [{
return mgb::hash_pair_combine( return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()), mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std))); mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.mean),
mgb::hash($_self.std))
)
);
}]; }];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}]; let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}];
} }
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册