未验证 提交 af6c8ef3 编写于 作者: W wangguanzhong 提交者: GitHub

keep device in export (#6158)

上级 b0482b8b
......@@ -58,7 +58,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
device = paddle.get_device()
paddle.enable_static()
paddle.set_device(device)
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
......@@ -69,7 +71,7 @@ def _prune_input_spec(input_spec, program, targets):
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
paddle.disable_static(place=device)
return pruned_input_spec
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册