trainer.py 23.8 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
23 24
import time
import datetime
L
littletomatodonkey 已提交
25 26 27 28
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
littletomatodonkey's avatar
littletomatodonkey 已提交
29
from visualdl import LogWriter
L
littletomatodonkey 已提交
30 31 32 33

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
34 35
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
36 37
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
38
from ppcls.arch import apply_to_static
W
weishengyu 已提交
39
from ppcls.loss import build_loss
W
weishengyu 已提交
40
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
41
from ppcls.optimizer import build_optimizer
C
cuicheng01 已提交
42
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
43
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
44 45
from ppcls.utils import save_load

46 47
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
48
from ppcls.data import create_operators
49

L
littletomatodonkey 已提交
50 51

class Trainer(object):
52
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
53
        self.mode = mode
54
        self.config = config
L
littletomatodonkey 已提交
55
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
56 57 58 59 60

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
61 62 63 64 65 66 67 68
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
69 70 71 72 73 74

        if "Head" in self.config["Arch"]:
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
75
        self.model = build_model(self.config["Arch"])
A
Aurelius84 已提交
76 77
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
78

79
        if self.config["Global"]["pretrained_model"] is not None:
C
cuicheng01 已提交
80 81 82 83 84 85
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    self.model, self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    self.model, self.config["Global"]["pretrained_model"])
86

L
littletomatodonkey 已提交
87 88 89 90
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
littletomatodonkey's avatar
littletomatodonkey 已提交
91
        if self.config['Global']['use_visualdl'] and mode == "train":
L
littletomatodonkey 已提交
92 93 94 95 96 97
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
98 99 100 101 102 103 104 105 106 107 108
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
109 110 111

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
112
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
113
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
114
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
115
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
116
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
117
            if metric_config is not None:
W
dbg  
weishengyu 已提交
118 119 120
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
121

W
weishengyu 已提交
122 123 124
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
125

W
weishengyu 已提交
126
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
143 144 145 146 147 148
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
149 150 151
        # global iter counter
        global_step = 0

152 153 154 155 156 157
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

littletomatodonkey's avatar
littletomatodonkey 已提交
158
        tic = time.time()
159 160 161
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
162
            for iter_id, batch in enumerate(self.train_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
163 164 165 166
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
167
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
168 169
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
170 171
                global_step += 1
                # image input
D
dongshuilong 已提交
172 173 174 175
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
176

L
littletomatodonkey 已提交
177
                # calc loss
W
weishengyu 已提交
178
                loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
179

L
littletomatodonkey 已提交
180 181 182 183 184 185
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
186 187
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
188 189 190 191 192 193
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
194 195 196 197 198 199 200 201
                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
202 203 204 205 206 207
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
208 209 210 211
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
212

littletomatodonkey's avatar
littletomatodonkey 已提交
213 214 215 216 217 218 219 220
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
221 222 223
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
224 225
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
littletomatodonkey's avatar
littletomatodonkey 已提交
226 227 228 229 230 231 232 233 234 235 236 237

                    logger.scaler(
                        name="lr",
                        value=lr_sch.get_lr(),
                        step=global_step,
                        writer=self.vdl_writer)
                    for key in output_info:
                        logger.scaler(
                            name="train_{}".format(key),
                            value=output_info[key].avg,
                            step=global_step,
                            writer=self.vdl_writer)
littletomatodonkey's avatar
littletomatodonkey 已提交
238
                tic = time.time()
L
littletomatodonkey 已提交
239 240 241 242 243

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
244 245
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
246 247 248 249 250
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
251
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
252
                acc = self.eval(epoch_id)
253
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
254 255 256 257 258
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
259
                        best_metric,
L
littletomatodonkey 已提交
260 261 262
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
263
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
264
                    epoch_id, best_metric["metric"]))
littletomatodonkey's avatar
littletomatodonkey 已提交
265 266 267 268 269 270
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

W
weishengyu 已提交
271
                self.model.train()
L
littletomatodonkey 已提交
272 273 274 275 276

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
277 278
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
279 280
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
281 282 283 284 285 286 287 288 289
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
290

littletomatodonkey's avatar
littletomatodonkey 已提交
291 292 293
        if self.vdl_writer is not None:
            self.vdl_writer.close()

L
littletomatodonkey 已提交
294 295 296 297 298 299
    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
300
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
301 302 303 304 305
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
306
        if self.eval_mode == "classification":
W
weishengyu 已提交
307 308 309
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
310 311

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
312
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
313
                if metric_config is not None:
W
dbg  
weishengyu 已提交
314 315 316
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
317

W
weishengyu 已提交
318 319
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
320
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
321 322
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
323
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
324 325 326

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
327
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
328
            # build metric info
W
weishengyu 已提交
329
            if self.eval_metric_func is None:
W
weishengyu 已提交
330 331 332 333 334 335
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
336
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
337 338
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
339 340 341
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
342

littletomatodonkey's avatar
littletomatodonkey 已提交
343
    @paddle.no_grad()
W
weishengyu 已提交
344 345
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
346 347 348 349 350 351
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
352 353 354
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
355
        tic = time.time()
W
weishengyu 已提交
356
        for iter_id, batch in enumerate(self.eval_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
357 358 359 360 361
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
362 363
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
364
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
365
            # image input
D
dongshuilong 已提交
366 367 368 369
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
370 371 372
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
373 374 375 376 377
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
378 379 380 381
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
382
                    for key in metric_dict:
W
weishengyu 已提交
383 384 385 386 387 388 389 390 391 392
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
393

W
weishengyu 已提交
394 395
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
396

littletomatodonkey's avatar
littletomatodonkey 已提交
397 398
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
399
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
400 401 402 403 404 405 406 407
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
408 409 410 411
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
412 413 414 415 416
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
L
littletomatodonkey 已提交
417 418 419 420 421 422 423 424

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
425
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
426 427 428
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
429

W
weishengyu 已提交
430 431
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
432
        cum_similarity_matrix = None
W
weishengyu 已提交
433
        # step1. build gallery
W
weishengyu 已提交
434
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
435
            name='gallery')
W
weishengyu 已提交
436
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
437
            name='query')
B
Bin Lu 已提交
438

W
weishengyu 已提交
439
        # step2. do evaluation
W
dbg  
weishengyu 已提交
440
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
441
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
442
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
443 444
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
445 446 447
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
448
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
449 450
        metric_key = None

F
Felix 已提交
451
        if self.eval_metric_func is None:
W
weishengyu 已提交
452
            metric_dict = {metric_key: 0.}
F
Felix 已提交
453 454 455 456 457 458 459 460 461 462 463 464 465
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
466 467
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
468 469
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
470 471 472

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
473
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
474

F
Felix 已提交
475 476
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
477 478
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
479
                    else:
L
littletomatodonkey 已提交
480 481
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
482

W
dbg  
weishengyu 已提交
483 484 485 486 487 488
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
489
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
490

littletomatodonkey's avatar
littletomatodonkey 已提交
491
        return metric_dict[metric_key]
W
weishengyu 已提交
492 493 494 495

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
496
        all_unique_id = None
W
weishengyu 已提交
497 498 499 500 501 502 503
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
504
        has_unique_id = False
W
weishengyu 已提交
505 506
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
L
littletomatodonkey 已提交
507 508 509 510
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
weishengyu 已提交
511
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
512
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
513
            if len(batch) == 3:
W
weishengyu 已提交
514
                has_unique_id = True
L
littletomatodonkey 已提交
515
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
516 517 518 519 520 521 522 523 524 525 526 527
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
528 529
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
530 531 532 533
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
534 535
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
weishengyu 已提交
536 537 538 539

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
540
            unique_id_list = []
W
weishengyu 已提交
541 542 543 544
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
545 546 547
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
548 549 550

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
551
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
552

553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
C
cuicheng01 已提交
581 582
                if isinstance(out, list):
                    out = out[0]
583 584 585 586
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()