From e92670e820003207d431a1b859a808d8d8089582 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 2 Dec 2020 17:32:18 +0800 Subject: [PATCH] fix(mgb/atlas): when batchsize more than model max batchsize GitOrigin-RevId: 63fe79eaa98e97b485ac731e87ce1885a190c271 --- src/opr/impl/atlas_runtime_op.cpp | 3 +++ src/opr/test/atlas_runtime_op.cpp | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 7853545a..46d3e437 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -278,6 +278,9 @@ void AtlasRuntimeOpr::scn_do_execute() { for (size_t i = 0; i < output().size(); i++) { auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); auto ovar = output(i); + output_size = std::max( + output_size, + ovar->dtype().size(ovar->shape().total_nr_elems())); ovar->shape_alloc(ovar->shape(), output_size); } } diff --git a/src/opr/test/atlas_runtime_op.cpp b/src/opr/test/atlas_runtime_op.cpp index 19ab0f24..b837da44 100644 --- a/src/opr/test/atlas_runtime_op.cpp +++ b/src/opr/test/atlas_runtime_op.cpp @@ -65,7 +65,7 @@ TEST(TestOprAtlas, Basic) { } TEST(TestOprAtlas, DynamicBatch) { - for (size_t batch : {1, 6}) { + for (size_t batch : {1, 6, 20}) { HostTensorGenerator<> gen; const auto& graph = ComputingGraph::make(); const auto& host_x = gen({batch, 3, 16, 16}); -- GitLab