提交 ea4e6ab9 编写于 作者: M Megvii Engine Team

fix(mgb/opr): fix shape cache of NvOF

GitOrigin-RevId: 456ba478e9f5be21027f73f302291b114cc2c2d6
上级 3228fb75
......@@ -215,8 +215,9 @@ SymbolVar NvOf::make(
void NvOf::scn_do_execute() {
auto input_shape = this->input()[0]->shape();
std::vector<size_t> t_shape;
for (size_t i = 0; i < 5; i++) {
vshape.push_back(input_shape[i]);
t_shape.push_back(input_shape[i]);
}
auto c = this->comp_node();
//! comp_node may init on CUDA or CPU, eg: lar with --cpu
......@@ -232,7 +233,8 @@ void NvOf::scn_do_execute() {
//! create NvOF engine at same device id of comp_node, can not get
//! comp_node device id, when NvOf:NvOf, so init at scn_do_execute
std::lock_guard<std::mutex> lock(m_lock);
if (init_flag == false) {
if (init_flag == false || vshape != t_shape) {
vshape = t_shape;
//! nvof sdk do not imp p2p copy, so init nvof engine on the same
//! device with mgb comp_node
nv_flow_extractor = std::make_shared<NVFlowExtractor>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册