diff --git a/src/opr/impl/misc.cpp b/src/opr/impl/misc.cpp index 26bc0a333809cd21a9298ded5632784231c6fda0..ed35b122dac4be6832f4cd7722af3f1334aed3c0 100644 --- a/src/opr/impl/misc.cpp +++ b/src/opr/impl/misc.cpp @@ -166,23 +166,11 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf); NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config) : Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} { - constexpr size_t NDIM = 5; mgb_assert(opr->dtype() == dtype::Uint8()); add_input({opr}); //! NvOf hava only one output add_output(None); - mgb_log_debug("init nvof engine with precision: %u", m_param.precision); - auto input_shape = this->input()[0]->shape(); - - //! nvof input format: nthwc4 - mgb_assert(input_shape.ndim == NDIM); - //! now only support RGBA format channel data - mgb_assert(input_shape[4] == 4); - - for (size_t i = 0; i < NDIM; i++) { - vshape.push_back(input_shape[i]); - } } void NvOf::init_output_dtype() { @@ -195,6 +183,10 @@ SymbolVar NvOf::make(SymbolVar opr, const Param& param, } void NvOf::scn_do_execute() { + auto input_shape = this->input()[0]->shape(); + for (size_t i = 0; i < 5; i++) { + vshape.push_back(input_shape[i]); + } auto c = this->comp_node(); //! comp_node may init on CUDA or CPU, eg: lar with --cpu //! if ON CUDA, need sync, caused by we use different stream @@ -229,6 +221,10 @@ void NvOf::init_output_static_infer_desc() { using namespace cg::static_infer; auto infer_shape = [](TensorShape& dest, const InpVal& iv) { auto ishp = iv.val.at(0).shape(); + //! nvof input format: nthwc4 + mgb_assert(ishp.ndim == 5); + //! now only support RGBA format channel data + mgb_assert(ishp[4] == 4); SmallVector tv; tv.push_back(ishp[0]); tv.push_back(ishp[1] - 1);