From abb7f6eff9503bcb407099430e52b8a1668c0a2e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 17 Jul 2023 19:02:08 +0800 Subject: [PATCH] fix(src/atlas): support static input batch and dynamic output batch GitOrigin-RevId: 78df430e68e417c64e1f5aae89bd39b1fb91cb60 --- src/opr/impl/atlas_runtime_op.cpp | 63 ++++++++++++++++++- .../include/megbrain/opr/atlas_runtime_op.h | 3 + 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 0d15ad6e3..b8b1daf50 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -15,12 +15,30 @@ namespace { /** * \brief get mgb shape from acl shape, batch from mgb */ -TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) { +TensorShape acl_shape_to_mgb_shape_for_output( + aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size, + aclmdlIODims acl_shape, size_t batch) { TensorShape ret; ret.ndim = acl_shape.dimCount; for (size_t i = 0; i < ret.ndim; ++i) { ret[i] = acl_shape.dims[i]; } + if (acl_shape.dims[0] == -1) { + batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx); + size_t chw = output_dtype_size; + for (size_t i = 1; i < ret.ndim; ++i) { + chw *= ret[i]; + } + mgb_assert( + batch % chw == 0, + "When the input batch is static and the output batch is dynamic, it is " + "necessary to reconfigure the output batch. The output size obtained " + "from the aclmdlGetOutputSizeByIndex interface should be evenly " + "divided by " + "shapes other than the batch. expect 0, but got %zu\n", + batch % chw); + batch /= chw; + } ret[0] = batch; return ret; } @@ -332,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() { for (size_t i = 0; i < nr_outputs; i++) { auto value_pair = output_getter.get(batch, i); size_t output_size = value_pair.second; - if (enable_dynamic_batch) { + if (enable_dynamic_batch || m_dyn_batch_output[i]) { output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); } aclDataBuffer* output_db = @@ -343,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() { "%zu:%s.", i, output(i)->cname()); aclmdlAddDatasetBuffer(model_outputs, output_db); + + if (m_dyn_batch_output[i]) { + auto tensor_ndim = output(0)->shape().ndim; + std::vector tensor_shape(tensor_ndim, 0); + for (size_t j = 0; j < tensor_ndim; j++) { + tensor_shape[j] = output(0)->shape()[j]; + } + aclTensorDesc* tensorDesc = aclCreateTensorDesc( + aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim, + tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i)); + aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i); + } } MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); @@ -351,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() { MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); } for (size_t i = 0; i < nr_outputs; ++i) { + if (m_dyn_batch_output[i]) { + const DeviceTensorND old_dev_tensor = output(i)->dev_tensor(); + + auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i); + + TensorShape new_shape; + new_shape.ndim = aclGetTensorDescNumDims(new_output_desc); + mgb_assert( + new_shape.ndim == old_dev_tensor.layout().ndim, + "for static input batch and dynamic output batch, the output " + "ndim should be consistent with the one before calling " + "aclmdlExecute(), so expect %zu, but got %zu", + old_dev_tensor.layout().ndim, new_shape.ndim); + for (size_t j = 0; j < new_shape.ndim; j++) { + new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j); + } + + TensorLayout new_layout{ + new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()}; + DeviceTensorND new_dev_tensor{ + old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(), + old_dev_tensor.format()}; + new_dev_tensor.reset(old_dev_tensor.storage(), new_layout); + output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor); + } aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i); MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); } @@ -387,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape( for (size_t i = 0; i < out_shape.size(); ++i) { aclmdlIODims output_dims; MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims)); - out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size); + out_shape[i] = acl_shape_to_mgb_shape_for_output( + m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size); + m_dyn_batch_output.push_back(output_dims.dims[0] == -1); } } diff --git a/src/opr/include/megbrain/opr/atlas_runtime_op.h b/src/opr/include/megbrain/opr/atlas_runtime_op.h index abd8c87ca..f8811a0e9 100644 --- a/src/opr/include/megbrain/opr/atlas_runtime_op.h +++ b/src/opr/include/megbrain/opr/atlas_runtime_op.h @@ -64,6 +64,9 @@ private: //! Atlas need a 64bit device tensor to hold dynamic batch state DeviceTensorND m_dyn_batch_tensor; SmallVector m_dyn_batch_choices; + //! Used when the input batch is static and the output batch is dynamic. Different + //! from the case where the input batch is dynamic and the output batch is dynamic + mutable SmallVector m_dyn_batch_output; }; } // namespace opr -- GitLab