提交 0c3699ff 编写于 作者: M Megvii Engine Team 提交者: XindaH

test(imperative/python): fix testcase for magicmind runtime module

GitOrigin-RevId: baf2f72f01f05e87cbfe181c879785e7d24d6d1e
上级 f398c8e6
......@@ -20,3 +20,4 @@ ci/resource/dump/batch_conv_bias_with_policy_8.8.0.mdl filter=lfs diff=lfs merge
ci/resource/prof/model_with_err_assert.mdl filter=lfs diff=lfs merge=lfs -text
ci/resource/prof/test_mge.mge filter=lfs diff=lfs merge=lfs -text
lite/test/resource/lite/ax_models/64-58063ce2.axe filter=lfs diff=lfs merge=lfs -text
imperative/python/test/unit/module/MagicMindRuntimeOprTest.GraphShapeMutable.mlu filter=lfs diff=lfs merge=lfs -text
......@@ -13,6 +13,7 @@ from ..functional.external import (
atlas_runtime_opr,
cambricon_runtime_opr,
extern_opr_subgraph,
magicmind_runtime_opr,
tensorrt_runtime_opr,
)
from .module import Module
......@@ -131,6 +132,7 @@ class AtlasRuntimeSubgraph(Module):
def forward(self, *inputs):
return atlas_runtime_opr(inputs, data=self._data)
class MagicMindRuntimeSubgraph(Module):
r"""Load a serialized MagicMindRuntime subgraph.
......@@ -151,6 +153,3 @@ class MagicMindRuntimeSubgraph(Module):
def forward(self, *inputs):
return magicmind_runtime_opr(inputs, data=self._data)
......@@ -267,7 +267,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
mgb_assert(
tensor != nullptr, "failed to find input tensor(name:%s)",
iname.c_str());
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(input(i)->shape())));
MM_CHECK(tensor->SetDimensions(mgb_shape_to_mm_dims(inp_shape[i])));
}
if (Status::OK() == m_context->InferOutputShape(inputs, outputs)) {
size_t nr_outputs = output().size();
......@@ -283,7 +283,7 @@ void MagicMindRuntimeOpr::get_output_var_shape(
}
std::vector<Dims> shape(inp_shape.size());
for (size_t i = 0; i < nr_inputs; ++i) {
shape[i] = mgb_shape_to_mm_dims(input(i)->shape());
shape[i] = mgb_shape_to_mm_dims(inp_shape[i]);
}
size_t wk_size = 0;
MM_CHECK(m_engine->QueryContextMaxWorkspaceSize(shape, &wk_size));
......
......@@ -390,7 +390,11 @@ void CompNodeEnv::init_cnrt(
MGB_CNRT_CHECK(cnrtGetDeviceInfo(&m_cnrt_env.device_info, dev));
// FIXME: doc doesn't describe the aligment requirement for device memory
// address
#if CNRT_MAJOR_VERSION >= 5
m_property.mem_alignment = 256u;
#else
m_property.mem_alignment = 1u;
#endif
// ensure exception safe
bool queue_created = false;
MGB_MARK_USED_VAR(queue_created);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册