diff --git a/src/opr/impl/atlas_runtime_op.cpp b/src/opr/impl/atlas_runtime_op.cpp index 7853545a2ec9f28fc27fd9e814637ca04ec38246..46d3e4373ba83ba006dd476650cbf063174d17b6 100644 --- a/src/opr/impl/atlas_runtime_op.cpp +++ b/src/opr/impl/atlas_runtime_op.cpp @@ -278,6 +278,9 @@ void AtlasRuntimeOpr::scn_do_execute() { for (size_t i = 0; i < output().size(); i++) { auto output_size = aclmdlGetOutputSizeByIndex(m_model_desc, i); auto ovar = output(i); + output_size = std::max( + output_size, + ovar->dtype().size(ovar->shape().total_nr_elems())); ovar->shape_alloc(ovar->shape(), output_size); } } diff --git a/src/opr/test/atlas_runtime_op.cpp b/src/opr/test/atlas_runtime_op.cpp index 19ab0f2425907d641917acf03a7b8c60fa0298e3..b837da44081d610dbcd7d5461c10467d7f9ccacc 100644 --- a/src/opr/test/atlas_runtime_op.cpp +++ b/src/opr/test/atlas_runtime_op.cpp @@ -65,7 +65,7 @@ TEST(TestOprAtlas, Basic) { } TEST(TestOprAtlas, DynamicBatch) { - for (size_t batch : {1, 6}) { + for (size_t batch : {1, 6, 20}) { HostTensorGenerator<> gen; const auto& graph = ComputingGraph::make(); const auto& host_x = gen({batch, 3, 16, 16});