diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 36f18697705888104b486ec1e13cc829b92131c4..037b0d05403388764e631d0ec65ef615dd42485b 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -18,6 +18,7 @@ import megengine.amp as amp import megengine.core.ops.builtin as builtin import megengine.core.tensor.dtype as dtype import megengine.functional as F +import megengine.jit as jit from megengine import Parameter, Tensor, is_cuda_available, tensor from megengine.core._trace_option import use_symbolic_shape from megengine.core.autodiff.grad import Grad @@ -859,6 +860,35 @@ def test_condtake(): np.testing.assert_equal(idx.numpy(), np.where(y.reshape(-1))[0]) +# @pytest.mark.parametrize("is_symbolic", [None, False, True]) +def test_condtake(is_symbolic=None): + shapes = [ + (3, 3, 3), + (0,), + (3, 0, 3), + ] + + def fn(mask, data): + return F.cond_take(mask, data) + + if is_symbolic is not None: + fn = jit.trace(symbolic=is_symbolic)(fn) + + for shp in shapes: + x_np = np.random.randn(*shp).astype("float32") + mask_np = x_np > 0 + x = tensor(x_np) + mask = tensor(mask_np) + ref_out = x_np[mask_np] + ref_idx = mask_np.flatten().nonzero()[0] + for i in range(3): + out, idx = fn(mask, x) + np.testing.assert_equal(out.numpy(), ref_out) + np.testing.assert_equal(idx.numpy(), ref_idx) + if is_symbolic is None: + break + + def test_condtake_is_same(): op1 = builtin.CondTake() op2 = builtin.CondTake() diff --git a/imperative/src/impl/ops/cond_take.cpp b/imperative/src/impl/ops/cond_take.cpp index 51b3ee0427ebe9b69f0e760ab4f434a44dd2c4fd..88ad983ed5fdbfe16fdcd5fc79569d1d14a4fdfc 100644 --- a/imperative/src/impl/ops/cond_take.cpp +++ b/imperative/src/impl/ops/cond_take.cpp @@ -45,25 +45,30 @@ SmallVector apply_on_physical_tensor( auto&& inp = inputs[0]; auto&& msk = inputs[1]; + SmallVector out; mgb_assert(inp->layout().eq_shape(msk->layout()), "input shape does not match mask shape"); mgb_assert(msk->get_value().dtype().enumv() == DTypeEnum::Bool, "mask dtype must be bool"); - DnnOprCaller dnn_op(inp->comp_node()); - dnn_op.op->param().val = 1; - - TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, - dtype::Byte()); - - auto dnn_workspace = dnn_op.create_workspace(m_layout); MegDNNDynOutMallocImpl<2> policy{inp->comp_node()}; - - dnn_op.op->exec(inp->dev_tensor().as_megdnn(), - msk->dev_tensor().as_megdnn(), - dnn_workspace, - &policy); - - SmallVector out; + if (inp->layout().is_empty()) { + // empty tensor + policy.alloc_output(0, inp->layout().dtype, {0}, nullptr); + policy.alloc_output(1, dtype::Int32(), {0}, nullptr); + } else { + DnnOprCaller dnn_op(inp->comp_node()); + dnn_op.op->param().val = 1; + + TensorLayout m_layout({dnn_op.op->get_workspace_in_bytes(inp->layout())}, + dtype::Byte()); + + auto dnn_workspace = dnn_op.create_workspace(m_layout); + + dnn_op.op->exec(inp->dev_tensor().as_megdnn(), + msk->dev_tensor().as_megdnn(), + dnn_workspace, + &policy); + } out.push_back(policy.at(0)); out.push_back(policy.at(1)); return out; diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index b8c1127848cc420fdc3f4305d9f4810a46015645..c73752d0181cc3d47f412aca850225979ad8a86e 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -264,6 +264,15 @@ CondTake::CondTake(VarNode *data, VarNode *mask, } } +CondTake::NodeProp* CondTake::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(CondTake) { mgb_assert(out_grad.size() == 3 && !out_grad[2]); @@ -305,11 +314,21 @@ void CondTake::add_input_layout_constraint() { } void CondTake::scn_do_execute() { + auto&& data = input(0)->dev_tensor(); + auto&& mask = input(1)->dev_tensor(); intl::MegDNNDynOutMallocImpl dyn_malloc{this, comp_node()}; - megdnn_opr()->exec(input(0)->dev_tensor().as_megdnn(), - input(1)->dev_tensor().as_megdnn(), - intl::get_megdnn_workspace_from_var(output().back()), - &dyn_malloc); + if (data.layout().is_empty()) { + mgb_assert(data.layout().eq_shape(mask.layout()), + "CondTake shape differs: data=%s mask=%s", + data.layout().TensorShape::to_string().c_str(), + mask.layout().TensorShape::to_string().c_str()); + dyn_malloc.alloc_output(0, data.layout().dtype, {0}, nullptr); + dyn_malloc.alloc_output(1, dtype::Int32(), {0}, nullptr); + } else { + megdnn_opr()->exec(data.as_megdnn(), mask.as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back()), + &dyn_malloc); + } } /* ================= TopK ================= */ diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 79c0a17712866d6847c6ed7f040b0a968c28282f..e39146844a91a3bf90224884b556825b2dccf385 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -151,6 +151,7 @@ MGB_DEFINE_OPR_CLASS(CondTake, intl::CondTakeBase) // { void init_output_static_infer_desc() override; void scn_do_execute() override; void add_input_layout_constraint() override; + NodeProp* do_make_node_prop() const override; public: CondTake(VarNode *data, VarNode *mask, diff --git a/src/opr/test/misc.cpp b/src/opr/test/misc.cpp index e754130489b723d4f9216d28b58c8ec82c337096..68195b1862c2f3eace51cd9f2540e90da4e06e5c 100644 --- a/src/opr/test/misc.cpp +++ b/src/opr/test/misc.cpp @@ -256,20 +256,25 @@ TEST(TestOprMisc, CondTake) { run(mki({100})); } -TEST(TestOprMisc, CondTakeEmptyOut) { +TEST(TestOprMisc, CondTakeEmptyIO) { using Param = opr::CondTake::Param; HostTensorGenerator<> gen; - auto host_x = gen({1}); - host_x->ptr()[0] = 1; - auto graph = ComputingGraph::make(); - auto x = opr::Host2DeviceCopy::make(*graph, host_x); - auto out = opr::CondTake::make(x, x, {Param::Mode::LT}); - HostTensorND host_out0, host_out1; - auto func = graph->compile({make_callback_copy(out[0], host_out0), - make_callback_copy(out[1], host_out1)}); - func->execute(); - ASSERT_EQ(TensorShape{0}, host_out0.shape()); - ASSERT_EQ(TensorShape{0}, host_out1.shape()); + auto check = [&](const TensorShape& shp) { + auto host_x = gen(shp); + auto graph = ComputingGraph::make(); + auto x = opr::Host2DeviceCopy::make(*graph, host_x); + auto y = x + 1; + auto out = opr::CondTake::make(x, y, {Param::Mode::EQ}); + HostTensorND host_out0, host_out1; + auto func = graph->compile({make_callback_copy(out[0], host_out0), + make_callback_copy(out[1], host_out1)}); + func->execute(); + ASSERT_EQ(TensorShape{0}, host_out0.shape()); + ASSERT_EQ(TensorShape{0}, host_out1.shape()); + }; + check({1}); + check({0}); + check({1, 0}); } TEST(TestOprMisc, TopKValueOnly) {