提交 a240d558 编写于 作者: M Megvii Engine Team

fix(mge/dnn): fix rng and topk oom in distributed training

GitOrigin-RevId: 9841d1219e113f2d8ef8e4231248ec77a997163a
上级 c3c1e46d
......@@ -379,6 +379,8 @@ void TopK::init_output_static_infer_desc() {
}
auto infer_workspace = [this](TensorShape& dst, const InpVal& iv) {
// active comp_node for cuda launch kernel in get_workspace_in_bytes
comp_node().activate();
auto k = iv.val[3].value().ptr<int>()[0];
auto size = megdnn_opr()->get_workspace_in_bytes(
k, {iv.val[0].shape(), input(0)->dtype()},
......
......@@ -60,6 +60,8 @@ cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const {
void RNGOprBase::ensure_megdnn_opr() {
if (!m_megdnn_opr || m_megdnn_opr.comp_node() != comp_node()) {
// activate comp_node for curandCreateGenerator in create_megdnn_opr
comp_node().activate();
m_megdnn_opr = create_megdnn_opr();
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册