diff --git a/src/opr/impl/cond.cpp b/src/opr/impl/cond.cpp index 90023c17a11819982188c93b8875eec54fb23227..2f9e4daf802589e50abff3661889ce72c2a126da 100644 --- a/src/opr/impl/cond.cpp +++ b/src/opr/impl/cond.cpp @@ -699,7 +699,9 @@ CondExecMark::CondExecMark(VarNode* ppv, const VarNodeArrayView& inputs, for (size_t i = 0; i < inputs.size(); ++i) { add_input({inputs[i]}); - add_output(ssprintf("fwd%zu", i))->dtype(inputs[i]->dtype()); + add_output(ssprintf("fwd%zu", i)) + ->dtype(inputs[i]->dtype()) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } add_input({ppv}); add_equivalence_component>(&m_param); @@ -789,6 +791,10 @@ void CondExecMark::add_input_layout_constraint() { CondExecMark::NodeProp* CondExecMark::do_make_node_prop() const { auto ret = Super::do_make_node_prop(); ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; + for (size_t i = 0; i < input().size() - 1; ++ i) { + ret->add_dep_type_existing_var(input(i), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + } return ret; } @@ -859,7 +865,8 @@ CondExecMerge::CondExecMerge(const VarNodeArrayView& inputs, // 2. dynamic allocator would wait for all inputs to become ready (see // VarNodeMemManager::DynamicAllocOprInfo::host_wait_input_ready), // which would cause infinite waiting for unselected inputs. - ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC); + ovar->add_flag(VarNode::Flag::NO_SYS_MEM_ALLOC) + .add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } MGB_MARK_USED_VAR(mask2str); @@ -1056,7 +1063,9 @@ void CondExecMerge::init_output_static_infer_desc() { desc.infer_func = [this](DeviceTensorND& dest, const InpVal& inp) { auto nr_branch = m_branch_masks.size(); bool found = false, first = true; - for (size_t i = 0; i < nr_branch; ++i) { + auto&& shape = inp.val.at(nr_branch).shape(); + + for (size_t i = 0; i < nr_branch && !shape.is_empty(); ++i) { if (!inp.val[i].value().ptr()[0]) continue; auto&& cur = inp.val.at(nr_branch + i).value(); @@ -1083,7 +1092,6 @@ void CondExecMerge::init_output_static_infer_desc() { } } if (!found) { - auto&& shape = inp.val.at(nr_branch).shape(); if (dest.storage().raw_storage().use_count() > 1) { // likely to be assigned from some input in previous // runs; we create a new tensor to avoid modifying input @@ -1115,6 +1123,7 @@ void CondExecMerge::scn_do_execute() { bool first = true; auto&& forwarded = m_mem_forwarded; + std::vector is_shape_empty(nr_out, false); for (size_t br = 0; br < m_branch_masks.size(); ++br) { if (!m_branch_masks[br]->enabled()) { continue; @@ -1125,6 +1134,10 @@ void CondExecMerge::scn_do_execute() { for (size_t oidx = 0; oidx < nr_out; ++oidx) { bool succ = output(oidx)->reset_dev_tensor_from_other_var( inp(br, oidx)); + if (inp(br, oidx)->shape().is_empty()) { + is_shape_empty[oidx] = true; + continue; + } if (!is_exact_one()) { if (forwarded.empty()) { forwarded.resize(nr_out); @@ -1144,6 +1157,11 @@ void CondExecMerge::scn_do_execute() { auto ovar = output(oidx); auto&& src = inp(br, oidx)->dev_tensor().as_megdnn(); auto&& dest = ovar->dev_tensor().as_megdnn(); + mgb_assert(src.layout.eq_shape(dest.layout), + "shape mismatch: %s vs %s in CondExecMerge", + src.layout.to_string().c_str(), + dest.layout.to_string().c_str()); + if (is_shape_empty[oidx]) continue; if (forwarded[oidx]) { ovar->shape_alloc(ovar->shape()); auto&& own_dest = ovar->dev_tensor().as_megdnn(); @@ -1200,6 +1218,10 @@ CondExecMerge::NodeProp* CondExecMerge::do_make_node_prop() const { // directly ret->dep_map().at(input().back()) = NodeProp::DepType::DEV_COMP_ORDER; } + for (size_t i = 0; i < m_param.nr_output * m_branch_masks.size(); ++ i) { + ret->add_dep_type_existing_var(input(i), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + } return ret; } diff --git a/src/opr/test/cond.cpp b/src/opr/test/cond.cpp index 4b3d4ce13d71a7b2aed3ea4448e9fb511ae4543b..ab8e6d1d28b9fedeb152ab020e616ee8c891af05 100644 --- a/src/opr/test/cond.cpp +++ b/src/opr/test/cond.cpp @@ -14,6 +14,7 @@ #include "megbrain/opr/basic_arith_wrapper.h" #include "megbrain/opr/cond.h" #include "megbrain/opr/io.h" +#include "megbrain/opr/misc.h" #include "megbrain/opr/tensor_manip.h" #include "megbrain/opr/utility.h" #include "megbrain/utils/timer.h" @@ -1285,6 +1286,80 @@ TEST(TestCondExec, MultiShape) { check(host_d2); } +TEST(TestCondExec, EmptyShape) { + HostTensorGenerator<> gen; + auto host_pred = gen({1}); + host_pred->ptr()[0] = 0; + static auto empty_in_empty_out = [](SymbolVar x) { + return x; + }; + static auto empty_in_scalar_out = [](SymbolVar x) { + return opr::Concat::make({x, x.make_scalar(1.f)}, 0); + }; + static auto scalar_in_empty_out = [](SymbolVar x) { + return opr::CondTake::make(x, x, {})[0]; // whether eq 0 + }; + { // EXACT_ONE + auto graph = ComputingGraph::make(); + auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), + empty = opr::ImmutableTensor::make(*graph, *gen({0})), + scalar = pred.make_scalar(1.f), + y0 = empty_in_empty_out(make_one_cond(pred + 1, empty)), + y1 = empty_in_scalar_out(make_one_cond(pred, empty)), + y2 = scalar_in_empty_out(make_one_cond(pred - 1, scalar)), + z = merge_one_out({y0, y1, y2}, MergeMode::EXACT_ONE); + + HostTensorND host_z; + auto func = graph->compile({make_callback_copy(z, host_z)}); + func->execute(); + ASSERT_TRUE(host_z.layout().is_empty()); + + host_pred->ptr()[0] = 1; + func->execute(); + ASSERT_EQ(1.f, host_z.ptr()[0]); + + host_pred->ptr()[0] = 2; + func->execute(); + ASSERT_TRUE(host_z.layout().is_empty()); + } + { // SUM + auto graph = ComputingGraph::make(); + host_pred->ptr()[0] = 1; + auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), + empty = opr::ImmutableTensor::make(*graph, *gen({0})), + scalar = pred.make_scalar(1.f), + y0 = empty_in_empty_out(make_one_cond(pred, empty)), + y1 = scalar_in_empty_out(make_one_cond(pred, scalar)), + z = merge_one_out({y0, y1}, MergeMode::SUM); + + HostTensorND host_z; + auto func = graph->compile({make_callback_copy(z, host_z)}); + func->execute(); + ASSERT_TRUE(host_z.layout().is_empty()); + } + { // TAKE GRAD + auto graph = ComputingGraph::make(); + host_pred->ptr()[0] = 0; + auto pred = opr::Host2DeviceCopy::make(*graph, host_pred), + x = pred.make_scalar(1.2f), + y0 = opr::CondTake::make(make_one_cond(pred + 1, x), pred, {})[0], + y1 = make_one_cond(pred, x.make_scalar(3.4f)), + z = merge_one_out({y0, y1}, MergeMode::EXACT_ONE), + g = cg::grad(z, x); + + HostTensorND host_z, host_g; + auto func = graph->compile({ + make_callback_copy(z, host_z), make_callback_copy(g, host_g)}); + func->execute(); + ASSERT_EQ(1.2f, host_z.ptr()[0]); + ASSERT_EQ(1.f, host_g.ptr()[0]); + + host_pred->ptr()[0] = 1; + func->execute(); + ASSERT_EQ(3.4f, host_z.ptr()[0]); + ASSERT_EQ(0.f, host_g.ptr()[0]); + } +} #endif // MGB_ENABLE_COND_EXEC // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}