trainer.py 27.3 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
W
Walter 已提交
20

L
littletomatodonkey 已提交
21 22 23
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
24
import time
25
import platform
littletomatodonkey's avatar
littletomatodonkey 已提交
26
import datetime
L
littletomatodonkey 已提交
27 28 29 30
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
littletomatodonkey's avatar
littletomatodonkey 已提交
31
from visualdl import LogWriter
L
littletomatodonkey 已提交
32 33 34 35

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
36 37
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
38 39
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
40
from ppcls.arch import apply_to_static
W
weishengyu 已提交
41
from ppcls.loss import build_loss
W
weishengyu 已提交
42
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
43
from ppcls.optimizer import build_optimizer
C
cuicheng01 已提交
44
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
45
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
46 47
from ppcls.utils import save_load

48 49
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
50
from ppcls.data import create_operators
51

L
littletomatodonkey 已提交
52 53

class Trainer(object):
54
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
55
        self.mode = mode
56
        self.config = config
L
littletomatodonkey 已提交
57
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
58 59 60 61 62

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
63 64 65 66 67 68 69 70
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
71 72 73 74 75 76

        if "Head" in self.config["Arch"]:
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
77
        self.model = build_model(self.config["Arch"])
W
weishengyu 已提交
78 79
        if "return_pattern" in self.config["Arch"]:
            self.return_inter = True
A
Aurelius84 已提交
80 81
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
82

83
        if self.config["Global"]["pretrained_model"] is not None:
C
cuicheng01 已提交
84 85 86 87 88 89
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    self.model, self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    self.model, self.config["Global"]["pretrained_model"])
90

L
littletomatodonkey 已提交
91 92 93 94
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
littletomatodonkey's avatar
littletomatodonkey 已提交
95
        if self.config['Global']['use_visualdl'] and mode == "train":
L
littletomatodonkey 已提交
96 97 98 99 100 101
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
102 103 104 105 106 107 108
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
W
Walter 已提交
109 110 111 112 113 114 115 116
        self.amp = True if "AMP" in self.config else False
        if self.amp and self.config["AMP"] is not None:
            self.scale_loss = self.config["AMP"].get("scale_loss", 1.0)
            self.use_dynamic_loss_scaling = self.config["AMP"].get(
                "use_dynamic_loss_scaling", False)
        else:
            self.scale_loss = 1.0
            self.use_dynamic_loss_scaling = False
Z
zhiqiu 已提交
117 118 119 120 121 122
        if self.amp:
            AMP_RELATED_FLAGS_SETTING = {
                'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
                'FLAGS_max_inplace_grad_add': 8,
            }
            paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
W
weishengyu 已提交
123 124 125 126
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
W
Walter 已提交
127
        self.use_dali = self.config['Global'].get("use_dali", False)
L
littletomatodonkey 已提交
128 129 130

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
131
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
132
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
133
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
134
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
135
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
136
            if metric_config is not None:
W
dbg  
weishengyu 已提交
137 138 139
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
140

W
weishengyu 已提交
141
        if self.train_dataloader is None:
W
Walter 已提交
142 143
            self.train_dataloader = build_dataloader(
                self.config["DataLoader"], "Train", self.device, self.use_dali)
L
littletomatodonkey 已提交
144

W
weishengyu 已提交
145
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
146 147 148 149 150 151 152 153 154 155 156 157 158

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
W
Walter 已提交
159
        # key:
L
littletomatodonkey 已提交
160 161
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
162 163 164 165 166 167
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
168 169 170
        # global iter counter
        global_step = 0

171 172 173 174 175 176
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

W
Walter 已提交
177 178 179 180 181 182
        # for amp training
        if self.amp:
            scaler = paddle.amp.GradScaler(
                init_loss_scaling=self.scale_loss,
                use_dynamic_loss_scaling=self.use_dynamic_loss_scaling)

littletomatodonkey's avatar
littletomatodonkey 已提交
183
        tic = time.time()
184 185
        max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
186 187 188
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
Walter 已提交
189 190 191
            train_dataloader = self.train_dataloader if self.use_dali else self.train_dataloader(
            )
            for iter_id, batch in enumerate(train_dataloader):
192 193
                if iter_id >= max_iter:
                    break
littletomatodonkey's avatar
littletomatodonkey 已提交
194 195 196 197
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
W
Walter 已提交
198 199 200 201 202
                if self.use_dali:
                    batch = [
                        paddle.to_tensor(batch[0]['data']),
                        paddle.to_tensor(batch[0]['label'])
                    ]
L
littletomatodonkey 已提交
203
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
204 205
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
206 207
                global_step += 1
                # image input
W
Walter 已提交
208 209 210 211 212 213
                if self.amp:
                    with paddle.amp.auto_cast(custom_black_list={
                            "flatten_contiguous_range", "greater_than"
                    }):
                        out = self.forward(batch)
                        loss_dict = self.train_loss_func(out, batch[1])
D
dongshuilong 已提交
214
                else:
W
Walter 已提交
215 216
                    out = self.forward(batch)

L
littletomatodonkey 已提交
217
                # calc loss
C
cuicheng01 已提交
218 219 220 221 222
                if self.config["DataLoader"]["Train"]["dataset"].get(
                        "batch_transform_ops", None):
                    loss_dict = self.train_loss_func(out, batch[1:])
                else:
                    loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
223

L
littletomatodonkey 已提交
224 225 226 227 228 229
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
230 231
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
232 233 234 235 236 237
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
238
                # step opt and lr
W
Walter 已提交
239 240 241 242 243 244 245
                if self.amp:
                    scaled = scaler.scale(loss_dict["loss"])
                    scaled.backward()
                    scaler.minimize(optimizer, scaled)
                else:
                    loss_dict["loss"].backward()
                    optimizer.step()
littletomatodonkey's avatar
littletomatodonkey 已提交
246 247 248 249 250
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
251 252 253 254 255 256
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
257 258 259 260
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
261

littletomatodonkey's avatar
littletomatodonkey 已提交
262 263 264 265 266 267 268 269
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
270 271 272
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
273 274
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
littletomatodonkey's avatar
littletomatodonkey 已提交
275 276 277 278 279 280 281 282 283 284 285 286

                    logger.scaler(
                        name="lr",
                        value=lr_sch.get_lr(),
                        step=global_step,
                        writer=self.vdl_writer)
                    for key in output_info:
                        logger.scaler(
                            name="train_{}".format(key),
                            value=output_info[key].avg,
                            step=global_step,
                            writer=self.vdl_writer)
littletomatodonkey's avatar
littletomatodonkey 已提交
287
                tic = time.time()
W
Walter 已提交
288 289
            if self.use_dali:
                self.train_dataloader.reset()
L
littletomatodonkey 已提交
290 291 292 293
            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
294 295
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
296 297 298 299 300
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
301
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
302
                acc = self.eval(epoch_id)
303
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
304 305 306 307 308
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
309
                        best_metric,
L
littletomatodonkey 已提交
310 311 312
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
313
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
314
                    epoch_id, best_metric["metric"]))
littletomatodonkey's avatar
littletomatodonkey 已提交
315 316 317 318 319 320
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

W
weishengyu 已提交
321
                self.model.train()
L
littletomatodonkey 已提交
322 323 324 325 326

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
327 328
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
329 330
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
331 332 333 334 335 336 337 338 339
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
340

littletomatodonkey's avatar
littletomatodonkey 已提交
341 342 343
        if self.vdl_writer is not None:
            self.vdl_writer.close()

L
littletomatodonkey 已提交
344 345 346 347 348 349
    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
350
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
351 352 353 354 355
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
356
        if self.eval_mode == "classification":
W
weishengyu 已提交
357 358
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
W
Walter 已提交
359 360
                    self.config["DataLoader"], "Eval", self.device,
                    self.use_dali)
W
weishengyu 已提交
361 362

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
363
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
364
                if metric_config is not None:
W
dbg  
weishengyu 已提交
365 366 367
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
368

W
weishengyu 已提交
369 370
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
371
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
372 373
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
Walter 已提交
374 375
                    self.config["DataLoader"]["Eval"], "Gallery", self.device,
                    self.use_dali)
W
weishengyu 已提交
376 377 378

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
Walter 已提交
379 380
                    self.config["DataLoader"]["Eval"], "Query", self.device,
                    self.use_dali)
W
weishengyu 已提交
381
            # build metric info
W
weishengyu 已提交
382
            if self.eval_metric_func is None:
W
weishengyu 已提交
383 384 385 386 387 388
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
389
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
390 391
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
392 393 394
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
395

W
Walter 已提交
396
    def forward(self, batch):
W
weishengyu 已提交
397 398 399 400
        if self.return_inter:
            return_dict = {}
        else:
            return_dict = None
W
Walter 已提交
401
        if not self.is_rec:
W
weishengyu 已提交
402
            out = self.model(batch[0], return_dict=return_dict)
W
Walter 已提交
403
        else:
W
weishengyu 已提交
404
            out = self.model(batch[0], batch[1], return_dict=return_dict)
W
Walter 已提交
405 406
        return out

littletomatodonkey's avatar
littletomatodonkey 已提交
407
    @paddle.no_grad()
W
weishengyu 已提交
408 409
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
410 411 412 413 414 415
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
416 417 418
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
419
        tic = time.time()
W
Walter 已提交
420 421
        eval_dataloader = self.eval_dataloader if self.use_dali else self.eval_dataloader(
        )
422 423
        max_iter = len(self.eval_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.eval_dataloader)
W
Walter 已提交
424
        for iter_id, batch in enumerate(eval_dataloader):
425 426
            if iter_id >= max_iter:
                break
littletomatodonkey's avatar
littletomatodonkey 已提交
427 428 429
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()
W
Walter 已提交
430 431 432 433 434
            if self.use_dali:
                batch = [
                    paddle.to_tensor(batch[0]['data']),
                    paddle.to_tensor(batch[0]['label'])
                ]
littletomatodonkey's avatar
littletomatodonkey 已提交
435
            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
436 437
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
438
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
439
            # image input
W
Walter 已提交
440
            out = self.forward(batch)
W
weishengyu 已提交
441 442 443
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
444 445 446 447 448
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
449 450 451 452
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
453
                    for key in metric_dict:
W
weishengyu 已提交
454 455 456 457 458 459 460 461 462 463
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
464

W
weishengyu 已提交
465 466
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
467

littletomatodonkey's avatar
littletomatodonkey 已提交
468 469
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
470
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
471 472 473 474 475 476 477 478
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
479 480 481 482
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
483 484 485 486 487
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
W
Walter 已提交
488 489
        if self.use_dali:
            self.eval_dataloader.reset()
L
littletomatodonkey 已提交
490 491 492 493 494 495 496
        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
497
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
498 499 500
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
501

W
weishengyu 已提交
502 503 504
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
        # step1. build gallery
W
weishengyu 已提交
505
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
506
            name='gallery')
W
weishengyu 已提交
507
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
508
            name='query')
B
Bin Lu 已提交
509

W
weishengyu 已提交
510
        # step2. do evaluation
W
dbg  
weishengyu 已提交
511
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
512
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
513
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
514 515
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
516 517 518
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
519
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
520 521
        metric_key = None

F
Felix 已提交
522
        if self.eval_metric_func is None:
W
weishengyu 已提交
523
            metric_dict = {metric_key: 0.}
F
Felix 已提交
524 525 526 527 528 529 530 531 532 533 534 535 536
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
537 538
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
539 540
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
541 542 543

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
544
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
545

F
Felix 已提交
546 547
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
548 549
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
550
                    else:
L
littletomatodonkey 已提交
551 552
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
553

W
dbg  
weishengyu 已提交
554 555 556 557 558 559
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
560
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
561

littletomatodonkey's avatar
littletomatodonkey 已提交
562
        return metric_dict[metric_key]
W
weishengyu 已提交
563 564 565 566

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
567
        all_unique_id = None
W
weishengyu 已提交
568 569 570 571 572 573 574
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
575
        has_unique_id = False
576 577
        max_iter = len(dataloader) - 1 if platform.system(
        ) == "Windows" else len(dataloader)
W
Walter 已提交
578 579 580
        dataloader_tmp = dataloader if self.use_dali else dataloader()
        for idx, batch in enumerate(
                dataloader_tmp):  # load is very time-consuming
581 582
            if idx >= max_iter:
                break
L
littletomatodonkey 已提交
583 584 585 586
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
Walter 已提交
587 588 589 590 591
            if self.use_dali:
                batch = [
                    paddle.to_tensor(batch[0]['data']),
                    paddle.to_tensor(batch[0]['label'])
                ]
W
weishengyu 已提交
592
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
593
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
594
            if len(batch) == 3:
W
weishengyu 已提交
595
                has_unique_id = True
L
littletomatodonkey 已提交
596
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
597 598 599 600 601 602 603 604 605 606 607 608
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
609 610
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
611 612 613 614
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
615 616
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
Walter 已提交
617 618
        if self.use_dali:
            dataloader_tmp.reset()
W
weishengyu 已提交
619 620 621
        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
622
            unique_id_list = []
W
weishengyu 已提交
623 624 625 626
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
627 628 629
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
630 631 632

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
633
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
634

635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
C
cuicheng01 已提交
663 664
                if isinstance(out, list):
                    out = out[0]
665 666 667 668
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()