trainer.py 27.1 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
W
Walter 已提交
20

L
littletomatodonkey 已提交
21 22 23
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
24
import time
25
import platform
littletomatodonkey's avatar
littletomatodonkey 已提交
26
import datetime
L
littletomatodonkey 已提交
27 28 29 30
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
littletomatodonkey's avatar
littletomatodonkey 已提交
31
from visualdl import LogWriter
L
littletomatodonkey 已提交
32 33 34 35

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
36 37
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
38 39
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
40
from ppcls.arch import apply_to_static
W
weishengyu 已提交
41
from ppcls.loss import build_loss
W
weishengyu 已提交
42
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
43
from ppcls.optimizer import build_optimizer
C
cuicheng01 已提交
44
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
45
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
46 47
from ppcls.utils import save_load

48 49
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
50
from ppcls.data import create_operators
51

L
littletomatodonkey 已提交
52 53

class Trainer(object):
54
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
55
        self.mode = mode
56
        self.config = config
L
littletomatodonkey 已提交
57
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
58 59 60 61 62

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
63 64 65 66 67 68 69 70
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
71 72 73 74 75 76

        if "Head" in self.config["Arch"]:
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
77
        self.model = build_model(self.config["Arch"])
A
Aurelius84 已提交
78 79
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
80

81
        if self.config["Global"]["pretrained_model"] is not None:
C
cuicheng01 已提交
82 83 84 85 86 87
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    self.model, self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    self.model, self.config["Global"]["pretrained_model"])
88

L
littletomatodonkey 已提交
89 90 91 92
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
littletomatodonkey's avatar
littletomatodonkey 已提交
93
        if self.config['Global']['use_visualdl'] and mode == "train":
L
littletomatodonkey 已提交
94 95 96 97 98 99
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
100 101 102 103 104 105 106
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
W
Walter 已提交
107 108 109 110 111 112 113 114
        self.amp = True if "AMP" in self.config else False
        if self.amp and self.config["AMP"] is not None:
            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
        else:
            self.scale_loss = 1.0
            self.use_dynamic_loss_scaling = False
Z
zhiqiu 已提交
115 116 117 118 119 120
        if self.amp:
            AMP_RELATED_FLAGS_SETTING = {
                'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
                'FLAGS_max_inplace_grad_add': 8,
            }
            paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
W
weishengyu 已提交
121 122 123 124
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
W
Walter 已提交
125
        self.use_dali = self.config['Global'].get("use_dali", False)
L
littletomatodonkey 已提交
126 127 128

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
129
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
130
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
131
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
132
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
133
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
134
            if metric_config is not None:
W
dbg  
weishengyu 已提交
135 136 137
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
138

W
weishengyu 已提交
139
        if self.train_dataloader is None:
W
Walter 已提交
140 141
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
L
littletomatodonkey 已提交
142

W
weishengyu 已提交
143
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
144 145 146 147 148 149 150 151 152 153 154 155 156

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
W
Walter 已提交
157
        # key:
L
littletomatodonkey 已提交
158 159
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
160 161 162 163 164 165
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
166 167 168
        # global iter counter
        global_step = 0

169 170 171 172 173 174
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

W
Walter 已提交
175 176 177 178 179 180
        # for amp training
        if self.amp:
            scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)

littletomatodonkey's avatar
littletomatodonkey 已提交
181
        tic = time.time()
182 183
        max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
184 185 186
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
Walter 已提交
187 188 189
            train_dataloader = self.train_dataloader if self.use_dali else self.train_dataloader(
            )
            for iter_id, batch in enumerate(train_dataloader):
190 191
                if iter_id >= max_iter:
                    break
littletomatodonkey's avatar
littletomatodonkey 已提交
192 193 194 195
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
W
Walter 已提交
196 197 198 199 200
                if self.use_dali:
                    batch = [
                        paddle.to_tensor(batch[0]['data']),
                        paddle.to_tensor(batch[0]['label'])
                    ]
L
littletomatodonkey 已提交
201
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
202 203
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
204 205
                global_step += 1
                # image input
W
Walter 已提交
206 207 208 209 210 211
                if self.amp:
                    with paddle.amp.auto_cast(custom_black_list={
                            "flatten_contiguous_range", "greater_than"
                    }):
                        out = self.forward(batch)
                        loss_dict = self.train_loss_func(out, batch[1])
D
dongshuilong 已提交
212
                else:
W
Walter 已提交
213 214
                    out = self.forward(batch)

L
littletomatodonkey 已提交
215
                # calc loss
C
cuicheng01 已提交
216 217 218 219 220
                if self.config["DataLoader"]["Train"]["dataset"].get(
                        "batch_transform_ops", None):
                    loss_dict = self.train_loss_func(out, batch[1:])
                else:
                    loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
221

L
littletomatodonkey 已提交
222 223 224 225 226 227
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
228 229
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
230 231 232 233 234 235
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
236
                # step opt and lr
W
Walter 已提交
237 238 239 240 241 242 243
                if self.amp:
                    scaled = scaler.scale(loss_dict["loss"])
                    scaled.backward()
                    scaler.minimize(optimizer, scaled)
                else:
                    loss_dict["loss"].backward()
                    optimizer.step()
littletomatodonkey's avatar
littletomatodonkey 已提交
244 245 246 247 248
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
249 250 251 252 253 254
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
255 256 257 258
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
259

littletomatodonkey's avatar
littletomatodonkey 已提交
260 261 262 263 264 265 266 267
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
268 269 270
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
271 272
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
littletomatodonkey's avatar
littletomatodonkey 已提交
273 274 275 276 277 278 279 280 281 282 283 284

                    logger.scaler(
                        name="lr",
                        value=lr_sch.get_lr(),
                        step=global_step,
                        writer=self.vdl_writer)
                    for key in output_info:
                        logger.scaler(
                            name="train_{}".format(key),
                            value=output_info[key].avg,
                            step=global_step,
                            writer=self.vdl_writer)
littletomatodonkey's avatar
littletomatodonkey 已提交
285
                tic = time.time()
W
Walter 已提交
286 287
            if self.use_dali:
                self.train_dataloader.reset()
L
littletomatodonkey 已提交
288 289 290 291
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
292 293
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
294 295 296 297 298
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
299
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
300
                acc = self.eval(epoch_id)
301
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
302 303 304 305 306
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
307
                        best_metric,
L
littletomatodonkey 已提交
308 309 310
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
311
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
312
                    epoch_id, best_metric["metric"]))
littletomatodonkey's avatar
littletomatodonkey 已提交
313 314 315 316 317 318
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

W
weishengyu 已提交
319
                self.model.train()
L
littletomatodonkey 已提交
320 321 322 323 324

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
325 326
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
327 328
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
329 330 331 332 333 334 335 336 337
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
338

littletomatodonkey's avatar
littletomatodonkey 已提交
339 340 341
        if self.vdl_writer is not None:
            self.vdl_writer.close()

L
littletomatodonkey 已提交
342 343 344 345 346 347
    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
348
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
349 350 351 352 353
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
354
        if self.eval_mode == "classification":
W
weishengyu 已提交
355 356
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
W
Walter 已提交
357 358
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
W
weishengyu 已提交
359 360

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
361
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
362
                if metric_config is not None:
W
dbg  
weishengyu 已提交
363 364 365
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
366

W
weishengyu 已提交
367 368
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
369
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
370 371
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
Walter 已提交
372 373
                    self.config["DataLoader"]["Eval"], "Gallery", self.device,
                    self.use_dali)
W
weishengyu 已提交
374 375 376

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
Walter 已提交
377 378
                    self.config["DataLoader"]["Eval"], "Query", self.device,
                    self.use_dali)
W
weishengyu 已提交
379
            # build metric info
W
weishengyu 已提交
380
            if self.eval_metric_func is None:
W
weishengyu 已提交
381 382 383 384 385 386
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
387
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
388 389
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
390 391 392
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
393

W
weishengyu 已提交
394
    def forward(self, batch):
W
Walter 已提交
395
        if not self.is_rec:
W
weishengyu 已提交
396
            out = self.model(batch[0])
W
Walter 已提交
397
        else:
W
weishengyu 已提交
398
            out = self.model(batch[0], batch[1])
W
Walter 已提交
399 400
        return out

littletomatodonkey's avatar
littletomatodonkey 已提交
401
    @paddle.no_grad()
W
weishengyu 已提交
402 403
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
404 405 406 407 408 409
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
410 411 412
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
413
        tic = time.time()
W
Walter 已提交
414 415
        eval_dataloader = self.eval_dataloader if self.use_dali else self.eval_dataloader(
        )
416 417
        max_iter = len(self.eval_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.eval_dataloader)
W
Walter 已提交
418
        for iter_id, batch in enumerate(eval_dataloader):
419 420
            if iter_id >= max_iter:
                break
littletomatodonkey's avatar
littletomatodonkey 已提交
421 422 423
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()
W
Walter 已提交
424 425 426 427 428
            if self.use_dali:
                batch = [
                    paddle.to_tensor(batch[0]['data']),
                    paddle.to_tensor(batch[0]['label'])
                ]
littletomatodonkey's avatar
littletomatodonkey 已提交
429
            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
430 431
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
432
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
433
            # image input
W
Walter 已提交
434
            out = self.forward(batch)
W
weishengyu 已提交
435 436 437
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
438 439 440 441 442
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
443 444 445 446
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
447
                    for key in metric_dict:
W
weishengyu 已提交
448 449 450 451 452 453 454 455 456 457
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
458

W
weishengyu 已提交
459 460
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
461

littletomatodonkey's avatar
littletomatodonkey 已提交
462 463
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
464
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
465 466 467 468 469 470 471 472
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
473 474 475 476
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
477 478 479 480 481
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
W
Walter 已提交
482 483
        if self.use_dali:
            self.eval_dataloader.reset()
L
littletomatodonkey 已提交
484 485 486 487 488 489 490
        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
491
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
492 493 494
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
495

W
weishengyu 已提交
496 497 498
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
        # step1. build gallery
W
weishengyu 已提交
499
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
500
            name='gallery')
W
weishengyu 已提交
501
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
502
            name='query')
B
Bin Lu 已提交
503

W
weishengyu 已提交
504
        # step2. do evaluation
W
dbg  
weishengyu 已提交
505
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
506
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
507
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
508 509
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
510 511 512
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
513
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
514 515
        metric_key = None

F
Felix 已提交
516
        if self.eval_metric_func is None:
W
weishengyu 已提交
517
            metric_dict = {metric_key: 0.}
F
Felix 已提交
518 519 520 521 522 523 524 525 526 527 528 529 530
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
531 532
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
533 534
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
535 536 537

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
538
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
539

F
Felix 已提交
540 541
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
542 543
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
544
                    else:
L
littletomatodonkey 已提交
545 546
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
547

W
dbg  
weishengyu 已提交
548 549 550 551 552 553
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
554
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
555

littletomatodonkey's avatar
littletomatodonkey 已提交
556
        return metric_dict[metric_key]
W
weishengyu 已提交
557 558 559 560

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
561
        all_unique_id = None
W
weishengyu 已提交
562 563 564 565 566 567 568
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
569
        has_unique_id = False
570 571
        max_iter = len(dataloader) - 1 if platform.system(
        ) == "Windows" else len(dataloader)
W
Walter 已提交
572 573 574
        dataloader_tmp = dataloader if self.use_dali else dataloader()
        for idx, batch in enumerate(
                dataloader_tmp):  # load is very time-consuming
575 576
            if idx >= max_iter:
                break
L
littletomatodonkey 已提交
577 578 579 580
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
Walter 已提交
581 582 583 584 585
            if self.use_dali:
                batch = [
                    paddle.to_tensor(batch[0]['data']),
                    paddle.to_tensor(batch[0]['label'])
                ]
W
weishengyu 已提交
586
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
587
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
588
            if len(batch) == 3:
W
weishengyu 已提交
589
                has_unique_id = True
L
littletomatodonkey 已提交
590
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
591
            out = self.forward(batch)
W
weishengyu 已提交
592 593 594 595 596 597 598 599 600 601 602
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
603 604
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
605 606 607 608
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
609 610
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
Walter 已提交
611 612
        if self.use_dali:
            dataloader_tmp.reset()
W
weishengyu 已提交
613 614 615
        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
616
            unique_id_list = []
W
weishengyu 已提交
617 618 619 620
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
621 622 623
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
624 625 626

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
627
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
628

629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
W
weishengyu 已提交
656
                out = self.forward([batch_tensor])
C
cuicheng01 已提交
657 658
                if isinstance(out, list):
                    out = out[0]
659 660 661 662
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()