提交 abb7f6ef 编写于 作者: M Megvii Engine Team

fix(src/atlas): support static input batch and dynamic output batch

GitOrigin-RevId: 78df430e68e417c64e1f5aae89bd39b1fb91cb60
上级 de084f92
...@@ -15,12 +15,30 @@ namespace { ...@@ -15,12 +15,30 @@ namespace {
/** /**
* \brief get mgb shape from acl shape, batch from mgb * \brief get mgb shape from acl shape, batch from mgb
*/ */
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) { TensorShape acl_shape_to_mgb_shape_for_output(
aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size,
aclmdlIODims acl_shape, size_t batch) {
TensorShape ret; TensorShape ret;
ret.ndim = acl_shape.dimCount; ret.ndim = acl_shape.dimCount;
for (size_t i = 0; i < ret.ndim; ++i) { for (size_t i = 0; i < ret.ndim; ++i) {
ret[i] = acl_shape.dims[i]; ret[i] = acl_shape.dims[i];
} }
if (acl_shape.dims[0] == -1) {
batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx);
size_t chw = output_dtype_size;
for (size_t i = 1; i < ret.ndim; ++i) {
chw *= ret[i];
}
mgb_assert(
batch % chw == 0,
"When the input batch is static and the output batch is dynamic, it is "
"necessary to reconfigure the output batch. The output size obtained "
"from the aclmdlGetOutputSizeByIndex interface should be evenly "
"divided by "
"shapes other than the batch. expect 0, but got %zu\n",
batch % chw);
batch /= chw;
}
ret[0] = batch; ret[0] = batch;
return ret; return ret;
} }
...@@ -332,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() { ...@@ -332,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() {
for (size_t i = 0; i < nr_outputs; i++) { for (size_t i = 0; i < nr_outputs; i++) {
auto value_pair = output_getter.get(batch, i); auto value_pair = output_getter.get(batch, i);
size_t output_size = value_pair.second; size_t output_size = value_pair.second;
if (enable_dynamic_batch) { if (enable_dynamic_batch || m_dyn_batch_output[i]) {
output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
} }
aclDataBuffer* output_db = aclDataBuffer* output_db =
...@@ -343,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() { ...@@ -343,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() {
"%zu:%s.", "%zu:%s.",
i, output(i)->cname()); i, output(i)->cname());
aclmdlAddDatasetBuffer(model_outputs, output_db); aclmdlAddDatasetBuffer(model_outputs, output_db);
if (m_dyn_batch_output[i]) {
auto tensor_ndim = output(0)->shape().ndim;
std::vector<int64_t> tensor_shape(tensor_ndim, 0);
for (size_t j = 0; j < tensor_ndim; j++) {
tensor_shape[j] = output(0)->shape()[j];
}
aclTensorDesc* tensorDesc = aclCreateTensorDesc(
aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim,
tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i));
aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i);
}
} }
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs)); MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));
...@@ -351,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() { ...@@ -351,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() {
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
} }
for (size_t i = 0; i < nr_outputs; ++i) { for (size_t i = 0; i < nr_outputs; ++i) {
if (m_dyn_batch_output[i]) {
const DeviceTensorND old_dev_tensor = output(i)->dev_tensor();
auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i);
TensorShape new_shape;
new_shape.ndim = aclGetTensorDescNumDims(new_output_desc);
mgb_assert(
new_shape.ndim == old_dev_tensor.layout().ndim,
"for static input batch and dynamic output batch, the output "
"ndim should be consistent with the one before calling "
"aclmdlExecute(), so expect %zu, but got %zu",
old_dev_tensor.layout().ndim, new_shape.ndim);
for (size_t j = 0; j < new_shape.ndim; j++) {
new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j);
}
TensorLayout new_layout{
new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()};
DeviceTensorND new_dev_tensor{
old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(),
old_dev_tensor.format()};
new_dev_tensor.reset(old_dev_tensor.storage(), new_layout);
output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor);
}
aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i); aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr)); MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
} }
...@@ -387,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape( ...@@ -387,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape(
for (size_t i = 0; i < out_shape.size(); ++i) { for (size_t i = 0; i < out_shape.size(); ++i) {
aclmdlIODims output_dims; aclmdlIODims output_dims;
MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims)); MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size); out_shape[i] = acl_shape_to_mgb_shape_for_output(
m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size);
m_dyn_batch_output.push_back(output_dims.dims[0] == -1);
} }
} }
......
...@@ -64,6 +64,9 @@ private: ...@@ -64,6 +64,9 @@ private:
//! Atlas need a 64bit device tensor to hold dynamic batch state //! Atlas need a 64bit device tensor to hold dynamic batch state
DeviceTensorND m_dyn_batch_tensor; DeviceTensorND m_dyn_batch_tensor;
SmallVector<size_t> m_dyn_batch_choices; SmallVector<size_t> m_dyn_batch_choices;
//! Used when the input batch is static and the output batch is dynamic. Different
//! from the case where the input batch is dynamic and the output batch is dynamic
mutable SmallVector<bool> m_dyn_batch_output;
}; };
} // namespace opr } // namespace opr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册