提交 abb7f6ef 编写于 作者: M Megvii Engine Team

fix(src/atlas): support static input batch and dynamic output batch

GitOrigin-RevId: 78df430e68e417c64e1f5aae89bd39b1fb91cb60
上级 de084f92
......@@ -15,12 +15,30 @@ namespace {
/**
* \brief get mgb shape from acl shape, batch from mgb
*/
TensorShape acl_shape_to_mgb_shape_for_output(aclmdlIODims acl_shape, size_t batch) {
TensorShape acl_shape_to_mgb_shape_for_output(
aclmdlDesc* model_desc, size_t output_idx, size_t output_dtype_size,
aclmdlIODims acl_shape, size_t batch) {
TensorShape ret;
ret.ndim = acl_shape.dimCount;
for (size_t i = 0; i < ret.ndim; ++i) {
ret[i] = acl_shape.dims[i];
}
if (acl_shape.dims[0] == -1) {
batch = aclmdlGetOutputSizeByIndex(model_desc, output_idx);
size_t chw = output_dtype_size;
for (size_t i = 1; i < ret.ndim; ++i) {
chw *= ret[i];
}
mgb_assert(
batch % chw == 0,
"When the input batch is static and the output batch is dynamic, it is "
"necessary to reconfigure the output batch. The output size obtained "
"from the aclmdlGetOutputSizeByIndex interface should be evenly "
"divided by "
"shapes other than the batch. expect 0, but got %zu\n",
batch % chw);
batch /= chw;
}
ret[0] = batch;
return ret;
}
......@@ -332,7 +350,7 @@ void AtlasRuntimeOpr::scn_do_execute() {
for (size_t i = 0; i < nr_outputs; i++) {
auto value_pair = output_getter.get(batch, i);
size_t output_size = value_pair.second;
if (enable_dynamic_batch) {
if (enable_dynamic_batch || m_dyn_batch_output[i]) {
output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
}
aclDataBuffer* output_db =
......@@ -343,6 +361,18 @@ void AtlasRuntimeOpr::scn_do_execute() {
"%zu:%s.",
i, output(i)->cname());
aclmdlAddDatasetBuffer(model_outputs, output_db);
if (m_dyn_batch_output[i]) {
auto tensor_ndim = output(0)->shape().ndim;
std::vector<int64_t> tensor_shape(tensor_ndim, 0);
for (size_t j = 0; j < tensor_ndim; j++) {
tensor_shape[j] = output(0)->shape()[j];
}
aclTensorDesc* tensorDesc = aclCreateTensorDesc(
aclmdlGetOutputDataType(m_model_desc, i), tensor_ndim,
tensor_shape.data(), aclmdlGetOutputFormat(m_model_desc, i));
aclmdlSetDatasetTensorDesc(model_outputs, tensorDesc, i);
}
}
MGB_ATLAS_CHECK(aclmdlExecute(m_model_id, model_inputs, model_outputs));
......@@ -351,6 +381,31 @@ void AtlasRuntimeOpr::scn_do_execute() {
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
}
for (size_t i = 0; i < nr_outputs; ++i) {
if (m_dyn_batch_output[i]) {
const DeviceTensorND old_dev_tensor = output(i)->dev_tensor();
auto new_output_desc = aclmdlGetDatasetTensorDesc(model_outputs, i);
TensorShape new_shape;
new_shape.ndim = aclGetTensorDescNumDims(new_output_desc);
mgb_assert(
new_shape.ndim == old_dev_tensor.layout().ndim,
"for static input batch and dynamic output batch, the output "
"ndim should be consistent with the one before calling "
"aclmdlExecute(), so expect %zu, but got %zu",
old_dev_tensor.layout().ndim, new_shape.ndim);
for (size_t j = 0; j < new_shape.ndim; j++) {
new_shape.shape[j] = aclGetTensorDescDim(new_output_desc, j);
}
TensorLayout new_layout{
new_shape, old_dev_tensor.dtype(), old_dev_tensor.format()};
DeviceTensorND new_dev_tensor{
old_dev_tensor.comp_node(), new_layout, old_dev_tensor.dtype(),
old_dev_tensor.format()};
new_dev_tensor.reset(old_dev_tensor.storage(), new_layout);
output(i)->force_assign_dev_tensor_from_tensor(new_dev_tensor);
}
aclDataBuffer* db_ptr = aclmdlGetDatasetBuffer(model_outputs, i);
MGB_ATLAS_CHECK(aclDestroyDataBuffer(db_ptr));
}
......@@ -387,7 +442,9 @@ void AtlasRuntimeOpr::get_output_var_shape(
for (size_t i = 0; i < out_shape.size(); ++i) {
aclmdlIODims output_dims;
MGB_ATLAS_CHECK(aclmdlGetOutputDims(m_model_desc, i, &output_dims));
out_shape[i] = acl_shape_to_mgb_shape_for_output(output_dims, batch_size);
out_shape[i] = acl_shape_to_mgb_shape_for_output(
m_model_desc, i, output(i)->dtype().size(), output_dims, batch_size);
m_dyn_batch_output.push_back(output_dims.dims[0] == -1);
}
}
......
......@@ -64,6 +64,9 @@ private:
//! Atlas need a 64bit device tensor to hold dynamic batch state
DeviceTensorND m_dyn_batch_tensor;
SmallVector<size_t> m_dyn_batch_choices;
//! Used when the input batch is static and the output batch is dynamic. Different
//! from the case where the input batch is dynamic and the output batch is dynamic
mutable SmallVector<bool> m_dyn_batch_output;
};
} // namespace opr
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册