From f18e57984b2953320c5317eabcffca03080b36ed Mon Sep 17 00:00:00 2001 From: wangguanzhong Date: Thu, 9 Jun 2022 17:28:16 +0800 Subject: [PATCH] keep device in export (#6157) --- ppdet/engine/export_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 12fb1ea6f..6af8b0f47 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -58,7 +58,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack'] def _prune_input_spec(input_spec, program, targets): # try to prune static program to figure out pruned input spec # so we perform following operations in static mode + device = paddle.get_device() paddle.enable_static() + paddle.set_device(device) pruned_input_spec = [{}] program = program.clone() program = program._prune(targets=targets) @@ -69,7 +71,7 @@ def _prune_input_spec(input_spec, program, targets): pruned_input_spec[0][name] = spec except Exception: pass - paddle.disable_static() + paddle.disable_static(place=device) return pruned_input_spec -- GitLab