diff --git a/tsm/README.md b/tsm/README.md index 918a7e01cd5f0f56074c4c647f3817711ac86891..68e7c37b012ddb863119d52095390bc240bac3a1 100644 --- a/tsm/README.md +++ b/tsm/README.md @@ -52,7 +52,13 @@ TSM的训练数据采用由DeepMind公布的Kinetics-400动作识别数据集。 ### 模型训练 -数据准备完毕后,可以通过如下方式启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。 +数据准备完毕后,可使用`main.py`脚本启动训练和评估,如下脚本会自动每epoch交替进行训练和模型评估,并将checkpoint默认保存在`tsm_checkpoint`目录下。 + +`main.py`脚本参数可通过如下命令查询 + +```shell +python main.py --help +``` #### 静态图训练 diff --git a/tsm/check.py b/tsm/check.py new file mode 100644 index 0000000000000000000000000000000000000000..16c07568c7f1a0319f791cb39244494f3ddf9f12 --- /dev/null +++ b/tsm/check.py @@ -0,0 +1,62 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import sys + +import paddle.fluid as fluid + +import logging +logger = logging.getLogger(__name__) + +__all__ = ['check_gpu', 'check_version'] + + +def check_gpu(use_gpu): + """ + Log error and exit when set use_gpu=true in paddlepaddle + cpu version. + """ + err = "Config use_gpu cannot be set as true while you are " \ + "using paddlepaddle cpu version ! \nPlease try: \n" \ + "\t1. Install paddlepaddle-gpu to run model on GPU \n" \ + "\t2. Set use_gpu as false in config file to run " \ + "model on CPU" + + try: + if use_gpu and not fluid.is_compiled_with_cuda(): + logger.error(err) + sys.exit(1) + except Exception as e: + pass + + +def check_version(version='1.7.0'): + """ + Log error and exit when the installed version of paddlepaddle is + not satisfied. + """ + err = "PaddlePaddle version {} or higher is required, " \ + "or a suitable develop version is satisfied as well. \n" \ + "Please make sure the version is good with your code." \ + .format(version) + + try: + fluid.require_version(version) + except Exception as e: + logger.error(err) + sys.exit(1)