提交 b2827cb1 编写于 作者: M Megvii Engine Team

feat(opr): let Dot, MatrixMul and BatchedMatrixMul support empty input

GitOrigin-RevId: 10a3c5b106d4013f486d8b99593959d90c760885
上级 50f73877
...@@ -142,6 +142,26 @@ def test_matmul(): ...@@ -142,6 +142,26 @@ def test_matmul():
) )
@pytest.mark.parametrize(
"shape_a, shape_b", [((0,), (0,)), ((10, 0), (0, 10)), ((3, 10, 0), (3, 0, 10)),],
)
@pytest.mark.parametrize("is_symbolic", [None, True, False])
def test_matmul_empty_tensor(shape_a, shape_b, is_symbolic):
def func(a, b):
return F.matmul(a, b)
if is_symbolic is not None:
func = jit.trace(symbolic=is_symbolic)(func)
a = tensor(np.random.randn(*shape_a))
b = tensor(np.random.randn(*shape_b))
for _ in range(3):
out = func(a, b)
assert np.all(out.numpy() == 0)
if is_symbolic is None:
break
def test_interpolate(): def test_interpolate():
def linear_interpolate(): def linear_interpolate():
inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2)) inp = tensor(np.arange(1, 3, dtype=np.float32).reshape(1, 1, 2))
......
...@@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param, ...@@ -45,6 +45,7 @@ MatrixMul::MatrixMul(VarNode* a, VarNode* b, const Param& param,
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
m_policy = policy; m_policy = policy;
add_input({a, b}); add_input({a, b});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }
SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, SymbolVar MatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
...@@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() { ...@@ -61,6 +62,15 @@ void MatrixMul::init_output_dtype() {
output(0)->dtype(output_dtype); output(0)->dtype(output_dtype);
} }
MatrixMul::NodeProp* MatrixMul::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) { bool MatrixMul::check_layout(const TensorLayout& layout, int transpose) {
mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s", mgb_assert(layout.ndim == 2, "input to MatrixMul must be 2-dim; got %s",
layout.to_string().c_str()); layout.to_string().c_str());
...@@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() { ...@@ -138,6 +148,17 @@ void MatrixMul::scn_do_execute() {
auto inp0 = input(0)->dev_tensor().as_megdnn(), auto inp0 = input(0)->dev_tensor().as_megdnn(),
inp1 = input(1)->dev_tensor().as_megdnn(), inp1 = input(1)->dev_tensor().as_megdnn(),
out = output(0)->dev_tensor().as_megdnn(); out = output(0)->dev_tensor().as_megdnn();
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
if (!out.layout.is_empty()) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())->
create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(out, {});
}
return;
}
auto transpose = [](TensorLayout& layout, bool& trans) { auto transpose = [](TensorLayout& layout, bool& trans) {
if (!check_layout(layout, 0)) { if (!check_layout(layout, 0)) {
mgb_assert(check_layout(layout, 1)); mgb_assert(check_layout(layout, 1));
...@@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param, ...@@ -193,6 +214,7 @@ BatchedMatrixMul::BatchedMatrixMul(VarNode* a, VarNode* b, const Param& param,
init_megdnn_opr(*this, param); init_megdnn_opr(*this, param);
m_policy = policy; m_policy = policy;
add_input({a, b}); add_input({a, b});
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
} }
SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param, SymbolVar BatchedMatrixMul::make(SymbolVar a, SymbolVar b, const Param& param,
...@@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() { ...@@ -229,6 +251,15 @@ void BatchedMatrixMul::init_output_dtype() {
output(0)->dtype(output_dtype); output(0)->dtype(output_dtype);
} }
BatchedMatrixMul::NodeProp* BatchedMatrixMul::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
bool BatchedMatrixMul::check_layout(const TensorLayout& layout, bool BatchedMatrixMul::check_layout(const TensorLayout& layout,
bool transpose) { bool transpose) {
int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2; int lhs = (transpose) ? 2 : 1, rhs = (transpose) ? 1 : 2;
...@@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() { ...@@ -294,6 +325,17 @@ void BatchedMatrixMul::scn_do_execute() {
auto inp0 = input(0)->dev_tensor().as_megdnn(), auto inp0 = input(0)->dev_tensor().as_megdnn(),
inp1 = input(1)->dev_tensor().as_megdnn(), inp1 = input(1)->dev_tensor().as_megdnn(),
out = output(0)->dev_tensor().as_megdnn(); out = output(0)->dev_tensor().as_megdnn();
if ((inp0.layout.is_empty() || inp1.layout.is_empty())) {
if (!out.layout.is_empty()) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())->
create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(out, {});
}
return;
}
auto transpose = [](TensorLayout& layout, bool& trans) { auto transpose = [](TensorLayout& layout, bool& trans) {
if (!check_layout(layout, false)) { if (!check_layout(layout, false)) {
mgb_assert(check_layout(layout, true)); mgb_assert(check_layout(layout, true));
...@@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config): ...@@ -354,6 +396,7 @@ Dot::Dot(VarNode *opr0, VarNode *opr1, const OperatorNodeConfig &config):
{ {
init_megdnn_opr(*this, {}); init_megdnn_opr(*this, {});
add_input({opr0, opr1}, AddInputSortType::CUR_ADDED); add_input({opr0, opr1}, AddInputSortType::CUR_ADDED);
output(0)->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE);
static_assert(std::is_empty<Param>::value, "Dot param should be empty"); static_assert(std::is_empty<Param>::value, "Dot param should be empty");
mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED && mgb_assert(opr0->dtype().category() != DTypeCategory::QUANTIZED &&
opr1->dtype().category() != DTypeCategory::QUANTIZED, opr1->dtype().category() != DTypeCategory::QUANTIZED,
...@@ -406,10 +449,28 @@ void Dot::scn_do_execute() { ...@@ -406,10 +449,28 @@ void Dot::scn_do_execute() {
i1.layout.stride[0] = 0; i1.layout.stride[0] = 0;
} }
} }
if ((i0.layout.is_empty() || i1.layout.is_empty())) {
if (!m_fill_opr) {
m_fill_opr = intl::get_megdnn_handle(comp_node())->
create_operator<megdnn::Fill>();
}
m_fill_opr->param() = 0;
m_fill_opr->exec(output(0)->dev_tensor().as_megdnn(), {});
return;
}
megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(), megdnn_opr()->exec(i0, i1, output(0)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(1))); intl::get_megdnn_workspace_from_var(output(1)));
} }
Dot::NodeProp* Dot::do_make_node_prop() const {
auto ret = Super::do_make_node_prop();
ret->add_dep_type_existing_var(input(0),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
ret->add_dep_type_existing_var(input(1),
NodeProp::DepType::VALUE_ALLOW_EMPTY);
return ret;
}
void Dot::add_input_layout_constraint() { void Dot::add_input_layout_constraint() {
auto check = [](const TensorLayout &ly) { auto check = [](const TensorLayout &ly) {
mgb_throw_if(ly.ndim != 1, GraphError, mgb_throw_if(ly.ndim != 1, GraphError,
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs/general.h"
#include "megdnn/oprs/linalg.h" #include "megdnn/oprs/linalg.h"
namespace mgb { namespace mgb {
...@@ -40,6 +41,7 @@ private: ...@@ -40,6 +41,7 @@ private:
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_dtype() override; void init_output_dtype() override;
NodeProp* do_make_node_prop() const override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const TensorShapeArray& output_shapes)
const override; const override;
...@@ -47,6 +49,7 @@ private: ...@@ -47,6 +49,7 @@ private:
//! store the policy of all transpose situations //! store the policy of all transpose situations
megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; megdnn::ExecutionPolicy m_cadidate_execution_policies[4];
std::unique_ptr<megdnn::Fill> m_fill_opr;
}; };
/*! /*!
...@@ -70,6 +73,7 @@ private: ...@@ -70,6 +73,7 @@ private:
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void init_output_dtype() override; void init_output_dtype() override;
void scn_do_execute() override; void scn_do_execute() override;
NodeProp* do_make_node_prop() const override;
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const TensorShapeArray& output_shapes)
const override; const override;
...@@ -77,6 +81,7 @@ private: ...@@ -77,6 +81,7 @@ private:
static bool check_layout(const TensorLayout& layout, bool transpose); static bool check_layout(const TensorLayout& layout, bool transpose);
//! store the policy of all transpose situations //! store the policy of all transpose situations
megdnn::ExecutionPolicy m_cadidate_execution_policies[4]; megdnn::ExecutionPolicy m_cadidate_execution_policies[4];
std::unique_ptr<megdnn::Fill> m_fill_opr;
}; };
/*! /*!
...@@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT< ...@@ -101,7 +106,9 @@ MGB_DEFINE_OPR_CLASS(Dot, cg::SingleCNOperatorNodeBaseT<
void add_input_layout_constraint() override; void add_input_layout_constraint() override;
void scn_do_execute() override; void scn_do_execute() override;
void init_output_static_infer_desc() override; void init_output_static_infer_desc() override;
NodeProp* do_make_node_prop() const override;
void record_execute_deps(ExecDependencyArray &deps) override; void record_execute_deps(ExecDependencyArray &deps) override;
std::unique_ptr<megdnn::Fill> m_fill_opr;
}; };
MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse); MGB_DEFINE_MEGDNN_OPR_WRAPPER_FWD1(MatrixInverse);
......
...@@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) { ...@@ -94,7 +94,9 @@ void run_sgemm_test(bool transa, bool transb) {
Checker(make_graph, fwd) Checker(make_graph, fwd)
.run({mkx(4, 6), mky(6, 2)}, opt) .run({mkx(4, 6), mky(6, 2)}, opt)
.run({mkx(2, 3), mky(3, 100)}, opt) .run({mkx(2, 3), mky(3, 100)}, opt)
.run({mkx(20, 3), mky(3, 20)}, opt); .run({mkx(20, 3), mky(3, 20)}, opt)
.run({mkx(10, 0), mky(0, 10)}, opt)
.run({mkx(0, 0), mky(0, 0)}, opt);
} }
#define FWD_BATCH_GEMM(dt_src, dt_dst) \ #define FWD_BATCH_GEMM(dt_src, dt_dst) \
...@@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) { ...@@ -143,7 +145,9 @@ void run_batched_sgemm_test(bool transa, bool transb) {
Checker(make_graph, fwd) Checker(make_graph, fwd)
.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) .run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt)
.run({mkx(3, 0, 2), mky(3, 2, 0)}, opt)
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt);
} }
auto gen_fp16 = [](HostTensorND& dest) { auto gen_fp16 = [](HostTensorND& dest) {
...@@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) { ...@@ -198,6 +202,7 @@ void run_batched_hgemm_test(bool transa, bool transb) {
checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt)
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt);
} }
...@@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) { ...@@ -236,6 +241,7 @@ void run_batched_igemm_test(bool transa, bool transb) {
checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt) checker.run({mkx(3, 5, 7), mky(3, 7, 2)}, opt)
.run({mkx(64, 1, 2), mky(64, 2, 1)}, opt) .run({mkx(64, 1, 2), mky(64, 2, 1)}, opt)
.run({mkx(64, 10, 0), mky(64, 0, 10)}, opt)
.run({mkx(1, 2, 3), mky(1, 3, 4)}, opt); .run({mkx(1, 2, 3), mky(1, 3, 4)}, opt);
} }
...@@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) { ...@@ -650,7 +656,8 @@ TEST(TestOprBlas, Dot) {
.run({TensorShape{15}, TensorShape{1}}) .run({TensorShape{15}, TensorShape{1}})
.run({TensorShape{1}, TensorShape{16}}) .run({TensorShape{1}, TensorShape{16}})
.run({TensorShape{23}, TensorShape{23}}) .run({TensorShape{23}, TensorShape{23}})
.run({TensorShape{1000}, TensorShape{1000}}); .run({TensorShape{1000}, TensorShape{1000}})
.run({TensorShape{0}, TensorShape{0}});
} }
TEST(TestOprBlas, TransMatMul) { TEST(TestOprBlas, TransMatMul) {
......
...@@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) { ...@@ -250,7 +250,6 @@ DEF_IMPL(void)::do_run(const ShapeInpArray& shapes, const RunOptions& opt) {
for (size_t i = 0; i < nr_out; ++i) { for (size_t i = 0; i < nr_out; ++i) {
if (m_outputs_allow_grad[i]) { if (m_outputs_allow_grad[i]) {
auto nr = m_outputs_truth[i].shape().total_nr_elems(); auto nr = m_outputs_truth[i].shape().total_nr_elems();
mgb_assert(nr, "got empty output");
if (opt.cont_loss_p) { if (opt.cont_loss_p) {
m_loss_p[i]->resize({nr}); m_loss_p[i]->resize({nr});
auto ptr = m_loss_p[i]->template ptr<float>(); auto ptr = m_loss_p[i]->template ptr<float>();
......
...@@ -36,7 +36,7 @@ std::vector<HostTensorND> mgb::numerical_diff_pt2( ...@@ -36,7 +36,7 @@ std::vector<HostTensorND> mgb::numerical_diff_pt2(
resize(cur_inp->shape()); resize(cur_inp->shape());
auto dptr = dest.ptr<float>(); auto dptr = dest.ptr<float>();
mgb_assert(cur_inp->layout().is_contiguous()); mgb_assert(cur_inp->layout().is_contiguous() || cur_inp->layout().is_empty());
auto cur_inp_ptr = cur_inp->ptr<float>(); auto cur_inp_ptr = cur_inp->ptr<float>();
mgb::RealTimer timer; mgb::RealTimer timer;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册