train.py 3.7 KB
Newer Older
D
dongshuilong 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import, division, print_function

import time
import paddle
from ppcls.engine.train.utils import update_loss, update_metric, log_info
D
dongshuilong 已提交
19
from ppcls.utils import profiler
D
dongshuilong 已提交
20 21


W
weishengyu 已提交
22
def train_epoch(engine, epoch_id, print_batch_step):
D
dongshuilong 已提交
23
    tic = time.time()
W
weishengyu 已提交
24 25
    for iter_id, batch in enumerate(engine.train_dataloader):
        if iter_id >= engine.max_iter:
D
dongshuilong 已提交
26
            break
D
dongshuilong 已提交
27
        profiler.add_profiler_step(engine.config["profiler_options"])
D
dongshuilong 已提交
28
        if iter_id == 5:
W
weishengyu 已提交
29 30 31
            for key in engine.time_info:
                engine.time_info[key].reset()
        engine.time_info["reader_cost"].update(time.time() - tic)
32

D
dongshuilong 已提交
33
        batch_size = batch[0].shape[0]
34
        if not engine.config["Global"].get("use_multilabel", False):
G
gaotingquan 已提交
35
            batch[1] = batch[1].reshape([batch_size, -1])
W
weishengyu 已提交
36
        engine.global_step += 1
37

D
dongshuilong 已提交
38
        # image input
W
weishengyu 已提交
39
        if engine.amp:
40 41 42 43 44 45
            amp_level = engine.config['AMP'].get("level", "O1").upper()
            with paddle.amp.auto_cast(
                    custom_black_list={
                        "flatten_contiguous_range", "greater_than"
                    },
                    level=amp_level):
W
weishengyu 已提交
46
                out = forward(engine, batch)
47
                loss_dict = engine.train_loss_func(out, batch[1])
D
dongshuilong 已提交
48
        else:
W
weishengyu 已提交
49
            out = forward(engine, batch)
50
            loss_dict = engine.train_loss_func(out, batch[1])
D
dongshuilong 已提交
51

52 53 54
        # loss
        loss = loss_dict["loss"] / engine.update_freq

H
HydrogenSulfate 已提交
55
        # backward & step opt
W
weishengyu 已提交
56
        if engine.amp:
57
            scaled = engine.scaler.scale(loss)
D
dongshuilong 已提交
58
            scaled.backward()
59 60 61
            if (iter_id + 1) % engine.update_freq == 0:
                for i in range(len(engine.optimizer)):
                    engine.scaler.minimize(engine.optimizer[i], scaled)
D
dongshuilong 已提交
62
        else:
63 64 65 66 67 68 69
            loss.backward()
            if (iter_id + 1) % engine.update_freq == 0:
                for i in range(len(engine.optimizer)):
                    engine.optimizer[i].step()

        if (iter_id + 1) % engine.update_freq == 0:
            # clear grad
70
            for i in range(len(engine.optimizer)):
71
                engine.optimizer[i].clear_grad()
Y
Yang Nie 已提交
72
            # step lr(by step)
73
            for i in range(len(engine.lr_sch)):
Y
Yang Nie 已提交
74 75
                if not getattr(engine.lr_sch[i], "by_epoch", False):
                    engine.lr_sch[i].step()
76 77 78
            # update ema
            if engine.ema:
                engine.model_ema.update(engine.model)
D
dongshuilong 已提交
79 80 81

        # below code just for logging
        # update metric_for_logger
W
weishengyu 已提交
82
        update_metric(engine, out, batch, batch_size)
D
dongshuilong 已提交
83
        # update_loss_for_logger
W
weishengyu 已提交
84 85
        update_loss(engine, loss_dict, batch_size)
        engine.time_info["batch_cost"].update(time.time() - tic)
D
dongshuilong 已提交
86
        if iter_id % print_batch_step == 0:
W
weishengyu 已提交
87
            log_info(engine, batch_size, epoch_id, iter_id)
D
dongshuilong 已提交
88
        tic = time.time()
D
dongshuilong 已提交
89

H
HydrogenSulfate 已提交
90 91 92 93 94
    # step lr(by epoch)
    for i in range(len(engine.lr_sch)):
        if getattr(engine.lr_sch[i], "by_epoch", False):
            engine.lr_sch[i].step()

D
dongshuilong 已提交
95

C
cuicheng01 已提交
96 97 98
def forward(engine, batch):
    if not engine.is_rec:
        return engine.model(batch[0])
D
dongshuilong 已提交
99
    else:
C
cuicheng01 已提交
100
        return engine.model(batch[0], batch[1])