提交 29f7cdb8 编写于 作者: M Megvii Engine Team

fix(mgb/opr): correct nvof out shape computation

GitOrigin-RevId: 16bf086e92125cba867d9b935bb487363602014e
上级 71cc814e
......@@ -633,7 +633,8 @@ def nvof(src: Tensor, precision: int = 1) -> Tensor:
:src shape: input tensor with shape (n, t, h, w, c4).
:src dtype: uint8.
:param precision: 0:NV_OF_PERF_LEVEL_SLOW 1:NV_OF_PERF_LEVEL_MEDIUM 2:NV_OF_PERF_LEVEL_FAST.
:output shape: (n, t-1, h//4, w//4, c2).
:output shape: ``(n, t-1, (h+out_grid_size-1)//out_grid_size, (w+out_grid_size-1)//out_grid_size, c2)``.
By default, out_grid_size = 4.
:output dtype: int16.
.. code-block:: python
......
......@@ -224,6 +224,7 @@ void NvOf::scn_do_execute() {
void NvOf::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto infer_shape = [](TensorShape& dest, const InpVal& iv) {
auto out_grid_size = NV_OF_OUTPUT_VECTOR_GRID_SIZE_4;
auto ishp = iv.val.at(0).shape();
//! nvof input format: nthwc4
mgb_assert(ishp.ndim == 5);
......@@ -232,8 +233,8 @@ void NvOf::init_output_static_infer_desc() {
SmallVector<size_t> tv;
tv.push_back(ishp[0]);
tv.push_back(ishp[1] - 1);
tv.push_back(ishp[2] / 4);
tv.push_back(ishp[3] / 4);
tv.push_back((ishp[2] + out_grid_size - 1) / out_grid_size);
tv.push_back((ishp[3] + out_grid_size - 1) / out_grid_size);
tv.push_back(ishp[4] / 2);
dest = tv;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册