util.py 10.3 KB
Newer Older
T
tangwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

T
tangwei 已提交
15
import datetime
X
xiexionghang 已提交
16 17
import os
import time
T
tangwei 已提交
18

T
tangwei 已提交
19
from paddle import fluid
T
tangwei 已提交
20

21
from paddlerec.core.utils import fs as fs
T
tangwei 已提交
22 23


T
tangwei 已提交
24
def save_program_proto(path, program=None):
T
tangwei 已提交
25

T
tangwei 已提交
26 27 28 29 30 31 32 33 34
    if program is None:
        _program = fluid.default_main_program()
    else:
        _program = program

    with open(path, "wb") as f:
        f.write(_program.desc.serialize_to_string())


T
tangwei 已提交
35 36 37 38 39 40 41 42 43
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise ValueError('Boolean value expected.')
T
tangwei 已提交
44

X
xiexionghang 已提交
45

T
tangwei 已提交
46 47 48 49 50 51 52 53 54 55 56 57 58 59
def run_which(command):
    regex = "/usr/bin/which: no {} in"
    ret = run_shell_cmd("which {}".format(command))
    if ret.startswith(regex.format(command)):
        return None
    else:
        return ret


def run_shell_cmd(command):
    assert command is not None and isinstance(command, str)
    return os.popen(command).read().strip()


X
xiexionghang 已提交
60
def get_env_value(env_name):
X
xiexionghang 已提交
61 62 63
    """
    get os environment value
    """
X
xiexionghang 已提交
64 65
    return os.popen("echo -n ${" + env_name + "}").read().strip()

X
xiexionghang 已提交
66

X
xiexionghang 已提交
67
def now_time_str():
X
xiexionghang 已提交
68 69 70 71
    """
    get current format str_time
    """
    return "\n" + time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + "[0]:"
X
xiexionghang 已提交
72

X
xiexionghang 已提交
73

X
xiexionghang 已提交
74
def get_absolute_path(path, params):
X
xiexionghang 已提交
75 76
    """R
    """
X
xiexionghang 已提交
77 78
    if path.startswith('afs:') or path.startswith('hdfs:'):
        sub_path = path.split('fs:')[1]
T
tangwei 已提交
79
        if ':' in sub_path:  # such as afs://xxx:prot/xxxx
X
xiexionghang 已提交
80 81 82 83 84 85
            return path
        elif 'fs_name' in params:
            return params['fs_name'] + sub_path
    else:
        return path

X
xiexionghang 已提交
86

X
xiexionghang 已提交
87
def make_datetime(date_str, fmt=None):
X
xiexionghang 已提交
88 89 90 91 92 93 94 95
    """
    create a datetime instance by date_string
    Args:
        date_str: such as 2020-01-14
        date_str_format: "%Y-%m-%d"
    Return:
        datetime 
    """
X
xiexionghang 已提交
96
    if fmt is None:
T
tangwei 已提交
97
        if len(date_str) == 8:  # %Y%m%d
X
xiexionghang 已提交
98
            return datetime.datetime.strptime(date_str, '%Y%m%d')
T
tangwei 已提交
99
        if len(date_str) == 12:  # %Y%m%d%H%M
X
xiexionghang 已提交
100 101 102 103 104
            return datetime.datetime.strptime(date_str, '%Y%m%d%H%M')
    return datetime.datetime.strptime(date_str, fmt)


def rank0_print(log_str):
X
xiexionghang 已提交
105 106
    """R
    """
X
xiexionghang 已提交
107 108
    print_log(log_str, {'master': True})

X
xiexionghang 已提交
109

X
xiexionghang 已提交
110
def print_cost(cost, params):
X
xiexionghang 已提交
111 112
    """R
    """
X
xiexionghang 已提交
113
    log_str = params['log_format'] % cost
T
tangwei 已提交
114
    print_log(log_str, params)
X
xiexionghang 已提交
115
    return log_str
T
tangwei 已提交
116

X
xiexionghang 已提交
117

X
xiexionghang 已提交
118
class CostPrinter(object):
X
xiexionghang 已提交
119 120 121
    """
    For count cost time && print cost log
    """
T
tangwei 已提交
122

X
xiexionghang 已提交
123
    def __init__(self, callback, callback_params):
X
xiexionghang 已提交
124 125
        """R
        """
X
xiexionghang 已提交
126 127
        self.reset(callback, callback_params)
        pass
T
tangwei 已提交
128

X
xiexionghang 已提交
129
    def __del__(self):
X
xiexionghang 已提交
130 131
        """R
        """
X
xiexionghang 已提交
132 133 134
        if not self._done:
            self.done()
        pass
T
tangwei 已提交
135

X
xiexionghang 已提交
136
    def reset(self, callback, callback_params):
X
xiexionghang 已提交
137 138
        """R
        """
X
xiexionghang 已提交
139 140 141 142 143
        self._done = False
        self._callback = callback
        self._callback_params = callback_params
        self._begin_time = time.time()
        pass
T
tangwei 已提交
144

X
xiexionghang 已提交
145
    def done(self):
X
xiexionghang 已提交
146 147
        """R
        """
X
xiexionghang 已提交
148
        cost = time.time() - self._begin_time
T
tangwei 已提交
149
        log_str = self._callback(cost, self._callback_params)  # cost(s)
X
xiexionghang 已提交
150 151 152
        self._done = True
        return cost, log_str

X
xiexionghang 已提交
153 154

class PathGenerator(object):
X
xiexionghang 已提交
155 156 157
    """
    generate path with template & runtime variables
    """
T
tangwei 已提交
158

X
xiexionghang 已提交
159
    def __init__(self, config):
X
xiexionghang 已提交
160 161
        """R
        """
T
tangwei 已提交
162
        self._templates = {}
X
xiexionghang 已提交
163 164
        self.add_path_template(config)
        pass
T
tangwei 已提交
165

X
xiexionghang 已提交
166
    def add_path_template(self, config):
X
xiexionghang 已提交
167 168
        """R
        """
X
xiexionghang 已提交
169 170 171 172 173 174
        if 'templates' in config:
            for template in config['templates']:
                self._templates[template['name']] = template['template']
        pass

    def generate_path(self, template_name, param):
X
xiexionghang 已提交
175 176
        """R
        """
X
xiexionghang 已提交
177 178
        if template_name in self._templates:
            if 'time_format' in param:
T
tangwei 已提交
179 180
                str = param['time_format'].strftime(self._templates[
                    template_name])
X
xiexionghang 已提交
181 182 183 184 185
                return str.format(**param)
            return self._templates[template_name].format(**param)
        else:
            return ""

X
xiexionghang 已提交
186

X
xiexionghang 已提交
187
class TimeTrainPass(object):
X
xiexionghang 已提交
188 189 190 191
    """
    timely pass
    define pass time_interval && start_time && end_time
    """
T
tangwei 已提交
192

X
xiexionghang 已提交
193
    def __init__(self, global_config):
X
xiexionghang 已提交
194 195
        """R
        """
X
xiexionghang 已提交
196 197 198 199 200 201
        self._config = global_config['epoch']
        if '+' in self._config['days']:
            day_str = self._config['days'].replace(' ', '')
            day_fields = day_str.split('+')
            self._begin_day = make_datetime(day_fields[0].strip())
            if len(day_fields) == 1 or len(day_fields[1]) == 0:
T
tangwei 已提交
202
                # 100 years, meaning to continuous running
T
tangwei 已提交
203 204
                self._end_day = self._begin_day + datetime.timedelta(
                    days=36500)
T
tangwei 已提交
205
            else:
X
xiexionghang 已提交
206 207
                # example: 2020212+10 
                run_day = int(day_fields[1].strip())
T
tangwei 已提交
208 209
                self._end_day = self._begin_day + datetime.timedelta(
                    days=run_day)
T
tangwei 已提交
210
        else:
X
xiexionghang 已提交
211
            # example: {20191001..20191031}
T
tangwei 已提交
212 213
            days = os.popen("echo -n " + self._config['days']).read().split(
                " ")
X
xiexionghang 已提交
214 215 216 217
            self._begin_day = make_datetime(days[0])
            self._end_day = make_datetime(days[len(days) - 1])
        self._checkpoint_interval = self._config['checkpoint_interval']
        self._dump_inference_interval = self._config['dump_inference_interval']
T
tangwei 已提交
218 219
        self._interval_per_pass = self._config[
            'train_time_interval']  # train N min data per pass
X
xiexionghang 已提交
220 221 222 223 224

        self._pass_id = 0
        self._inference_pass_id = 0
        self._pass_donefile_handler = None
        if 'pass_donefile_name' in self._config:
T
tangwei 已提交
225 226
            self._train_pass_donefile = global_config[
                'output_path'] + '/' + self._config['pass_donefile_name']
T
tangwei 已提交
227
            if fs.is_afs_path(self._train_pass_donefile):
T
tangwei 已提交
228 229
                self._pass_donefile_handler = fs.FileHandler(global_config[
                    'io']['afs'])
X
xiexionghang 已提交
230
            else:
T
tangwei 已提交
231 232
                self._pass_donefile_handler = fs.FileHandler(global_config[
                    'io']['local_fs'])
T
tangwei 已提交
233

T
tangwei 已提交
234 235
            last_done = self._pass_donefile_handler.cat(
                self._train_pass_donefile).strip().split('\n')[-1]
X
xiexionghang 已提交
236 237 238 239 240
            done_fileds = last_done.split('\t')
            if len(done_fileds) > 4:
                self._base_key = done_fileds[1]
                self._checkpoint_model_path = done_fileds[2]
                self._checkpoint_pass_id = int(done_fileds[3])
T
tangwei 已提交
241
                self._inference_pass_id = int(done_fileds[4])
X
xiexionghang 已提交
242 243 244
                self.init_pass_by_id(done_fileds[0], self._checkpoint_pass_id)

    def max_pass_num_day(self):
X
xiexionghang 已提交
245 246
        """R
        """
X
xiexionghang 已提交
247
        return 24 * 60 / self._interval_per_pass
T
tangwei 已提交
248

T
tangwei 已提交
249 250
    def save_train_progress(self, day, pass_id, base_key, model_path,
                            is_checkpoint):
X
xiexionghang 已提交
251 252
        """R
        """
X
xiexionghang 已提交
253 254 255
        if is_checkpoint:
            self._checkpoint_pass_id = pass_id
            self._checkpoint_model_path = model_path
T
tangwei 已提交
256 257 258 259 260
        done_content = "%s\t%s\t%s\t%s\t%d\n" % (
            day, base_key, self._checkpoint_model_path,
            self._checkpoint_pass_id, pass_id)
        self._pass_donefile_handler.write(done_content,
                                          self._train_pass_donefile, 'a')
X
xiexionghang 已提交
261 262 263
        pass

    def init_pass_by_id(self, date_str, pass_id):
X
xiexionghang 已提交
264 265 266 267 268 269
        """
        init pass context with pass_id
        Args:
            date_str: example "20200110"
            pass_id(int): pass_id of date
        """
T
tangwei 已提交
270
        date_time = make_datetime(date_str)
X
xiexionghang 已提交
271 272 273 274 275 276 277
        if pass_id < 1:
            pass_id = 0
        if (date_time - self._begin_day).total_seconds() > 0:
            self._begin_day = date_time
        self._pass_id = pass_id
        mins = self._interval_per_pass * (pass_id - 1)
        self._current_train_time = date_time + datetime.timedelta(minutes=mins)
T
tangwei 已提交
278

X
xiexionghang 已提交
279
    def init_pass_by_time(self, datetime_str):
X
xiexionghang 已提交
280 281 282 283 284
        """
        init pass context with datetime
        Args:
            date_str: example "20200110000" -> "%Y%m%d%H%M"
        """
X
xiexionghang 已提交
285
        self._current_train_time = make_datetime(datetime_str)
X
xiexionghang 已提交
286
        minus = self._current_train_time.hour * 60 + self._current_train_time.minute
X
xiexionghang 已提交
287 288
        self._pass_id = minus / self._interval_per_pass + 1

X
xiexionghang 已提交
289 290 291
    def current_pass(self):
        """R
        """
X
xiexionghang 已提交
292
        return self._pass_id
T
tangwei 已提交
293

X
xiexionghang 已提交
294
    def next(self):
X
xiexionghang 已提交
295 296
        """R
        """
X
xiexionghang 已提交
297 298 299 300 301
        has_next = True
        old_pass_id = self._pass_id
        if self._pass_id < 1:
            self.init_pass_by_time(self._begin_day.strftime("%Y%m%d%H%M"))
        else:
T
tangwei 已提交
302 303
            next_time = self._current_train_time + datetime.timedelta(
                minutes=self._interval_per_pass)
X
xiexionghang 已提交
304 305 306 307
            if (next_time - self._end_day).total_seconds() > 0:
                has_next = False
            else:
                self.init_pass_by_time(next_time.strftime("%Y%m%d%H%M"))
T
tangwei 已提交
308 309
        if has_next and (self._inference_pass_id < self._pass_id or
                         self._pass_id < old_pass_id):
X
xiexionghang 已提交
310 311 312 313
            self._inference_pass_id = self._pass_id - 1
        return has_next

    def is_checkpoint_pass(self, pass_id):
X
xiexionghang 已提交
314 315
        """R
        """
X
xiexionghang 已提交
316 317 318 319 320 321 322
        if pass_id < 1:
            return True
        if pass_id == self.max_pass_num_day():
            return False
        if pass_id % self._checkpoint_interval == 0:
            return True
        return False
T
tangwei 已提交
323

X
xiexionghang 已提交
324
    def need_dump_inference(self, pass_id):
X
xiexionghang 已提交
325 326
        """R
        """
X
xiexionghang 已提交
327 328 329
        return self._inference_pass_id < pass_id and pass_id % self._dump_inference_interval == 0

    def date(self, delta_day=0):
X
xiexionghang 已提交
330 331 332 333 334 335 336
        """
        get train date
        Args:
            delta_day(int): n day afer current_train_date
        Return:
            date(current_train_time + delta_day)
        """
T
tangwei 已提交
337 338
        return (self._current_train_time + datetime.timedelta(days=delta_day)
                ).strftime("%Y%m%d")
X
xiexionghang 已提交
339 340

    def timestamp(self, delta_day=0):
X
xiexionghang 已提交
341 342
        """R
        """
T
tangwei 已提交
343 344
        return (self._current_train_time + datetime.timedelta(days=delta_day)
                ).timestamp()