trainer.py 19.1 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
W
weishengyu 已提交
33
from ppcls.loss import build_loss
W
weishengyu 已提交
34
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
35
from ppcls.optimizer import build_optimizer
36 37
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
38 39
from ppcls.utils import save_load

40 41 42 43
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data.reader import create_operators

L
littletomatodonkey 已提交
44 45

class Trainer(object):
46
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
47
        self.mode = mode
48
        self.config = config
L
littletomatodonkey 已提交
49 50 51 52 53 54 55 56 57
        self.output_dir = self.config['Global']['output_dir']
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
58 59 60 61 62 63 64 65

        if "Head" in self.config["Arch"]:
            self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
                "class_num"]
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
66 67
        self.model = build_model(self.config["Arch"])

68 69 70 71
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(self.model,
                                  self.config["Global"]["pretrained_model"])

L
littletomatodonkey 已提交
72 73 74 75 76 77 78 79 80 81 82 83
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
        if self.config['Global']['use_visualdl']:
            from visualdl import LogWriter
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
84 85 86 87 88 89 90 91 92 93 94
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
95 96 97

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
98
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
99
            loss_info = self.config.get("Loss", None)
W
dbg  
weishengyu 已提交
100
            if loss_info is not None:
W
dbg  
weishengyu 已提交
101 102
                loss_info = loss_info["Train"]
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
103 104
        if self.train_metric_func is None:
            metric_config = self.config.get("Metric", None)
W
dbg  
weishengyu 已提交
105
            if metric_config is not None:
W
weishengyu 已提交
106 107
                metric_config = metric_config["Train"]
            self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
108

W
weishengyu 已提交
109 110 111
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
112

W
weishengyu 已提交
113
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
        # global iter counter
        global_step = 0

133 134 135 136 137 138 139 140 141
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
142
            for iter_id, batch in enumerate(self.train_dataloader()):
L
littletomatodonkey 已提交
143 144 145 146 147
                batch_size = batch[0].shape[0]
                batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
                                            .reshape([-1, 1]))
                global_step += 1
                # image input
D
dongshuilong 已提交
148 149 150 151
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
L
littletomatodonkey 已提交
152
                # calc loss
W
weishengyu 已提交
153
                loss_dict = self.train_loss_func(out, batch[1])
L
littletomatodonkey 已提交
154 155 156 157 158 159
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
160 161
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
162 163 164 165 166 167 168 169 170 171 172 173 174 175
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
                    logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
                        epoch_id, iter_id,
W
weishengyu 已提交
176
                        len(self.train_dataloader), lr_msg, metric_msg))
L
littletomatodonkey 已提交
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196

                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
            logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
                                                          metric_msg))
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
                        "eval_during_train"] == 0:
                acc = self.eval(epoch_id)
197
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
198 199 200 201 202
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
203
                        best_metric,
L
littletomatodonkey 已提交
204 205 206
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
W
weishengyu 已提交
207
                self.model.train()
L
littletomatodonkey 已提交
208 209 210 211 212

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
213 214
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
215 216 217 218 219 220 221 222 223 224
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="ppcls_epoch_{}".format(epoch_id))

    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
225 226
        if self.eval_loss_func is None:
            loss_info = self.config.get("Loss", None)
W
dbg  
weishengyu 已提交
227
            if loss_info is not None:
W
weishengyu 已提交
228
                loss_info = loss_info["Eval"]
W
dbg  
weishengyu 已提交
229
                self.eval_loss_func = build_loss(loss_info)
W
weishengyu 已提交
230
        if self.eval_mode == "classification":
W
weishengyu 已提交
231 232 233
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
234 235 236

            if self.eval_metric_func is None:
                metric_config = self.config.get("Metric", None)
W
dbg  
weishengyu 已提交
237
                if metric_config is not None:
W
weishengyu 已提交
238
                    metric_config = metric_config["Eval"]
W
dbg  
weishengyu 已提交
239
                    self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
240

W
weishengyu 已提交
241 242
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
243
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
244 245
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
246
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
247 248 249

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
250
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
251
            # build metric info
W
weishengyu 已提交
252
            if self.eval_metric_func is None:
W
weishengyu 已提交
253 254 255 256 257 258
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
259
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
260 261
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
262 263 264
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
265 266 267

    def eval_cls(self, epoch_id=0):
        output_info = dict()
L
littletomatodonkey 已提交
268 269 270
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
W
weishengyu 已提交
271
        for iter_id, batch in enumerate(self.eval_dataloader()):
L
littletomatodonkey 已提交
272 273 274 275
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
            batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
            # image input
D
dongshuilong 已提交
276 277 278 279
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
280 281 282
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
283 284 285 286 287
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
288 289 290 291
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
292
                    for key in metric_dict:
W
weishengyu 已提交
293 294 295 296 297 298 299 300 301 302
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
303

W
weishengyu 已提交
304 305
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
306 307 308 309 310 311 312

            if iter_id % print_batch_step == 0:
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
W
weishengyu 已提交
313
                    epoch_id, iter_id, len(self.eval_dataloader), metric_msg))
L
littletomatodonkey 已提交
314 315 316 317 318 319 320 321

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
322
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
323 324 325
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
326

W
weishengyu 已提交
327 328
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
329
        cum_similarity_matrix = None
W
weishengyu 已提交
330 331 332 333 334
        # step1. build gallery
        gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
            name='gallery')
        query_feas, query_img_id, query_camera_id = self._cal_feature(
            name='query')
W
dbg  
weishengyu 已提交
335 336 337
        gallery_img_id = gallery_img_id
        # if gallery_camera_id is not None:
        #     gallery_camera_id = gallery_camera_id
W
weishengyu 已提交
338
        # step2. do evaluation
W
dbg  
weishengyu 已提交
339
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
340
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
341
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
342 343
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
dbg  
weishengyu 已提交
344 345 346 347
        # if query_camera_id is not None:
        #     camera_id_blocks = paddle.split(
        #         query_camera_id, num_or_sections=sections)
        # image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
348 349 350
        metric_key = None

        for block_idx, block_fea in enumerate(fea_blocks):
W
weishengyu 已提交
351
            similarity_matrix = paddle.matmul(
W
weishengyu 已提交
352
                block_fea, gallery_feas, transpose_y=True)
W
dbg  
weishengyu 已提交
353 354 355 356 357
            if query_camera_id is not None:
                camera_id_block = camera_id_blocks[block_idx]
                camera_id_mask = (camera_id_block != gallery_camera_id)
                similarity_matrix = similarity_matrix.masked_select(
                    camera_id_mask)
W
dbg  
weishengyu 已提交
358
            if cum_similarity_matrix is None:
W
weishengyu 已提交
359 360
                cum_similarity_matrix = similarity_matrix
            else:
W
dbg  
weishengyu 已提交
361 362
                cum_similarity_matrix = paddle.concat(
                    [cum_similarity_matrix, similarity_matrix], axis=0)
W
weishengyu 已提交
363

W
weishengyu 已提交
364 365 366 367 368 369
        # calc metric
        if self.eval_metric_func is not None:
            metric_dict = self.eval_metric_func(cum_similarity_matrix,
                                                query_img_id, gallery_img_id)
        else:
            metric_dict = {metric_key: 0.}
W
weishengyu 已提交
370 371 372 373 374
        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, metric_dict[key].avg)
            for key in metric_dict
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
375

W
weishengyu 已提交
376
        return metric_dict[metric_key]
W
weishengyu 已提交
377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
        all_camera_id = None
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

        has_cam_id = False
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
            batch = [paddle.to_tensor(x) for x in batch]
            batch[1] = batch[1].reshape([-1, 1])
            if len(batch) == 3:
                has_cam_id = True
                batch[2] = batch[2].reshape([-1, 1])
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
                if has_cam_id:
                    all_camera_id = batch[2]
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
                if has_cam_id:
                    all_camera_id = paddle.concat([all_camera_id, batch[2]])

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
            cam_id_list = []
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
            if has_cam_id:
                paddle.distributed.all_gather(cam_id_list, all_camera_id)
                all_camera_id = paddle.concat(cam_id_list, axis=0)

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
        return all_feas, all_image_id, all_camera_id

434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()