From bec79a0084cbace63257fdb8e99273a48707d67c Mon Sep 17 00:00:00 2001 From: littletomatodonkey <2120160898@bit.edu.cn> Date: Thu, 12 Nov 2020 21:43:40 +0800 Subject: [PATCH] update static method by trying-catch (#391) --- ppcls/utils/check.py | 8 ++++++++ tools/eval.py | 3 ++- tools/eval_multi_platform.py | 3 ++- tools/export_model.py | 3 ++- tools/infer/infer.py | 3 ++- tools/infer/py_infer.py | 9 ++++++++- tools/train.py | 3 ++- tools/train_multi_platform.py | 3 ++- 8 files changed, 28 insertions(+), 7 deletions(-) diff --git a/ppcls/utils/check.py b/ppcls/utils/check.py index c8f13eb4..a5746846 100644 --- a/ppcls/utils/check.py +++ b/ppcls/utils/check.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import sys +import paddle import paddle.fluid as fluid from ppcls.modeling import get_architectures @@ -134,3 +135,10 @@ def check_function_params(config, key): ('params is required in {} config'.format(key)) assert isinstance(params, dict), \ ('the params in {} config should be a dict'.format(key)) + + +def enable_static_mode(): + try: + paddle.enable_static() + except: + pass diff --git a/tools/eval.py b/tools/eval.py index f7452974..7129669e 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -32,6 +32,7 @@ import program from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model +from ppcls.utils.check import enable_static_mode from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.base import role_maker @@ -84,6 +85,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/eval_multi_platform.py b/tools/eval_multi_platform.py index ad640afb..78501f5a 100644 --- a/tools/eval_multi_platform.py +++ b/tools/eval_multi_platform.py @@ -31,6 +31,7 @@ import program from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model +from ppcls.utils.check import enable_static_mode def parse_args(): @@ -77,6 +78,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/export_model.py b/tools/export_model.py index 538287a3..d911856d 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -21,6 +21,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import argparse from ppcls.modeling import architectures +from ppcls.utils.check import enable_static_mode import paddle import paddle.fluid as fluid @@ -82,5 +83,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/infer/infer.py b/tools/infer/infer.py index dcff66ea..0bb6fd41 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -24,6 +24,7 @@ import paddle import paddle.fluid as fluid from ppcls.modeling import architectures +from ppcls.utils.check import enable_static_mode import utils @@ -145,5 +146,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/infer/py_infer.py b/tools/infer/py_infer.py index caaba405..e0c8f199 100644 --- a/tools/infer/py_infer.py +++ b/tools/infer/py_infer.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import utils import argparse @@ -19,6 +24,8 @@ import numpy as np import paddle import paddle.fluid as fluid +from ppcls.utils.check import enable_static_mode + def parse_args(): def str2bool(v): @@ -100,5 +107,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/train.py b/tools/train.py index 89c7dea9..35b5fb9f 100644 --- a/tools/train.py +++ b/tools/train.py @@ -33,6 +33,7 @@ from paddle.fluid.incubate.fleet.collective import fleet from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model +from ppcls.utils.check import enable_static_mode from ppcls.utils import logger import program @@ -155,6 +156,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/train_multi_platform.py b/tools/train_multi_platform.py index 8ebee57a..6362d8be 100644 --- a/tools/train_multi_platform.py +++ b/tools/train_multi_platform.py @@ -31,6 +31,7 @@ import paddle.fluid as fluid from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model +from ppcls.utils.check import enable_static_mode from ppcls.utils import logger import program @@ -164,6 +165,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) -- GitLab