提交 13e6ea34 编写于 作者: M Megvii Engine Team

feat(imperative/opr): rebase rng refactoring to dev & add python module

GitOrigin-RevId: ee5984c52d3fa346d5f26d737bf40ec4ed43b2c7
上级 cded8ef1
......@@ -156,7 +156,8 @@ def _logical_binary_elwise(mode, rev=False):
def _remove_axis(inp: Tensor, axis) -> Tensor:
def get_axes():
if axis is None:
return [i for i, s in enumerate(inp.shape) if s == 1]
shp = inp.shape
return [i for i, s in enumerate(shp) if s == 1]
try:
return [int(axis)]
except (TypeError, ValueError):
......
......@@ -6,9 +6,11 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from typing import List, Optional, Tuple
from ..device import set_default_device, what_is_xpu
from ..random import seed
from .server import Client, Server
......@@ -156,6 +158,7 @@ def init_process_group(
WORLD.reset(list(range(world_size)))
set_default_device("{}{}".format(device_type, device))
seed(int(time.time()) + rank)
def is_distributed() -> bool:
......
......@@ -7,7 +7,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .distribution import normal, uniform
from .rng import seed
from .rng import RNG, seed
# pylint: disable=undefined-variable
del distribution, rng # type: ignore[name-defined]
......@@ -9,11 +9,8 @@
from typing import Iterable, Optional
from .. import Tensor
from ..core._imperative_rt import invoke_op
from ..core._imperative_rt.core2 import apply
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from .rng import _random_seed_generator
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from .rng import _normal, _uniform
__all__ = ["normal", "uniform"]
......@@ -48,14 +45,14 @@ def normal(
[-1.4939808 -1.5824696 ]]
"""
if size is None:
size = (1,)
op = GaussianRNG(mean, std)
_ref = Tensor([], dtype="int32")
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
return output
return _normal(
mean=mean,
std=std,
size=size,
seed=_get_global_rng_seed(),
device=None,
handle=0,
)
def uniform(
......@@ -88,14 +85,11 @@ def uniform(
[0.09365904 0.62957656]]
"""
assert low < high, "Uniform is not defined when low >= high"
if size is None:
size = (1,)
op = UniformRNG()
_ref = Tensor([], dtype="int32")
shape = utils.astensor1d(size, _ref, dtype="int32")
shape = Tensor(shape, dtype="int32")
(output,) = apply(op, shape)
return low + (high - low) * output
return _uniform(
low=low,
high=high,
size=size,
seed=_get_global_rng_seed(),
device=None,
handle=0,
)
......@@ -7,17 +7,94 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import time
from typing import Iterable, Optional
from numpy.random import MT19937
from .. import Tensor
from ..core._imperative_rt.core2 import apply
from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle
from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed
from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle
from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed
from ..core.ops.builtin import GaussianRNG, UniformRNG
from ..core.tensor import utils
from ..device import get_default_device
_rng = None
def _random_seed_generator():
if _rng is None:
from ..distributed.group import get_rank
def _normal(
mean: float,
std: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
if size is None:
size = (1,)
op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return output
def _uniform(
low: float,
high: float,
size: Optional[Iterable[int]],
seed: int,
device: str,
handle: int,
) -> Tensor:
assert low < high, "Uniform is not defined when low >= high"
if size is None:
size = (1,)
op = UniformRNG(seed=seed, handle=handle)
_ref = Tensor([], dtype="int32", device=device)
shape = utils.astensor1d(size, _ref, dtype="int32", device=device)
(output,) = apply(op, shape)
return low + (high - low) * output
class RNG:
def __init__(self, seed=0, device=None):
self.seed = seed
self.device = device if device else get_default_device()
self.handle = _new_rng_handle(self.device, self.seed)
def uniform(
self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None
):
return _uniform(
low=low,
high=high,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)
seed(seed=int(time.time()) + get_rank())
def normal(
self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None
):
return _normal(
mean=mean,
std=std,
size=size,
seed=self.seed,
device=self.device,
handle=self.handle,
)
def __del__(self):
_delete_rng_handle(self.handle)
def _random_seed_generator():
assert _rng
while True:
yield _rng.random_raw()
......@@ -25,3 +102,7 @@ def _random_seed_generator():
def seed(seed: int):
global _rng # pylint: disable=global-statement
_rng = MT19937(seed=seed)
_set_global_rng_seed(seed)
seed(int(time.time()))
......@@ -10,7 +10,10 @@
*/
#include "./ops.h"
#include "./helper.h"
#include "./tensor.h"
#include "megbrain/common.h"
#include "megbrain/imperative.h"
#include "megbrain/imperative/ops/backward_graph.h"
#include "megbrain/imperative/ops/opr_attr.h"
......@@ -491,21 +494,15 @@ void init_ops(py::module m) {
_init_py_op_base(m);
INIT_ALL_OP(m)
m.def("new_rng_handle", &RNGMixin::new_handle);
// FIXME: RNG op might execute after handle released due to async dispatch,
// which would cause memory leak or use-after-free
m.def("delete_rng_handle", &RNGMixin::delete_handle);
m.def("set_rng_seed", &set_rng_seed);
py::class_<UniformRNG, std::shared_ptr<UniformRNG>, OpDef>(m, "UniformRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<RNGMixin::Handle>());
py::class_<GaussianRNG, std::shared_ptr<GaussianRNG>, OpDef>(m, "GaussianRNG")
.def(py::init<>())
.def(py::init<mgb::CompNode>())
.def(py::init<float ,float>())
.def(py::init<float ,float, mgb::CompNode>())
.def(py::init<float ,float, RNGMixin::Handle>());
m.def("new_rng_handle", &rng::new_handle);
m.def("delete_rng_handle", [](size_t handle){
// RNG op might execute after handle released due to async dispatch, so
// we need sync before delete a handle to avoid memory leak or use-after-free
python::interpreter_for_py->sync();
mgb::CompNode::sync_all();
py_task_q.wait_all_task_finish();
rng::delete_handle(handle);
}, py::call_guard<py::gil_scoped_release>());
m.def("set_global_rng_seed", &rng::set_global_rng_seed);
m.def("get_global_rng_seed", &rng::get_global_rng_seed);
}
......@@ -8,14 +8,21 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import megengine
from megengine import tensor
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.ops import delete_rng_handle, new_rng_handle
from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import (
delete_rng_handle,
get_global_rng_seed,
new_rng_handle,
)
from megengine.core.ops.builtin import GaussianRNG, UniformRNG
from megengine.core.tensor.core import apply
from megengine.random import RNG
from megengine.random.rng import _normal, _uniform
def test_gaussian_rng():
def test_gaussian_op():
shape = (
8,
9,
......@@ -23,23 +30,16 @@ def test_gaussian_rng():
12,
)
shape = tensor(shape, dtype="int32")
op = GaussianRNG(1.0, 3.0)
op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 1.0) < 1e-1
assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1
assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = GaussianRNG(-1.0, 2.0, cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - (-1.0)) < 1e-1
assert np.sqrt(output.numpy().var()) - 2.0 < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = GaussianRNG(3.0, 1.0, h)
op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 3.0) < 1e-1
......@@ -47,7 +47,7 @@ def test_gaussian_rng():
assert str(output.device) == str(cn)
def test_uniform_rng():
def test_uniform_op():
shape = (
8,
9,
......@@ -55,22 +55,67 @@ def test_uniform_rng():
12,
)
shape = tensor(shape, dtype="int32")
op = UniformRNG()
op = UniformRNG(seed=get_global_rng_seed())
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(CompNode("xpux"))
cn = CompNode("xpu1")
op = UniformRNG(cn)
(output,) = apply(op, shape)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)
cn = CompNode("xpu2")
seed = 233333
h = new_rng_handle(cn, seed)
op = UniformRNG(h)
op = UniformRNG(seed=seed, handle=h)
(output,) = apply(op, shape)
delete_rng_handle(h)
assert np.fabs(output.numpy().mean() - 0.5) < 1e-1
assert str(output.device) == str(cn)
def test_UniformRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.uniform(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.uniform(size=(100,))
out3 = m3.uniform(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()
low = -234
high = 123
out = m1.uniform(low=low, high=high, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - ((low + high) / 2)) / (high - low) < 0.1
def test_NormalRNG():
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = m1.normal(size=(100,))
out1_ = m1.uniform(size=(100,))
out2 = m2.normal(size=(100,))
out3 = m3.normal(size=(100,))
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
assert not (out1.numpy() == out1_.numpy()).all()
mean = -1
std = 2
out = m1.normal(mean=mean, std=std, size=(20, 30, 40))
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (20, 30, 40)
else:
assert all(out.shape.numpy() == np.array([20, 30, 40]))
assert np.abs(out.mean().numpy() - mean) / std < 0.1
assert np.abs(np.std(out.numpy()) - std) < 0.1
......@@ -2,7 +2,7 @@
* \file imperative/src/impl/ops/rng.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
......@@ -10,23 +10,23 @@
*/
#include "megbrain/imperative/ops/rng.h"
#include <bits/stdint-uintn.h>
#include "megbrain/comp_node_env.h"
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
//#include "megbrain/common.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h"
namespace mgb {
namespace imperative {
namespace mgb::imperative::rng {
namespace {
template <typename HandleFactory, typename THandle>
class DnnOpManagerT : public CompNodeDepedentObject, public NonCopyableObj {
public:
using DT = CompNode::DeviceType;
using Handle = THandle;
using OpTypeInfo = size_t;
template <typename... Args>
Handle new_handle(Args&&... args) {
......@@ -38,27 +38,26 @@ public:
size_t removed = 0;
if (!is_finalized()) {
MGB_LOCK_GUARD(m_mtx);
removed = m_handle2op.erase(handle);
removed = m_handle2ops.erase(handle);
}
static_cast<HandleFactory*>(this)->do_delete_handle(handle);
return removed;
}
template <typename DnnOp>
auto get_dnn_op(Handle handle, CompNode cn) {
auto get_dnn_op(Handle handle, OpTypeInfo tpinfo, CompNode cn) {
mgb_assert(!is_finalized());
DnnOpWithMutex* dnn_op_with_mtx;
{
MGB_LOCK_GUARD(m_mtx);
dnn_op_with_mtx = &m_handle2op[handle];
dnn_op_with_mtx = &m_handle2ops[handle][tpinfo];
}
auto dnn_handle =
MegDNNHandle::get(CompNodeEnv::from_comp_node(cn)).handle();
DnnOp* dnn_op;
std::unique_lock<std::mutex> lock(dnn_op_with_mtx->mtx);
bool initialized = false;
if ((dnn_op = dynamic_cast<DnnOp*>(dnn_op_with_mtx->op.get())) !=
nullptr) {
DnnOp* dnn_op = static_cast<DnnOp*>(dnn_op_with_mtx->op.get());
if (dnn_op != nullptr) {
mgb_assert(dnn_op->handle() == dnn_handle);
initialized = true;
} else {
......@@ -77,35 +76,30 @@ private:
struct DnnOpWithMutex {
std::mutex mtx;
std::unique_ptr<megdnn::OperatorBase> op;
DnnOpWithMutex(): op{nullptr} {}
};
std::shared_ptr<void> on_comp_node_finalize() override {
MGB_LOCK_GUARD(m_mtx);
m_handle2op.clear();
m_handle2ops.clear();
return {};
}
std::unordered_map<Handle, DnnOpWithMutex> m_handle2op;
std::unordered_map<Handle, std::unordered_map<OpTypeInfo, DnnOpWithMutex> > m_handle2ops;
std::mutex m_mtx;
};
class RNGDnnOpManager final
: public DnnOpManagerT<RNGDnnOpManager, RNGMixin::Handle> {
: public DnnOpManagerT<RNGDnnOpManager, Handle> {
public:
Handle new_handle(CompNode comp_node, uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::new_handle(comp_node, seed);
}
size_t delete_handle(Handle handle) {
size_t ret = 0;
{
MGB_LOCK_GUARD(sm_mtx);
auto iter = sm_partial2full.find(handle);
if (iter != sm_partial2full.end()) {
for (auto&& h : iter->second) {
ret += DnnOpManagerBase::delete_handle(h.second);
}
sm_partial2full.erase(iter);
}
}
ret += DnnOpManagerBase::delete_handle(handle);
return ret;
MGB_LOCK_GUARD(sm_mtx);
return DnnOpManagerBase::delete_handle(handle);
}
Handle do_new_handle(CompNode comp_node, uint64_t seed) {
......@@ -118,32 +112,26 @@ public:
}
static uint64_t get_seed(Handle handle) {
if (!handle) { return glob_default_seed; }
return reinterpret_cast<HandleData*>(handle)->seed;
}
static CompNode get_comp_node(Handle handle) {
mgb_assert(handle, "invalid handle");
return reinterpret_cast<HandleData*>(handle)->comp_node;
}
static Handle get_full_handle(Handle handle, CompNode comp_node) {
if (get_comp_node(handle).valid()) {
return handle;
}
MGB_LOCK_GUARD(sm_mtx);
auto&& full = sm_partial2full[handle][comp_node];
if (!full) {
full = inst().new_handle(comp_node, get_seed(handle));
}
return full;
}
static Handle get_default_handle(CompNode comp_node) {
static Handle glob_partial_handle =
inst().new_handle(CompNode{}, glob_default_seed);
if (!comp_node.valid()) {
return glob_partial_handle;
mgb_assert(comp_node.valid());
MGB_LOCK_GUARD(sm_mtx);
auto&& glob_handle = glob_default_handles[comp_node];
if (!glob_handle) {
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
} else if (get_seed(glob_handle) != glob_default_seed) {
inst().DnnOpManagerBase::delete_handle(glob_handle);
glob_handle = inst().do_new_handle(comp_node, glob_default_seed);
}
return get_full_handle(glob_partial_handle, comp_node);
return glob_handle;
}
static RNGDnnOpManager& inst() {
......@@ -152,9 +140,15 @@ public:
}
static void set_glob_default_seed(uint64_t seed) {
MGB_LOCK_GUARD(sm_mtx);
glob_default_seed = seed;
}
static uint64_t get_glob_default_seed() {
MGB_LOCK_GUARD(sm_mtx);
return glob_default_seed;
}
private:
struct HandleData {
CompNode comp_node;
......@@ -165,16 +159,13 @@ private:
MemPool<HandleData> m_handle_pool;
static std::mutex sm_mtx;
static std::unordered_map<Handle, CompNode::UnorderedMap<Handle>>
sm_partial2full;
static CompNode::UnorderedMap<Handle> glob_default_handles;
static uint64_t glob_default_seed;
};
uint64_t RNGDnnOpManager::glob_default_seed = 0;
std::mutex RNGDnnOpManager::sm_mtx;
std::unordered_map<RNGDnnOpManager::Handle,
CompNode::UnorderedMap<RNGDnnOpManager::Handle>>
RNGDnnOpManager::sm_partial2full;
CompNode::UnorderedMap<Handle> RNGDnnOpManager::glob_default_handles;
template <typename Op>
struct OpMeth;
......@@ -185,7 +176,11 @@ struct OpMeth<UniformRNG> {
using Param = DnnOp::Param;
using OpNode = mgb::opr::UniformRNG;
static Param make_param(const UniformRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle())};
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed};
}
};
......@@ -195,7 +190,11 @@ struct OpMeth<GaussianRNG> {
using Param = DnnOp::Param;
using OpNode = mgb::opr::GaussianRNG;
static Param make_param(const GaussianRNG& rng) {
return {RNGDnnOpManager::get_seed(rng.handle()), rng.mean, rng.std};
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed, rng.mean, rng.std};
}
};
......@@ -206,23 +205,22 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto dest = outputs[0];
auto cn = dest->comp_node();
auto handle = RNGDnnOpManager::get_full_handle(rng.handle(), cn);
{
auto handle_cn = RNGDnnOpManager::get_comp_node(handle);
mgb_assert(cn == handle_cn,
"inconsistent comp_node: handle: %s, output: %s",
cn.to_string().c_str(), handle_cn.to_string().c_str());
auto handle = rng.handle;
if (!handle) {
handle = RNGDnnOpManager::get_default_handle(cn);
}
// retrieve dnn_op from glob cache
auto dnn_op_thread_safe = RNGDnnOpManager::inst()
.get_dnn_op<typename OpMeth<Op>::DnnOp>(handle, cn);
.get_dnn_op<typename OpMeth<Op>::DnnOp>(
handle, reinterpret_cast<size_t>(op.dyn_typeinfo()),
cn);
auto initialized = std::get<0>(dnn_op_thread_safe);
auto dnn_op = std::get<1>(dnn_op_thread_safe);
if (initialized) {
auto handle_seed = RNGDnnOpManager::get_seed(handle);
mgb_assert(dnn_op->param().seed == handle_seed,
"inconsistent rng seed: handle: %zu, dnn_op: %zu",
"inconsistent rng seed: handle: %lu, dnn_op: %lu",
handle_seed, dnn_op->param().seed);
}
dnn_op->param() = OpMeth<Op>::make_param(rng);
......@@ -239,9 +237,12 @@ template <typename Op>
SmallVector<LogicalTensorDesc> infer_output_attrs(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
LogicalTensorDesc dest;
dest.comp_node = op.cast_final_safe<Op>().comp_node();
if (!dest.comp_node.valid())
auto handle = op.cast_final_safe<Op>().handle;
if (handle) {
dest.comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dest.comp_node = inputs[0]->comp_node();
}
auto hv = inputs[0]->get_value().proxy_to_default_cpu();
TensorShape tshape;
......@@ -263,15 +264,22 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
}
template<typename Op>
cg::OperatorNodeBase* apply_on_var_node(
const OpDef& def, const VarNodeArray& inputs) {
SymbolVar apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
size_t nr_inp = inputs.size();
mgb_assert(nr_inp == 1, "UniformRNG expects 1 inputs; got %lu actually",
nr_inp);
auto&& rng = def.cast_final_safe<Op>();
mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually",
rng.dyn_typeinfo()->name,
nr_inp);
auto param = OpMeth<Op>::make_param(rng);
return OpMeth<Op>::OpNode::make(
inputs[0], param, {rng.comp_node()}).node()->owner_opr();
OperatorNodeConfig config;
if (rng.handle) {
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)};
} else {
config = {rng.make_name()};
}
return OpMeth<Op>::OpNode::make(inputs[0], param, config);
}
template<typename T>
......@@ -309,28 +317,22 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
} // anonymous namespace
RNGMixin::RNGMixin(CompNode cn):
m_handle(RNGDnnOpManager::get_default_handle(cn)) {}
uint64_t RNGMixin::seed() const {
return RNGDnnOpManager::get_seed(m_handle);
}
CompNode RNGMixin::comp_node() const {
return RNGDnnOpManager::get_comp_node(m_handle);
}
RNGMixin::Handle RNGMixin::new_handle(CompNode comp_node, uint64_t seed) {
Handle new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed);
}
size_t RNGMixin::delete_handle(Handle handle) {
size_t delete_handle(Handle handle) {
return RNGDnnOpManager::inst().delete_handle(handle);
}
void set_rng_seed(uint64_t seed) {
void set_global_rng_seed(uint64_t seed) {
RNGDnnOpManager::set_glob_default_seed(seed);
}
uint64_t get_global_rng_seed() {
return RNGDnnOpManager::get_glob_default_seed();
}
#define REG_RNG_OP(NAME)\
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
......@@ -339,12 +341,10 @@ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.fallback(); \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NAME);
REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG)
} // namespace imperative
} // namespace mgb
} // namespace mgb::imperative::rng
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -429,34 +429,6 @@ OP_TRAIT_REG(AssertEqual, AssertEqual)
.fallback();
}} // assert_equal
namespace { namespace uniform_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const UniformRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::UniformRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(UniformRNG, UniformRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // uniform_rng
namespace { namespace gaussian_rng {
auto apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
auto&& op = static_cast<const GaussianRNG&>(def);
mgb_assert(inputs.size() == 1);
OperatorNodeConfig config{op.make_name()};
return opr::GaussianRNG::make(inputs[0], op.param(), config);
}
OP_TRAIT_REG(GaussianRNG, GaussianRNG)
.apply_on_var_node(apply_on_var_node)
.fallback();
}} // gaussian_rng
namespace { namespace roi_align {
VarNodeArray apply_on_var_node(
const OpDef& def,
......
......@@ -2,7 +2,7 @@
* \file imperative/src/include/megbrain/imperative/ops/rng.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
......@@ -12,84 +12,15 @@
#pragma once
#include "megbrain/imperative/op_def.h"
#include "megbrain/imperative/ops/autogen.h"
namespace mgb::imperative {
namespace mgb::imperative::rng {
class RNGMixin {
public:
using Handle = size_t;
using Handle = size_t;
static Handle new_handle(
CompNode comp_node={}, uint64_t seed=0);
Handle new_handle(CompNode comp_node, uint64_t seed);
size_t delete_handle(Handle handle);
void set_global_rng_seed(uint64_t seed);
uint64_t get_global_rng_seed();
static size_t delete_handle(Handle handle);
Handle handle() const {
return m_handle;
}
uint64_t seed() const;
CompNode comp_node() const;
protected:
RNGMixin(Handle handle): m_handle(handle) {}
RNGMixin(CompNode comp_node);
private:
Handle m_handle;
};
class GaussianRNG : public OpDefImplBase<GaussianRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
float mean = 1.0f, std = 0.0;
GaussianRNG(CompNode comp_node_): RNGMixin(comp_node_) {}
GaussianRNG(float mean_=1.0, float std_=0.0, CompNode comp_node_={}):
GaussianRNG(comp_node_) { mean = mean_; std = std_; }
GaussianRNG(float mean_, float std_, Handle handle):
RNGMixin(handle), mean(mean_), std(std_) {}
size_t hash() const override {
XXHash xxhash{};
auto append = [&xxhash](auto field){
auto hash_val = HashTrait<decltype(field)>::eval(field);
xxhash.update(reinterpret_cast<void*>(&hash_val), sizeof(hash_val));
};
append(dyn_typeinfo());
append(seed());
append(mean);
append(std);
return xxhash.digest();
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const GaussianRNG&>(rhs_);
return rhs.seed() == seed()
&& rhs.mean == mean
&& rhs.std == std;
}
};
class UniformRNG : public OpDefImplBase<UniformRNG>,
public RNGMixin {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
UniformRNG(CompNode comp_node_={}): RNGMixin(comp_node_) {}
UniformRNG(Handle handle): RNGMixin(handle) {}
size_t hash() const override {
return hash_pair_combine(
mgb::hash(seed()),
reinterpret_cast<std::uintptr_t>(dyn_typeinfo()));
}
bool is_same_st(const Hashable& rhs_) const override {
auto&& rhs = static_cast<const UniformRNG&>(rhs_);
return rhs.dyn_typeinfo() == dyn_typeinfo()
&& rhs.seed() == seed();
}
};
void set_rng_seed(uint64_t seed);
} // namespace mgb::imperative
} // namespace mgb::imperative::rng
......@@ -14,6 +14,7 @@
using namespace mgb;
using namespace imperative;
using namespace imperative::rng;
template<typename Op, typename ...Args>
void check_rng_basic(Args&& ...args) {
......@@ -22,24 +23,31 @@ void check_rng_basic(Args&& ...args) {
{3, 4, 5, 6},
{2333}})
for (auto&& cn: {
CompNode::load("cpu0"),
CompNode::load("xpu0")})
CompNode::load("xpu0"),
CompNode::load("xpu1")})
{
auto op = Op::make(std::forward<Args>(args)..., cn);
Handle h = new_handle(cn, 123);
auto op = Op::make(std::forward<Args>(args)..., h);
DeviceTensorND tshape_dev;
cg::copy_shape_to_tensor_value(tshape_dev, tshape);
auto outputs = OpDef::apply_on_physical_tensor(*op, {Tensor::make(tshape_dev)});
SmallVector<TensorPtr> inputs = {Tensor::make(tshape_dev)};
auto outputs = OpDef::apply_on_physical_tensor(*op, inputs);
ASSERT_TRUE(outputs[0]->layout().eq_shape(tshape));
ASSERT_TRUE(cn == outputs[0]->comp_node());
// sync before delete handle
for (auto&& p: outputs) {
p->get_value();
}
delete_handle(h);
}
}
TEST(TestImperative, UniformRNGBasic) {
check_rng_basic<UniformRNG>();
check_rng_basic<UniformRNG>(123);
}
TEST(TestImperative, GaussianRNGBasic) {
check_rng_basic<GaussianRNG>(2.f, 3.f);
check_rng_basic<GaussianRNG>(123, 2.f, 3.f);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -114,17 +114,33 @@ def TopK: MgbHashableOp<"TopK", [TopKParam]>;
def NvOf: MgbHashableOp<"NvOf", [NvOfParam]>;
def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> {
let hashFunction = [{return mgb::hash($_self.dyn_typeinfo());}];
let cmpFunction = [{return true;}];
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle));
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash_pair_combine(mgb::hash($_self.mean), mgb::hash($_self.std)));
mgb::hash_pair_combine(
mgb::hash($_self.handle),
mgb::hash_pair_combine(
mgb::hash($_self.mean),
mgb::hash($_self.std))
)
);
}];
let cmpFunction = [{return $0.mean == $1.mean && $0.std == $1.std;}];
let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}];
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册