From f7731bd4370d4044ee5b6424668bede031c2aedc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 26 Nov 2020 17:08:52 +0800 Subject: [PATCH] fix(mgb/jit): fix a pointer bug in mlir executable_cuda GitOrigin-RevId: 3ec79b760233cee3d7ce9ecc367559c1336c0fec --- src/jit/impl/mlir/executable_cuda.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/jit/impl/mlir/executable_cuda.cpp b/src/jit/impl/mlir/executable_cuda.cpp index 638ec697..4f7c8629 100644 --- a/src/jit/impl/mlir/executable_cuda.cpp +++ b/src/jit/impl/mlir/executable_cuda.cpp @@ -36,13 +36,14 @@ template void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, int block_size) { auto&& args = fusion_opr->args(); - std::vector> param_holders; + size_t num_memrefs = args.inputs.size() + args.outputs.size(); + std::vector> param_holders(num_memrefs); std::vector params; auto set_params = [¶m_holders, ¶ms]( - void* ptr, const megdnn::TensorLayout& layout) { - param_holders.push_back(StridedMemRefType{}); - StridedMemRefType& desc = param_holders.back(); + size_t idx, void* ptr, + const megdnn::TensorLayout& layout) { + auto& desc = param_holders[idx]; desc.basePtr = static_cast(ptr); params.push_back(&(desc.basePtr)); desc.data = static_cast(ptr); @@ -56,9 +57,12 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, params.push_back(&(desc.strides[i])); } }; + + size_t idx = 0; for (const auto& arg : args.inputs) { - set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); + set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } + int64_t nr_elements = 0; for (const auto& arg : args.outputs) { if (nr_elements == 0) { @@ -73,8 +77,13 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func, arg.from->layout().to_string().c_str()); } - set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout()); + set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout()); } + + mgb_assert(param_holders.size() == num_memrefs, + "calling push_back method of param_holders is unsafe as it " + "might cause reallocation of std::vector"); + const CompNodeEnv& env = CompNodeEnv::from_comp_node(fusion_opr->comp_node()); -- GitLab