提交 e92670e8 编写于 作者: M Megvii Engine Team

fix(mgb/atlas): when batchsize more than model max batchsize

GitOrigin-RevId: 63fe79eaa98e97b485ac731e87ce1885a190c271
上级 147dbf8a
......@@ -278,6 +278,9 @@ void AtlasRuntimeOpr::scn_do_execute() {
for (size_t i = 0; i < output().size(); i++) {
auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i);
auto ovar = output(i);
output_size = std::max<size_t>(
output_size,
ovar->dtype().size(ovar->shape().total_nr_elems()));
ovar->shape_alloc(ovar->shape(), output_size);
}
}
......
......@@ -65,7 +65,7 @@ TEST(TestOprAtlas, Basic) {
}
TEST(TestOprAtlas, DynamicBatch) {
for (size_t batch : {1, 6}) {
for (size_t batch : {1, 6, 20}) {
HostTensorGenerator<> gen;
const auto& graph = ComputingGraph::make();
const auto& host_x = gen({batch, 3, 16, 16});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册