diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 0d15ad6e34ed1e24739a0b2f8f0cf78b10403eb3..b8b1daf50048e52026329a1ccd891831083ca3f2 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -15,12 +15,30 @@ namespace { /** * \brief get mgb shape from acl shape, batch from mgb */ -TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) { +TensorShape acl_shape_to_mgb_shape_for_output( + aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size, + aclmdlIODims acl_shape, size_t batch) { TensorShape ret; ret.ndim = acl_shape.dimCount; for (size_t i = 0; i < ret.ndim; ++i) { ret[i] = acl_shape.dims[i]; } + if (acl_shape.dims[0] == -1) { + batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx); + size_t chw = output_dtype_size; + for (size_t i = 1; i < ret.ndim; ++i) { + chw *= ret[i]; + } + mgb_assert( + batch % chw == 0, + "When the input batch is static and the output batch is dynamic, it is " + "necessary to reconfigure the output batch. The output size obtained " + "from the aclmdlGetOutputSizeByIndex interface should be evenly " + "divided by " + "shapes other than the batch. expect 0, but got %zu\n", + batch % chw); + batch /= chw; + } ret[0] = batch; return ret; } @@ -332,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() { for (size_t i = 0; i < nr_outputs; i++) { auto value_pair = output_getter.get(batch, i); size_t output_size = value_pair.second; - if (enable_dynamic_batch) { + if (enable_dynamic_batch || m_dyn_batch_output[i]) { output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); } aclDataBuffer* output_db = @@ -343,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() { "%zu:%s.", i, output(i)->cname()); aclmdlAddDatasetBuffer(model_outputs, output_db); + + if (m_dyn_batch_output[i]) { + auto tensor_ndim = output(0)->shape().ndim; + std::vector tensor_shape(tensor_ndim, 0); + for (size_t j = 0; j < tensor_ndim; j++) { + tensor_shape[j] = output(0)->shape()[j]; + } + aclTensorDesc* tensorDesc = aclCreateTensorDesc( + aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim, + tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i)); + aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i); + } } MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); @@ -351,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() { MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); } for (size_t i = 0; i < nr_outputs; ++i) { + if (m_dyn_batch_output[i]) { + const DeviceTensorND old_dev_tensor = output(i)->dev_tensor(); + + auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i); + + TensorShape new_shape; + new_shape.ndim = aclGetTensorDescNumDims(new_output_desc); + mgb_assert( + new_shape.ndim == old_dev_tensor.layout().ndim, + "for static input batch and dynamic output batch, the output " + "ndim should be consistent with the one before calling " + "aclmdlExecute(), so expect %zu, but got %zu", + old_dev_tensor.layout().ndim, new_shape.ndim); + for (size_t j = 0; j < new_shape.ndim; j++) { + new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j); + } + + TensorLayout new_layout{ + new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()}; + DeviceTensorND new_dev_tensor{ + old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(), + old_dev_tensor.format()}; + new_dev_tensor.reset(old_dev_tensor.storage(), new_layout); + output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor); + } aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i); MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); } @@ -387,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape( for (size_t i = 0; i < out_shape.size(); ++i) { aclmdlIODims output_dims; MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims)); - out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size); + out_shape[i] = acl_shape_to_mgb_shape_for_output( + m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size); + m_dyn_batch_output.push_back(output_dims.dims[0] == -1); } } diff --git a/src/opr/include/megbrain/opr/atlas_runtime_op.h b/src/opr/include/megbrain/opr/atlas_runtime_op.h index abd8c87ca1c521c0c48cb438755030445ddaf365..f8811a0e936ca58b1666128a4578e7fda2291b26 100644 --- a/src/opr/include/megbrain/opr/atlas_runtime_op.h +++ b/src/opr/include/megbrain/opr/atlas_runtime_op.h @@ -64,6 +64,9 @@ private: //! Atlas need a 64bit device tensor to hold dynamic batch state DeviceTensorND m_dyn_batch_tensor; SmallVector m_dyn_batch_choices; + //! Used when the input batch is static and the output batch is dynamic. Different + //! from the case where the input batch is dynamic and the output batch is dynamic + mutable SmallVector m_dyn_batch_output; }; } // namespace opr