diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 8db40f870e08ccb6879cd19d9d4ff2d01067e2fb..d21363a50868fb7031c280890dfacb9fbd824326 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -142,6 +142,26 @@ def test_matmul(): ) +@pytest.mark.parametrize( + "shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),], +) +@pytest.mark.parametrize("is_symbolic", [None, True, False]) +def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic): + def func(a, b): + return F.matmul(a, b) + + if is_symbolic is not None: + func = jit.trace(symbolic=is_symbolic)(func) + + a = tensor(np.random.randn(*shape_a)) + b = tensor(np.random.randn(*shape_b)) + for _ in range(3): + out = func(a, b) + assert np.all(out.numpy() == 0) + if is_symbolic is None: + break + + def test_interpolate(): def linear_interpolate(): inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) diff --git a/src/opr/impl/blas.cpp b/src/opr/impl/blas.cpp index 336f330fb8f27bec34705b99309c1d7fab1d9ef1..249cbc6aab936ca44ea5db25ac59670e4064472a 100644 --- a/src/opr/impl/blas.cpp +++ b/src/opr/impl/blas.cpp @@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, init_megdnn_opr(*this, param); m_policy = policy; add_input({a, b}); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, @@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() { output(0)->dtype(output_dtype); } +MatrixMul::NodeProp* MatrixMul::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) { mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s", layout.to_string().c_str()); @@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() { auto inp0 = input(0)->dev_tensor().as_megdnn(), inp1 = input(1)->dev_tensor().as_megdnn(), out = output(0)->dev_tensor().as_megdnn(); + if ((inp0.layout.is_empty() || inp1.layout.is_empty())) { + if (!out.layout.is_empty()) { + if (!m_fill_opr) { + m_fill_opr = intl::get_megdnn_handle(comp_node())-> + create_operator(); + } + m_fill_opr->param() = 0; + m_fill_opr->exec(out, {}); + } + return; + } auto transpose = [](TensorLayout& layout, bool& trans) { if (!check_layout(layout, 0)) { mgb_assert(check_layout(layout, 1)); @@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, init_megdnn_opr(*this, param); m_policy = policy; add_input({a, b}); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); } SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, @@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() { output(0)->dtype(output_dtype); } +BatchedMatrixMul::NodeProp* BatchedMatrixMul::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + bool BatchedMatrixMul::check_layout(const TensorLayout& layout, bool transpose) { int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2; @@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() { auto inp0 = input(0)->dev_tensor().as_megdnn(), inp1 = input(1)->dev_tensor().as_megdnn(), out = output(0)->dev_tensor().as_megdnn(); + if ((inp0.layout.is_empty() || inp1.layout.is_empty())) { + if (!out.layout.is_empty()) { + if (!m_fill_opr) { + m_fill_opr = intl::get_megdnn_handle(comp_node())-> + create_operator(); + } + m_fill_opr->param() = 0; + m_fill_opr->exec(out, {}); + } + return; + } auto transpose = [](TensorLayout& layout, bool& trans) { if (!check_layout(layout, false)) { mgb_assert(check_layout(layout, true)); @@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config): { init_megdnn_opr(*this, {}); add_input({opr0, opr1}, AddInputSortType::CUR_ADDED); + output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE); static_assert(std::is_empty::value, "Dot param should be empty"); mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED && opr1->dtype().category() != DTypeCategory::QUANTIZED, @@ -406,10 +449,28 @@ void Dot::scn_do_execute() { i1.layout.stride[0] = 0; } } + if ((i0.layout.is_empty() || i1.layout.is_empty())) { + if (!m_fill_opr) { + m_fill_opr = intl::get_megdnn_handle(comp_node())-> + create_operator(); + } + m_fill_opr->param() = 0; + m_fill_opr->exec(output(0)->dev_tensor().as_megdnn(), {}); + return; + } megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(), intl::get_megdnn_workspace_from_var(output(1))); } +Dot::NodeProp* Dot::do_make_node_prop() const { + auto ret = Super::do_make_node_prop(); + ret->add_dep_type_existing_var(input(0), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + ret->add_dep_type_existing_var(input(1), + NodeProp::DepType::VALUE_ALLOW_EMPTY); + return ret; +} + void Dot::add_input_layout_constraint() { auto check = [](const TensorLayout &ly) { mgb_throw_if(ly.ndim != 1, GraphError, diff --git a/src/opr/include/megbrain/opr/blas.h b/src/opr/include/megbrain/opr/blas.h index 93f9049aeecd4cbe4bb7ce1f37f43541874e46aa..e8752ec94ca052e181187ac48c8a22b0a9a8810c 100644 --- a/src/opr/include/megbrain/opr/blas.h +++ b/src/opr/include/megbrain/opr/blas.h @@ -17,6 +17,7 @@ #include "megbrain/graph.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs/general.h" #include "megdnn/oprs/linalg.h" namespace mgb { @@ -40,6 +41,7 @@ private: void add_input_layout_constraint() override; void scn_do_execute() override; void init_output_dtype() override; + NodeProp* do_make_node_prop() const override; size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const override; @@ -47,6 +49,7 @@ private: //! store the policy of all transpose situations megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; + std::unique_ptr m_fill_opr; }; /*! @@ -70,6 +73,7 @@ private: void add_input_layout_constraint() override; void init_output_dtype() override; void scn_do_execute() override; + NodeProp* do_make_node_prop() const override; size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, const TensorShapeArray& output_shapes) const override; @@ -77,6 +81,7 @@ private: static bool check_layout(const TensorLayout& layout, bool transpose); //! store the policy of all transpose situations megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; + std::unique_ptr m_fill_opr; }; /*! @@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT< void add_input_layout_constraint() override; void scn_do_execute() override; void init_output_static_infer_desc() override; + NodeProp* do_make_node_prop() const override; void record_execute_deps(ExecDependencyArray &deps) override; + std::unique_ptr m_fill_opr; }; MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse); diff --git a/src/opr/test/blas.cpp b/src/opr/test/blas.cpp index d271828e5edc49c40f0c37a3251abd1e69309e2c..76bd8b67310ef26c555ebd3fbb458d790b27ea93 100644 --- a/src/opr/test/blas.cpp +++ b/src/opr/test/blas.cpp @@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) { Checker(make_graph, fwd) .run({mkx(4, 6), mky(6, 2)}, opt) .run({mkx(2, 3), mky(3, 100)}, opt) - .run({mkx(20, 3), mky(3, 20)}, opt); + .run({mkx(20, 3), mky(3, 20)}, opt) + .run({mkx(10, 0), mky(0, 10)}, opt) + .run({mkx(0, 0), mky(0, 0)}, opt); } #define FWD_BATCH_GEMM(dt_src, dt_dst) \ @@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) { Checker(make_graph, fwd) .run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) - .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); + .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt) + .run({mkx(3, 0, 2), mky(3, 2, 0)}, opt) + .run({mkx(64, 10, 0), mky(64, 0, 10)}, opt); } auto gen_fp16 = [](HostTensorND& dest) { @@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) { checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) + .run({mkx(64, 10, 0), mky(64, 0, 10)}, opt) .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); } @@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) { checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) + .run({mkx(64, 10, 0), mky(64, 0, 10)}, opt) .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); } @@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) { .run({TensorShape{15}, TensorShape{1}}) .run({TensorShape{1}, TensorShape{16}}) .run({TensorShape{23}, TensorShape{23}}) - .run({TensorShape{1000}, TensorShape{1000}}); + .run({TensorShape{1000}, TensorShape{1000}}) + .run({TensorShape{0}, TensorShape{0}}); } TEST(TestOprBlas, TransMatMul) { diff --git a/test/src/autocheck.cpp b/test/src/autocheck.cpp index 2bb678affbd7e205dc36d80be58735b7a24ac8f1..24afaf6819c4f27acc55a3ceede8eb64fa4652c3 100644 --- a/test/src/autocheck.cpp +++ b/test/src/autocheck.cpp @@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { for (size_t i = 0; i < nr_out; ++i) { if (m_outputs_allow_grad[i]) { auto nr = m_outputs_truth[i].shape().total_nr_elems(); - mgb_assert(nr, "got empty output"); if (opt.cont_loss_p) { m_loss_p[i]->resize({nr}); auto ptr = m_loss_p[i]->template ptr(); diff --git a/test/src/numerical_diff.cpp b/test/src/numerical_diff.cpp index 7d248bd72745b7f1efdb543ddd7454f8f0a4232a..2ab4064aa74b58fd2a84a80a1c3f262d5a3fcdf5 100644 --- a/test/src/numerical_diff.cpp +++ b/test/src/numerical_diff.cpp @@ -36,7 +36,7 @@ std::vector mgb::numerical_diff_pt2( resize(cur_inp->shape()); auto dptr = dest.ptr(); - mgb_assert(cur_inp->layout().is_contiguous()); + mgb_assert(cur_inp->layout().is_contiguous() || cur_inp->layout().is_empty()); auto cur_inp_ptr = cur_inp->ptr(); mgb::RealTimer timer;