提交 7234efe1 编写于 作者: M Megvii Engine Team

feat(opr): let random op support empty output

GitOrigin-RevId: a2174975aac82724e435f09b4c815a82c8b70c69
上级 3bc94738
......@@ -209,7 +209,8 @@ def _poisson(
def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0"
assert isinstance(n, int)
assert n >= 0, "Permutation is not defined when n < 0"
size = (n,)
op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
_ref = Tensor([], dtype="int32", device=device)
......
......@@ -10,7 +10,7 @@ import numpy as np
import pytest
import megengine.functional as F
from megengine import Tensor
from megengine import Tensor, jit, random
from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import (
......@@ -402,3 +402,44 @@ def test_seed():
seed(11)
out4 = uniform(size=[10, 10])
assert not (out1.numpy() == out4.numpy()).all()
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_rng_empty_tensor(is_symbolic):
shapes = [
(0,),
(0, 0, 0),
(10, 0, 10),
]
def fn(shape):
o1 = random.uniform(0, 1, shape)
o2 = random.normal(0, 1, shape)
o3 = random.gamma(2, 1, shape)
o4 = random.beta(2, 1, shape)
o5 = random.poisson(2, shape)
return o1, o2, o3, o4, o5
for shape in shapes:
if is_symbolic is not None:
fn_ = jit.trace(symbolic=is_symbolic)(fn)
else:
fn_ = fn
for _ in range(3):
outs = fn_(shape)
for out in outs:
np.testing.assert_equal(out.numpy().shape, shape)
if is_symbolic is None:
break
def fn2(n):
return random.permutation(n=n)
if is_symbolic is not None:
fn2 = jit.trace(symbolic=is_symbolic)(fn2)
for _ in range(3):
out = fn2(0)
np.testing.assert_equal(out.numpy().shape, (0,))
if is_symbolic is None:
break
......@@ -312,12 +312,7 @@ struct _InferLayout<false>
template<typename Op>
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
size_t size = inp.layout.total_nr_elems();
mgb_assert(
size > 0,
"target size of %s expects size>0; got size=%lu actually",
rng.dyn_typeinfo()->name,
size);
mgb_assert(inp.layout.ndim);
return inp.layout;
}
};
......@@ -376,6 +371,7 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0];
if (dest->layout().is_empty()) return;
auto cn = dest->comp_node();
auto handle = rng.handle;
if (!handle) {
......
......@@ -48,6 +48,9 @@ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const {
auto prop = Super::do_make_node_prop(); \
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \
prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \
for (auto i: input()) { \
prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \
} \
return prop; \
} \
RNGOpr::RNGOpr(VarNode *shape, const Param &param, \
......@@ -56,7 +59,7 @@ RNGOpr::RNGOpr(VarNode *shape, const Param &param,
{ \
DType dtype = DType::from_enum(param.dtype); \
add_input({shape}); \
add_output(None)->dtype(dtype); \
add_output(None)->dtype(dtype).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \
cg::add_workspace_output(this); \
add_equivalence_component<ScalarHash<void*>>(this); \
} \
......@@ -84,7 +87,12 @@ void RNGOpr::init_output_static_infer_desc() {
{SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \
} \
void RNGOpr::scn_do_execute() { \
m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \
auto&& ret = output(0); \
if (ret->layout().is_empty()) { \
mgb_assert(ret->dev_tensor().empty()); \
return; \
} \
m_dnn_opr->exec(ret->dev_tensor().as_megdnn(), \
get_megdnn_workspace_from_var(output(1))); \
}
......@@ -105,7 +113,7 @@ RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param &param,
Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \
{ \
add_input({_INPUTS(,)}); \
add_output(None)->dtype(i0->dtype()); \
add_output(None)->dtype(i0->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \
cg::add_workspace_output(this); \
add_equivalence_component<ScalarHash<void*>>(this); \
} \
......@@ -132,9 +140,22 @@ void RNGOpr::add_input_layout_constraint(){
for (auto i : input()) i->add_layout_constraint_contiguous(); \
}; \
void RNGOpr::scn_do_execute() { \
auto&& ret = output(0); \
if (ret->layout().is_empty()) { \
mgb_assert(ret->dev_tensor().empty()); \
return; \
} \
m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \
get_megdnn_workspace_from_var(output(1))); \
}
} \
cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \
auto prop = Super::do_make_node_prop(); \
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \
for (auto i: input()) { \
prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \
} \
return prop; \
}
/* ================= 1 input ================= */
#define _INPUTS(prefix, subfix) prefix i0 subfix
......
......@@ -67,6 +67,7 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG)
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
public: \
RNG(_INPUTS(VarNode*), const Param &param, \
const OperatorNodeConfig &config); \
......
......@@ -247,6 +247,92 @@ TEST(TestOprRand, PermutationRNG) {
}
}
TEST(TestOprRand, EmptyShape) {
auto test_uniform = []() {
static constexpr size_t M = 128, N = 0;
auto graph = ComputingGraph::make();
SymbolVar dev_out = opr::UniformRNG::make(
*graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32});
HostTensorND host_out;
auto func = graph->compile({make_callback_copy(dev_out, host_out)});
func->execute();
ASSERT_EQ(host_out.shape(), TensorShape({M, N}));
};
auto test_gaussian = []() {
size_t SIZE = 0;
constexpr float MEAN = 1, STD = 2;
auto graph = ComputingGraph::make();
auto y = opr::GaussianRNG::make(
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}),
{23, MEAN, STD, DTypeEnum::Float32});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({SIZE}), host_y.shape());
};
auto test_gamma = []() {
std::shared_ptr<HostTensorND> shape_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
std::shared_ptr<HostTensorND> scale_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host);
auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host);
auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_poisson = []() {
std::shared_ptr<HostTensorND> lam_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host);
auto y = opr::PoissonRNG::make(lam_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_beta = []() {
std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
std::shared_ptr<HostTensorND> beta_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host);
auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host);
auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_permutation = []() {
static constexpr size_t SIZE = 0;
auto graph = ComputingGraph::make();
auto y = opr::PermutationRNG::make(
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}),
{23, DTypeEnum::Int32});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({SIZE}), host_y.shape());
};
test_uniform();
test_gaussian();
test_gamma();
test_poisson();
test_beta();
test_permutation();
}
TEST(TestOprRand, UniformReprod) {
static constexpr size_t SIZE = 123;
auto graph = ComputingGraph::make();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册