From bab0caab724244342c05a89ca8f68b87b86f6dc0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 31 Mar 2021 17:08:12 +0800 Subject: [PATCH] fix(mgb/opr): add layout constraint on the input of Cumsum GitOrigin-RevId: d2d108e2786806aa3f4c44ebfc0bf5f094cdab42 --- src/opr/impl/misc.cpp | 4 ++++ src/opr/include/megbrain/opr/misc.h | 1 + 2 files changed, 5 insertions(+) diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index f557e2b6..9d0820d8 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -132,6 +132,10 @@ void Cumsum::scn_do_execute() { intl::get_megdnn_workspace_from_var(output().back())); } +void Cumsum::add_input_layout_constraint() { + input(0)->add_layout_constraint_contiguous(); +} + void Cumsum::init_output_static_infer_desc() { using namespace cg::static_infer; auto infer_shape = [](TensorShape& dest, const InpVal& iv) { diff --git a/src/opr/include/megbrain/opr/misc.h b/src/opr/include/megbrain/opr/misc.h index 63d17b59..f3559d95 100644 --- a/src/opr/include/megbrain/opr/misc.h +++ b/src/opr/include/megbrain/opr/misc.h @@ -86,6 +86,7 @@ public: //! cumulative sum along given axis MGB_DEFINE_OPR_CLASS(Cumsum, cg::SingleCNOperatorNodeBaseT< mixin::MegDNNOprHolderImpl>) // { + void add_input_layout_constraint() override; public: Cumsum(VarNode *src, const Param ¶m, const OperatorNodeConfig &config); -- GitLab