未验证 提交 af6c8ef3 编写于 作者: W wangguanzhong 提交者: GitHub

keep device in export (#6158)

上级 b0482b8b
...@@ -58,7 +58,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack'] ...@@ -58,7 +58,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack']
def _prune_input_spec(input_spec, program, targets): def _prune_input_spec(input_spec, program, targets):
# try to prune static program to figure out pruned input spec # try to prune static program to figure out pruned input spec
# so we perform following operations in static mode # so we perform following operations in static mode
device = paddle.get_device()
paddle.enable_static() paddle.enable_static()
paddle.set_device(device)
pruned_input_spec = [{}] pruned_input_spec = [{}]
program = program.clone() program = program.clone()
program = program._prune(targets=targets) program = program._prune(targets=targets)
...@@ -69,7 +71,7 @@ def _prune_input_spec(input_spec, program, targets): ...@@ -69,7 +71,7 @@ def _prune_input_spec(input_spec, program, targets):
pruned_input_spec[0][name] = spec pruned_input_spec[0][name] = spec
except Exception: except Exception:
pass pass
paddle.disable_static() paddle.disable_static(place=device)
return pruned_input_spec return pruned_input_spec
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册