diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index 12fb1ea6ffe369fef21274eb9a5221cf4e221812..6af8b0f4757dca4d9b0e0ba76cffc1ff3308b9de 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -58,7 +58,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT', 'ByteTrack'] def _prune_input_spec(input_spec, program, targets): # try to prune static program to figure out pruned input spec # so we perform following operations in static mode + device = paddle.get_device() paddle.enable_static() + paddle.set_device(device) pruned_input_spec = [{}] program = program.clone() program = program._prune(targets=targets) @@ -69,7 +71,7 @@ def _prune_input_spec(input_spec, program, targets): pruned_input_spec[0][name] = spec except Exception: pass - paddle.disable_static() + paddle.disable_static(place=device) return pruned_input_spec