提交 7234efe1 编写于 作者: M Megvii Engine Team

feat(opr): let random op support empty output

GitOrigin-RevId: a2174975aac82724e435f09b4c815a82c8b70c69
上级 3bc94738
...@@ -209,7 +209,8 @@ def _poisson( ...@@ -209,7 +209,8 @@ def _poisson(
def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor: def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor:
assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0" assert isinstance(n, int)
assert n >= 0, "Permutation is not defined when n < 0"
size = (n,) size = (n,)
op = PermutationRNG(seed=seed, handle=handle, dtype=dtype) op = PermutationRNG(seed=seed, handle=handle, dtype=dtype)
_ref = Tensor([], dtype="int32", device=device) _ref = Tensor([], dtype="int32", device=device)
......
...@@ -10,7 +10,7 @@ import numpy as np ...@@ -10,7 +10,7 @@ import numpy as np
import pytest import pytest
import megengine.functional as F import megengine.functional as F
from megengine import Tensor from megengine import Tensor, jit, random
from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt import CompNode
from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.core2 import apply
from megengine.core._imperative_rt.ops import ( from megengine.core._imperative_rt.ops import (
...@@ -402,3 +402,44 @@ def test_seed(): ...@@ -402,3 +402,44 @@ def test_seed():
seed(11) seed(11)
out4 = uniform(size=[10, 10]) out4 = uniform(size=[10, 10])
assert not (out1.numpy() == out4.numpy()).all() assert not (out1.numpy() == out4.numpy()).all()
@pytest.mark.parametrize("is_symbolic", [None, False, True])
def test_rng_empty_tensor(is_symbolic):
shapes = [
(0,),
(0, 0, 0),
(10, 0, 10),
]
def fn(shape):
o1 = random.uniform(0, 1, shape)
o2 = random.normal(0, 1, shape)
o3 = random.gamma(2, 1, shape)
o4 = random.beta(2, 1, shape)
o5 = random.poisson(2, shape)
return o1, o2, o3, o4, o5
for shape in shapes:
if is_symbolic is not None:
fn_ = jit.trace(symbolic=is_symbolic)(fn)
else:
fn_ = fn
for _ in range(3):
outs = fn_(shape)
for out in outs:
np.testing.assert_equal(out.numpy().shape, shape)
if is_symbolic is None:
break
def fn2(n):
return random.permutation(n=n)
if is_symbolic is not None:
fn2 = jit.trace(symbolic=is_symbolic)(fn2)
for _ in range(3):
out = fn2(0)
np.testing.assert_equal(out.numpy().shape, (0,))
if is_symbolic is None:
break
...@@ -312,12 +312,7 @@ struct _InferLayout<false> ...@@ -312,12 +312,7 @@ struct _InferLayout<false>
template<typename Op> template<typename Op>
static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){
size_t size = inp.layout.total_nr_elems(); mgb_assert(inp.layout.ndim);
mgb_assert(
size > 0,
"target size of %s expects size>0; got size=%lu actually",
rng.dyn_typeinfo()->name,
size);
return inp.layout; return inp.layout;
} }
}; };
...@@ -376,6 +371,7 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, ...@@ -376,6 +371,7 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
auto&& rng = op.cast_final_safe<Op>(); auto&& rng = op.cast_final_safe<Op>();
auto dest = outputs[0]; auto dest = outputs[0];
if (dest->layout().is_empty()) return;
auto cn = dest->comp_node(); auto cn = dest->comp_node();
auto handle = rng.handle; auto handle = rng.handle;
if (!handle) { if (!handle) {
......
...@@ -48,6 +48,9 @@ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { ...@@ -48,6 +48,9 @@ cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const {
auto prop = Super::do_make_node_prop(); \ auto prop = Super::do_make_node_prop(); \
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \
prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \
for (auto i: input()) { \
prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \
} \
return prop; \ return prop; \
} \ } \
RNGOpr::RNGOpr(VarNode *shape, const Param &param, \ RNGOpr::RNGOpr(VarNode *shape, const Param &param, \
...@@ -56,7 +59,7 @@ RNGOpr::RNGOpr(VarNode *shape, const Param &param, ...@@ -56,7 +59,7 @@ RNGOpr::RNGOpr(VarNode *shape, const Param &param,
{ \ { \
DType dtype = DType::from_enum(param.dtype); \ DType dtype = DType::from_enum(param.dtype); \
add_input({shape}); \ add_input({shape}); \
add_output(None)->dtype(dtype); \ add_output(None)->dtype(dtype).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \
cg::add_workspace_output(this); \ cg::add_workspace_output(this); \
add_equivalence_component<ScalarHash<void*>>(this); \ add_equivalence_component<ScalarHash<void*>>(this); \
} \ } \
...@@ -84,7 +87,12 @@ void RNGOpr::init_output_static_infer_desc() { ...@@ -84,7 +87,12 @@ void RNGOpr::init_output_static_infer_desc() {
{SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \
} \ } \
void RNGOpr::scn_do_execute() { \ void RNGOpr::scn_do_execute() { \
m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \ auto&& ret = output(0); \
if (ret->layout().is_empty()) { \
mgb_assert(ret->dev_tensor().empty()); \
return; \
} \
m_dnn_opr->exec(ret->dev_tensor().as_megdnn(), \
get_megdnn_workspace_from_var(output(1))); \ get_megdnn_workspace_from_var(output(1))); \
} }
...@@ -105,7 +113,7 @@ RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param &param, ...@@ -105,7 +113,7 @@ RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param &param,
Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \
{ \ { \
add_input({_INPUTS(,)}); \ add_input({_INPUTS(,)}); \
add_output(None)->dtype(i0->dtype()); \ add_output(None)->dtype(i0->dtype()).add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); \
cg::add_workspace_output(this); \ cg::add_workspace_output(this); \
add_equivalence_component<ScalarHash<void*>>(this); \ add_equivalence_component<ScalarHash<void*>>(this); \
} \ } \
...@@ -132,9 +140,22 @@ void RNGOpr::add_input_layout_constraint(){ ...@@ -132,9 +140,22 @@ void RNGOpr::add_input_layout_constraint(){
for (auto i : input()) i->add_layout_constraint_contiguous(); \ for (auto i : input()) i->add_layout_constraint_contiguous(); \
}; \ }; \
void RNGOpr::scn_do_execute() { \ void RNGOpr::scn_do_execute() { \
auto&& ret = output(0); \
if (ret->layout().is_empty()) { \
mgb_assert(ret->dev_tensor().empty()); \
return; \
} \
m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \
get_megdnn_workspace_from_var(output(1))); \ get_megdnn_workspace_from_var(output(1))); \
} } \
cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \
auto prop = Super::do_make_node_prop(); \
prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \
for (auto i: input()) { \
prop->add_dep_type_existing_var(i, NodeProp::DepType::VALUE_ALLOW_EMPTY); \
} \
return prop; \
}
/* ================= 1 input ================= */ /* ================= 1 input ================= */
#define _INPUTS(prefix, subfix) prefix i0 subfix #define _INPUTS(prefix, subfix) prefix i0 subfix
......
...@@ -67,6 +67,7 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) ...@@ -67,6 +67,7 @@ _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG)
#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ #define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
void add_input_layout_constraint() override; \ void add_input_layout_constraint() override; \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
public: \ public: \
RNG(_INPUTS(VarNode*), const Param &param, \ RNG(_INPUTS(VarNode*), const Param &param, \
const OperatorNodeConfig &config); \ const OperatorNodeConfig &config); \
......
...@@ -247,6 +247,92 @@ TEST(TestOprRand, PermutationRNG) { ...@@ -247,6 +247,92 @@ TEST(TestOprRand, PermutationRNG) {
} }
} }
TEST(TestOprRand, EmptyShape) {
auto test_uniform = []() {
static constexpr size_t M = 128, N = 0;
auto graph = ComputingGraph::make();
SymbolVar dev_out = opr::UniformRNG::make(
*graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32});
HostTensorND host_out;
auto func = graph->compile({make_callback_copy(dev_out, host_out)});
func->execute();
ASSERT_EQ(host_out.shape(), TensorShape({M, N}));
};
auto test_gaussian = []() {
size_t SIZE = 0;
constexpr float MEAN = 1, STD = 2;
auto graph = ComputingGraph::make();
auto y = opr::GaussianRNG::make(
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}),
{23, MEAN, STD, DTypeEnum::Float32});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({SIZE}), host_y.shape());
};
auto test_gamma = []() {
std::shared_ptr<HostTensorND> shape_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
std::shared_ptr<HostTensorND> scale_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host);
auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host);
auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_poisson = []() {
std::shared_ptr<HostTensorND> lam_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host);
auto y = opr::PoissonRNG::make(lam_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_beta = []() {
std::shared_ptr<HostTensorND> alpha_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
std::shared_ptr<HostTensorND> beta_host(new HostTensorND{
CompNode::load("xpux"), TensorShape{10, 0}, dtype::Float32()});
auto graph = ComputingGraph::make();
auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host);
auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host);
auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({10, 0}), host_y.shape());
};
auto test_permutation = []() {
static constexpr size_t SIZE = 0;
auto graph = ComputingGraph::make();
auto y = opr::PermutationRNG::make(
SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}),
{23, DTypeEnum::Int32});
HostTensorND host_y;
auto func = graph->compile({make_callback_copy(y, host_y)});
func->execute();
ASSERT_EQ(TensorShape({SIZE}), host_y.shape());
};
test_uniform();
test_gaussian();
test_gamma();
test_poisson();
test_beta();
test_permutation();
}
TEST(TestOprRand, UniformReprod) { TEST(TestOprRand, UniformReprod) {
static constexpr size_t SIZE = 123; static constexpr size_t SIZE = 123;
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册