From dea5278172b7e0801e5374af8245d93fcda39b7f Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 25 Jun 2021 16:34:58 +0800 Subject: [PATCH] feat(mgb/opr): let PowC & TypeCvt support empty IO GitOrigin-RevId: f97b3005fd3d60c7c8d1159672debf3a5e30cc64 --- src/opr/impl/basic_arith.cpp | 28 ++++++++++++++++++- src/opr/include/megbrain/opr/basic_arith.h | 2 ++ src/opr/test/basic_arith/others.cpp | 32 ++++++++++++++++++++++ 3 files changed, 61 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/basic_arith.cpp b/src/opr/impl/basic_arith.cpp index c3d29398c..fc24d0533 100644 --- a/src/opr/impl/basic_arith.cpp +++ b/src/opr/impl/basic_arith.cpp @@ -776,6 +776,10 @@ void TypeCvt::perform(DeviceTensorND &dest, intl::UniqPtrWithCN &opr) { mgb_assert(src.comp_node() == opr.comp_node()); mgb_assert(dest_type.valid()); + if (src.empty()) { + mgb_assert(dest.empty()); + return; + } if (src.dtype() == dest_type) { dest.copy_from(src); return; @@ -1739,7 +1743,13 @@ void Reduce::record_execute_deps(ExecDependencyArray& deps) { MGB_DYN_TYPE_OBJ_FINAL_IMPL(PowC); -MEGDNN_OPR_CTOR_INIT1(PowC, ssprintf("powc_%g", param.exp)) +PowC::PowC(VarNode *i0, const Param ¶m, const OperatorNodeConfig &config) + : Super(OperatorNodeBaseCtorParam{ i0->owner_graph(), config, ssprintf("powc_%g", param.exp), {i0}} ) { + init_megdnn_opr(*this, param); + add_input({i0}); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); + intl::MegDNNOprInitPostCtor::apply(*this); +} SymbolVar PowC::make(SymbolVar x, const Param& param, const OperatorNodeConfig& config) { @@ -1778,6 +1788,22 @@ void PowC::init_output_static_infer_desc() { {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_value}); } +void PowC::scn_do_execute() { + if (input(0)->dev_tensor().empty()) { + mgb_assert(output(0)->dev_tensor().empty()); + return; + } + mgb_assert(!output(0)->dev_tensor().empty()); + Super::scn_do_execute(); +} + +PowC::NodeProp* PowC::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + #if MGB_ENABLE_GRAD MGB_IMPL_OPR_GRAD(PowC) { auto exp = opr.param().exp; diff --git a/src/opr/include/megbrain/opr/basic_arith.h b/src/opr/include/megbrain/opr/basic_arith.h index 848ecc44e..73a67757f 100644 --- a/src/opr/include/megbrain/opr/basic_arith.h +++ b/src/opr/include/megbrain/opr/basic_arith.h @@ -352,6 +352,8 @@ MGB_DEFINE_OPR_CLASS(PowC, intl::MegDNNOprWrapperFwd) // { void add_input_layout_constraint() override; void init_output_static_infer_desc() override; void mem_plan_fwd_in2out_writable() override; + NodeProp* do_make_node_prop() const override; + void scn_do_execute() override; public: PowC(VarNode* inp, const Param& param, const OperatorNodeConfig& config); diff --git a/src/opr/test/basic_arith/others.cpp b/src/opr/test/basic_arith/others.cpp index 951a2aa4e..5179b0be0 100644 --- a/src/opr/test/basic_arith/others.cpp +++ b/src/opr/test/basic_arith/others.cpp @@ -589,6 +589,23 @@ TEST(TestOprBasicArith, TypeCvtFromBool) { ASSERT_EQ(TensorShape({2}), host_y.shape()); } +TEST(TestOprBasicArith, TypeCvtPerformEmptyIO) { + HostTensorGenerator<> gen; + auto cn = CompNode::load("xpu0"); + auto host_x = gen({2, 0, 3, 4}); + auto dev_x = std::make_shared(cn); + dev_x->copy_from(*host_x); + + auto dev_y = std::make_shared(cn, dtype::Int32{}); + dev_y->resize(dev_x->shape()); + + auto dnn_opr = opr::intl::create_megdnn_opr(cn); + ASSERT_NO_THROW(opr::TypeCvt::perform(*dev_y, dtype::Int32{}, *dev_x, dnn_opr)); + ASSERT_TRUE(dev_y->empty()); + ASSERT_TRUE(dev_y->shape().is_empty()); + MGB_ASSERT_SHAPE_EQ(dev_x->shape(), dev_y->shape()); +} + TEST(TestOprBasicArith, ElemwiseMemFwd) { auto graph = ComputingGraph::make(); graph->options().graph_opt_level = 0; @@ -756,4 +773,19 @@ TEST(TestOprBasicArith, PowCInfer) { run(true); } +TEST(TestOprBasicArith, PowCEmptyIO) { + HostTensorGenerator<> gen; + auto graph = ComputingGraph::make(); + // empty input + auto host_x = gen({4, 0, 2, 3}); + auto x = opr::Host2DeviceCopy::make(*graph, host_x), + y = opr::PowC::make(x, 3.f); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + ASSERT_NO_THROW(func->execute().wait()); + ASSERT_TRUE(host_y.empty()); + ASSERT_TRUE(host_y.shape().is_empty()); + MGB_ASSERT_SHAPE_EQ(host_x->shape(), host_y.shape()); +} + // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} -- GitLab