diff --git a/ppcls/utils/check.py b/ppcls/utils/check.py index c8f13eb4ab97cd8d412aba9e50f9368015278e41..a5746846d9a14545d08bedf661674245689ff1a7 100644 --- a/ppcls/utils/check.py +++ b/ppcls/utils/check.py @@ -19,6 +19,7 @@ from __future__ import print_function import os import sys +import paddle import paddle.fluid as fluid from ppcls.modeling import get_architectures @@ -134,3 +135,10 @@ def check_function_params(config, key): ('params is required in {} config'.format(key)) assert isinstance(params, dict), \ ('the params in {} config should be a dict'.format(key)) + + +def enable_static_mode(): + try: + paddle.enable_static() + except: + pass diff --git a/tools/eval.py b/tools/eval.py index f74529747f684b578ca76fc93baa2037abe46b52..7129669e5fcafc0b469bd051ce115e873669fefa 100644 --- a/tools/eval.py +++ b/tools/eval.py @@ -32,6 +32,7 @@ import program from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model +from ppcls.utils.check import enable_static_mode from paddle.fluid.incubate.fleet.collective import fleet from paddle.fluid.incubate.fleet.base import role_maker @@ -84,6 +85,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/eval_multi_platform.py b/tools/eval_multi_platform.py index ad640afb8336cf2af0549c4efd75ce3b462ed50b..78501f5ae5673ddfc7f0ec3fecf87b818f55ae5e 100644 --- a/tools/eval_multi_platform.py +++ b/tools/eval_multi_platform.py @@ -31,6 +31,7 @@ import program from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model +from ppcls.utils.check import enable_static_mode def parse_args(): @@ -77,6 +78,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/export_model.py b/tools/export_model.py index 538287a3671f0e439ce62d3e80f362b01045baa2..d911856d8c32e0d9596de2d73ebf4a1bc7528c84 100644 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -21,6 +21,7 @@ sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import argparse from ppcls.modeling import architectures +from ppcls.utils.check import enable_static_mode import paddle import paddle.fluid as fluid @@ -82,5 +83,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/infer/infer.py b/tools/infer/infer.py index dcff66ea6c9599f0ef912b567c15e39d368504e0..0bb6fd41c821f7d313a4e1d95fed252d7e34b316 100644 --- a/tools/infer/infer.py +++ b/tools/infer/infer.py @@ -24,6 +24,7 @@ import paddle import paddle.fluid as fluid from ppcls.modeling import architectures +from ppcls.utils.check import enable_static_mode import utils @@ -145,5 +146,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/infer/py_infer.py b/tools/infer/py_infer.py index caaba405f3b53e90af2fdb834acd13ea20ff95d2..e0c8f1990434079be2abd2540bcb25b1d79b76d2 100644 --- a/tools/infer/py_infer.py +++ b/tools/infer/py_infer.py @@ -11,6 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os +import sys +__dir__ = os.path.dirname(os.path.abspath(__file__)) +sys.path.append(__dir__) +sys.path.append(os.path.abspath(os.path.join(__dir__, '../'))) import utils import argparse @@ -19,6 +24,8 @@ import numpy as np import paddle import paddle.fluid as fluid +from ppcls.utils.check import enable_static_mode + def parse_args(): def str2bool(v): @@ -100,5 +107,5 @@ def main(): if __name__ == "__main__": - paddle.enable_static() + enable_static_mode() main() diff --git a/tools/train.py b/tools/train.py index 89c7dea9dc0c26fe669ae2c106ed051887a4879b..35b5fb9fb079bceb78b67cbd29b178afc1c8a74c 100644 --- a/tools/train.py +++ b/tools/train.py @@ -33,6 +33,7 @@ from paddle.fluid.incubate.fleet.collective import fleet from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model +from ppcls.utils.check import enable_static_mode from ppcls.utils import logger import program @@ -155,6 +156,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args) diff --git a/tools/train_multi_platform.py b/tools/train_multi_platform.py index 8ebee57a561c899548e0958527383de55231ce08..6362d8beae1e6d58997e4dce85c6287cac9bc164 100644 --- a/tools/train_multi_platform.py +++ b/tools/train_multi_platform.py @@ -31,6 +31,7 @@ import paddle.fluid as fluid from ppcls.data import Reader from ppcls.utils.config import get_config from ppcls.utils.save_load import init_model, save_model +from ppcls.utils.check import enable_static_mode from ppcls.utils import logger import program @@ -164,6 +165,6 @@ def main(args): if __name__ == '__main__': - paddle.enable_static() + enable_static_mode() args = parse_args() main(args)