提交 f7731bd4 编写于 作者: M Megvii Engine Team

fix(mgb/jit): fix a pointer bug in mlir executable_cuda

GitOrigin-RevId: 3ec79b760233cee3d7ce9ecc367559c1336c0fec
上级 810d8cba
......@@ -36,13 +36,14 @@ template <int out_dim, typename ctype>
void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
int block_size) {
auto&& args = fusion_opr->args();
std::vector<StridedMemRefType<ctype, out_dim>> param_holders;
size_t num_memrefs = args.inputs.size() + args.outputs.size();
std::vector<StridedMemRefType<ctype, out_dim>> param_holders(num_memrefs);
std::vector<void*> params;
auto set_params = [&param_holders, &params](
void* ptr, const megdnn::TensorLayout& layout) {
param_holders.push_back(StridedMemRefType<ctype, out_dim>{});
StridedMemRefType<ctype, out_dim>& desc = param_holders.back();
size_t idx, void* ptr,
const megdnn::TensorLayout& layout) {
auto& desc = param_holders[idx];
desc.basePtr = static_cast<ctype*>(ptr);
params.push_back(&(desc.basePtr));
desc.data = static_cast<ctype*>(ptr);
......@@ -56,9 +57,12 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
params.push_back(&(desc.strides[i]));
}
};
size_t idx = 0;
for (const auto& arg : args.inputs) {
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}
int64_t nr_elements = 0;
for (const auto& arg : args.outputs) {
if (nr_elements == 0) {
......@@ -73,8 +77,13 @@ void setup_and_launch(const JITExecutor* fusion_opr, CUfunction func,
arg.from->layout().to_string().c_str());
}
set_params(arg.from->dev_tensor().raw_ptr(), arg.from->layout());
set_params(idx++, arg.from->dev_tensor().raw_ptr(), arg.from->layout());
}
mgb_assert(param_holders.size() == num_memrefs,
"calling push_back method of param_holders is unsafe as it "
"might cause reallocation of std::vector");
const CompNodeEnv& env =
CompNodeEnv::from_comp_node(fusion_opr->comp_node());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册