setup.py 633 字节
Newer Older
L
LielinJiang 已提交
1 2 3 4
import os
import time
import paddle

L
LielinJiang 已提交
5
from paddle.distributed import ParallelEnv
L
LielinJiang 已提交
6 7 8 9 10 11 12 13 14

from .logger import setup_logger


def setup(args, cfg):
    if args.evaluate_only:
        cfg.isTrain = False

    cfg.timestamp = time.strftime('-%Y-%m-%d-%H-%M', time.localtime())
L
LielinJiang 已提交
15 16
    cfg.output_dir = os.path.join(cfg.output_dir,
                                  str(cfg.model.name) + cfg.timestamp)
L
LielinJiang 已提交
17 18 19 20 21

    logger = setup_logger(cfg.output_dir)

    logger.info('Configs: {}'.format(cfg))

L
fix nan  
LielinJiang 已提交
22 23
    place = paddle.CUDAPlace(ParallelEnv().dev_id) \
                    if ParallelEnv().nranks > 1 else paddle.CUDAPlace(0)
L
LielinJiang 已提交
24
    paddle.disable_static(place)