提交 184e1311 编写于 作者: M Megvii Engine Team

fix(mgb/opr): move NVOF opr's shape inference to execute part

GitOrigin-RevId: 883c55e4a09e243e5e50e320ebec097695e88bc9
上级 d168cea4
...@@ -166,23 +166,11 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf); ...@@ -166,23 +166,11 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NvOf);
NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config) NvOf::NvOf(VarNode* opr, const Param& param, const OperatorNodeConfig& config)
: Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} { : Super{opr->owner_graph(), config, "NvOf", {opr}}, m_param{param} {
constexpr size_t NDIM = 5;
mgb_assert(opr->dtype() == dtype::Uint8()); mgb_assert(opr->dtype() == dtype::Uint8());
add_input({opr}); add_input({opr});
//! NvOf hava only one output //! NvOf hava only one output
add_output(None); add_output(None);
mgb_log_debug("init nvof engine with precision: %u", m_param.precision); mgb_log_debug("init nvof engine with precision: %u", m_param.precision);
auto input_shape = this->input()[0]->shape();
//! nvof input format: nthwc4
mgb_assert(input_shape.ndim == NDIM);
//! now only support RGBA format channel data
mgb_assert(input_shape[4] == 4);
for (size_t i = 0; i < NDIM; i++) {
vshape.push_back(input_shape[i]);
}
} }
void NvOf::init_output_dtype() { void NvOf::init_output_dtype() {
...@@ -195,6 +183,10 @@ SymbolVar NvOf::make(SymbolVar opr, const Param& param, ...@@ -195,6 +183,10 @@ SymbolVar NvOf::make(SymbolVar opr, const Param& param,
} }
void NvOf::scn_do_execute() { void NvOf::scn_do_execute() {
auto input_shape = this->input()[0]->shape();
for (size_t i = 0; i < 5; i++) {
vshape.push_back(input_shape[i]);
}
auto c = this->comp_node(); auto c = this->comp_node();
//! comp_node may init on CUDA or CPU, eg: lar with --cpu //! comp_node may init on CUDA or CPU, eg: lar with --cpu
//! if ON CUDA, need sync, caused by we use different stream //! if ON CUDA, need sync, caused by we use different stream
...@@ -229,6 +221,10 @@ void NvOf::init_output_static_infer_desc() { ...@@ -229,6 +221,10 @@ void NvOf::init_output_static_infer_desc() {
using namespace cg::static_infer; using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) { auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto ishp = iv.val.at(0).shape(); auto ishp = iv.val.at(0).shape();
//! nvof input format: nthwc4
mgb_assert(ishp.ndim == 5);
//! now only support RGBA format channel data
mgb_assert(ishp[4] == 4);
SmallVector<size_t> tv; SmallVector<size_t> tv;
tv.push_back(ishp[0]); tv.push_back(ishp[0]);
tv.push_back(ishp[1] - 1); tv.push_back(ishp[1] - 1);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册