diff --git a/ppdet/engine/export_utils.py b/ppdet/engine/export_utils.py index e1cf64638089c3e98fb7d08af7a3e39246cf6c48..9608699f1410f2a7b7e3116f651c9ce11518f5b8 100644 --- a/ppdet/engine/export_utils.py +++ b/ppdet/engine/export_utils.py @@ -55,7 +55,9 @@ MOT_ARCH = ['DeepSORT', 'JDE', 'FairMOT'] def _prune_input_spec(input_spec, program, targets): # try to prune static program to figure out pruned input spec # so we perform following operations in static mode + device = paddle.get_device() paddle.enable_static() + paddle.set_device(device) pruned_input_spec = [{}] program = program.clone() program = program._prune(targets=targets) @@ -66,7 +68,7 @@ def _prune_input_spec(input_spec, program, targets): pruned_input_spec[0][name] = spec except Exception: pass - paddle.disable_static() + paddle.disable_static(place=device) return pruned_input_spec