config.py 2.7 KB
Newer Older
Z
Zeyu Chen 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2019  PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

W
wuzewu 已提交
15
import time
Z
Zeyu Chen 已提交
16

Z
Zeyu Chen 已提交
17
from datetime import datetime
W
wuzewu 已提交
18 19
from paddlehub.finetune.strategy import DefaultStrategy
from paddlehub.common.logger import logger
20

21 22

class RunConfig(object):
23 24 25 26 27 28 29 30 31 32 33
    """ This class specifies the configurations for PaddleHub to finetune """

    def __init__(self,
                 log_interval=10,
                 eval_interval=100,
                 save_ckpt_interval=None,
                 use_cuda=False,
                 checkpoint_dir=None,
                 num_epoch=10,
                 batch_size=None,
                 enable_memory_optim=True,
Z
Zeyu Chen 已提交
34
                 strategy=None):
35 36 37 38 39 40 41 42
        """ Construct finetune Config """
        self._log_interval = log_interval
        self._eval_interval = eval_interval
        self._save_ckpt_interval = save_ckpt_interval
        self._use_cuda = use_cuda
        self._checkpoint_dir = checkpoint_dir
        self._num_epoch = num_epoch
        self._batch_size = batch_size
Z
Zeyu Chen 已提交
43 44 45 46
        if strategy is None:
            self._strategy = DefaultStrategy()
        else:
            self._strategy = strategy
47
        self._enable_memory_optim = enable_memory_optim
W
wuzewu 已提交
48
        if checkpoint_dir is None:
49 50 51 52

            now = int(time.time())
            time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(now))
            self._checkpoint_dir = "ckpt_" + time_str
W
wuzewu 已提交
53 54
        else:
            self._checkpoint_dir = checkpoint_dir
55
        logger.info("Checkpoint dir: {}".format(self._checkpoint_dir))
56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85

    @property
    def log_interval(self):
        return self._log_interval

    @property
    def eval_interval(self):
        return self._eval_interval

    @property
    def save_ckpt_interval(self):
        return self._save_ckpt_interval

    @property
    def use_cuda(self):
        return self._use_cuda

    @property
    def checkpoint_dir(self):
        return self._checkpoint_dir

    @property
    def num_epoch(self):
        return self._num_epoch

    @property
    def batch_size(self):
        return self._batch_size

    @property
Z
Zeyu Chen 已提交
86 87
    def strategy(self):
        return self._strategy
88 89 90 91

    @property
    def enable_memory_optim(self):
        return self._enable_memory_optim