未验证 提交 2a4c8f4b 编写于 作者: W wangguanzhong 提交者: GitHub

keep device in export (#6159)

上级 441beb50
......@@ -55,7 +55,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT']
def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec
# so we perform following operations in static mode
device = paddle.get_device()
paddle.enable_static()
paddle.set_device(device)
pruned_input_spec = [{}]
program = program.clone()
program = program._prune(targets=targets)
......@@ -66,7 +68,7 @@ def _prune_input_spec(input_spec, program, targets):
pruned_input_spec[0][name] = spec
except Exception:
pass
paddle.disable_static()
paddle.disable_static(place=device)
return pruned_input_spec
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册