diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index 638ec697ae2832ac2f25fb9e0f3beee12aa454ae..4f7c8629e96fb37a664273c4e01d07c4adc5e285 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -36,13 +36,14 @@ template void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, int block_size) { auto&& args = fusion_opr->args(); - std::vector> param_holders; + size_t num_memrefs = args.inputs.size() + args.outputs.size(); + std::vector> param_holders(num_memrefs); std::vector params; auto set_params = [¶m_holders, ¶ms]( - void* ptr, const megdnn::TensorLayout& layout) { - param_holders.push_back(StridedMemRefType{}); - StridedMemRefType& desc = param_holders.back(); + size_t idx, void* ptr, + const megdnn::TensorLayout& layout) { + auto& desc = param_holders[idx]; desc.basePtr = static_cast(ptr); params.push_back(&(desc.basePtr)); desc.data = static_cast(ptr); @@ -56,9 +57,12 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, params.push_back(&(desc.strides[i])); } }; + + size_t idx = 0; for (const auto& arg : args.inputs) { - set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); + set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } + int64_t nr_elements = 0; for (const auto& arg : args.outputs) { if (nr_elements == 0) { @@ -73,8 +77,13 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, arg.from->layout().to_string().c_str()); } - set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); + set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } + + mgb_assert(param_holders.size() == num_memrefs, + "calling push_back method of param_holders is unsafe as it " + "might cause reallocation of std::vector"); + const CompNodeEnv& env = CompNodeEnv::from_comp_node(fusion_opr->comp_node());