trainer.py 24.6 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
23
import time
24
import platform
littletomatodonkey's avatar
littletomatodonkey 已提交
25
import datetime
L
littletomatodonkey 已提交
26 27 28 29
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist
littletomatodonkey's avatar
littletomatodonkey 已提交
30
from visualdl import LogWriter
L
littletomatodonkey 已提交
31 32 33 34

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
35 36
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
37 38
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
39
from ppcls.arch import apply_to_static
W
weishengyu 已提交
40
from ppcls.loss import build_loss
W
weishengyu 已提交
41
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
42
from ppcls.optimizer import build_optimizer
C
cuicheng01 已提交
43
from ppcls.utils.save_load import load_dygraph_pretrain, load_dygraph_pretrain_from_url
44
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
45 46
from ppcls.utils import save_load

47 48
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
49
from ppcls.data import create_operators
50

L
littletomatodonkey 已提交
51 52

class Trainer(object):
53
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
54
        self.mode = mode
55
        self.config = config
L
littletomatodonkey 已提交
56
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
57 58 59 60 61

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
62 63 64 65 66 67 68 69
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
70 71 72 73 74 75

        if "Head" in self.config["Arch"]:
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
76
        self.model = build_model(self.config["Arch"])
A
Aurelius84 已提交
77 78
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
79

80
        if self.config["Global"]["pretrained_model"] is not None:
C
cuicheng01 已提交
81 82 83 84 85 86
            if self.config["Global"]["pretrained_model"].startswith("http"):
                load_dygraph_pretrain_from_url(
                    self.model, self.config["Global"]["pretrained_model"])
            else:
                load_dygraph_pretrain(
                    self.model, self.config["Global"]["pretrained_model"])
87

L
littletomatodonkey 已提交
88 89 90 91
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
littletomatodonkey's avatar
littletomatodonkey 已提交
92
        if self.config['Global']['use_visualdl'] and mode == "train":
L
littletomatodonkey 已提交
93 94 95 96 97 98
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
99 100 101 102 103 104 105 106 107 108 109
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
110 111 112

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
113
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
114
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
115
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
116
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
117
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
118
            if metric_config is not None:
W
dbg  
weishengyu 已提交
119 120 121
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
122

W
weishengyu 已提交
123 124 125
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
126

W
weishengyu 已提交
127
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
144 145 146 147 148 149
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
150 151 152
        # global iter counter
        global_step = 0

153 154 155 156 157 158
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

littletomatodonkey's avatar
littletomatodonkey 已提交
159
        tic = time.time()
160 161
        max_iter = len(self.train_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.train_dataloader)
162 163 164
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
165
            for iter_id, batch in enumerate(self.train_dataloader()):
166 167
                if iter_id >= max_iter:
                    break
littletomatodonkey's avatar
littletomatodonkey 已提交
168 169 170 171
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
172
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
173 174
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
175 176
                global_step += 1
                # image input
D
dongshuilong 已提交
177 178 179 180
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
L
littletomatodonkey 已提交
181
                # calc loss
C
cuicheng01 已提交
182 183 184 185 186
                if self.config["DataLoader"]["Train"]["dataset"].get(
                        "batch_transform_ops", None):
                    loss_dict = self.train_loss_func(out, batch[1:])
                else:
                    loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
187

L
littletomatodonkey 已提交
188 189 190 191 192 193
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
194 195
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
196 197 198 199 200 201
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
202 203 204 205 206 207 208 209
                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
210 211 212 213 214 215
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
216 217 218 219
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
220

littletomatodonkey's avatar
littletomatodonkey 已提交
221 222 223 224 225 226 227 228
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
229 230 231
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
232 233
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
littletomatodonkey's avatar
littletomatodonkey 已提交
234 235 236 237 238 239 240 241 242 243 244 245

                    logger.scaler(
                        name="lr",
                        value=lr_sch.get_lr(),
                        step=global_step,
                        writer=self.vdl_writer)
                    for key in output_info:
                        logger.scaler(
                            name="train_{}".format(key),
                            value=output_info[key].avg,
                            step=global_step,
                            writer=self.vdl_writer)
littletomatodonkey's avatar
littletomatodonkey 已提交
246
                tic = time.time()
L
littletomatodonkey 已提交
247 248 249 250 251

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
252 253
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
254 255 256 257 258
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
259
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
260
                acc = self.eval(epoch_id)
261
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
262 263 264 265 266
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
267
                        best_metric,
L
littletomatodonkey 已提交
268 269 270
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
271
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
272
                    epoch_id, best_metric["metric"]))
littletomatodonkey's avatar
littletomatodonkey 已提交
273 274 275 276 277 278
                logger.scaler(
                    name="eval_acc",
                    value=acc,
                    step=epoch_id,
                    writer=self.vdl_writer)

W
weishengyu 已提交
279
                self.model.train()
L
littletomatodonkey 已提交
280 281 282 283 284

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
285 286
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
287 288
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
289 290 291 292 293 294 295 296 297
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
298

littletomatodonkey's avatar
littletomatodonkey 已提交
299 300 301
        if self.vdl_writer is not None:
            self.vdl_writer.close()

L
littletomatodonkey 已提交
302 303 304 305 306 307
    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
308
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
309 310 311 312 313
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
314
        if self.eval_mode == "classification":
W
weishengyu 已提交
315 316 317
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
318 319

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
320
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
321
                if metric_config is not None:
W
dbg  
weishengyu 已提交
322 323 324
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
325

W
weishengyu 已提交
326 327
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
328
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
329 330
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
331
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
332 333 334

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
335
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
336
            # build metric info
W
weishengyu 已提交
337
            if self.eval_metric_func is None:
W
weishengyu 已提交
338 339 340 341 342 343
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
344
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
345 346
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
347 348 349
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
350

littletomatodonkey's avatar
littletomatodonkey 已提交
351
    @paddle.no_grad()
W
weishengyu 已提交
352 353
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
354 355 356 357 358 359
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
360 361 362
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
363
        tic = time.time()
364 365
        max_iter = len(self.eval_dataloader) - 1 if platform.system(
        ) == "Windows" else len(self.eval_dataloader)
W
weishengyu 已提交
366
        for iter_id, batch in enumerate(self.eval_dataloader()):
367 368
            if iter_id >= max_iter:
                break
littletomatodonkey's avatar
littletomatodonkey 已提交
369 370 371 372 373
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
374 375
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
376
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
377
            # image input
D
dongshuilong 已提交
378 379 380 381
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
382 383 384
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
385 386 387 388 389
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
390 391 392 393
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
394
                    for key in metric_dict:
W
weishengyu 已提交
395 396 397 398 399 400 401 402 403 404
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
405

W
weishengyu 已提交
406 407
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
408

littletomatodonkey's avatar
littletomatodonkey 已提交
409 410
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
411
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
412 413 414 415 416 417 418 419
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
420 421 422 423
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
424 425 426 427 428
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
L
littletomatodonkey 已提交
429 430 431 432 433 434 435 436

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
437
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
438 439 440
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
441

W
weishengyu 已提交
442 443
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
444
        cum_similarity_matrix = None
W
weishengyu 已提交
445
        # step1. build gallery
W
weishengyu 已提交
446
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
447
            name='gallery')
W
weishengyu 已提交
448
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
449
            name='query')
B
Bin Lu 已提交
450

W
weishengyu 已提交
451
        # step2. do evaluation
W
dbg  
weishengyu 已提交
452
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
453
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
454
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
455 456
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
457 458 459
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
460
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
461 462
        metric_key = None

F
Felix 已提交
463
        if self.eval_metric_func is None:
W
weishengyu 已提交
464
            metric_dict = {metric_key: 0.}
F
Felix 已提交
465 466 467 468 469 470 471 472 473 474 475 476 477
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
478 479
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
480 481
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
482 483 484

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
485
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
486

F
Felix 已提交
487 488
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
489 490
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
491
                    else:
L
littletomatodonkey 已提交
492 493
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
494

W
dbg  
weishengyu 已提交
495 496 497 498 499 500
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
501
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
502

littletomatodonkey's avatar
littletomatodonkey 已提交
503
        return metric_dict[metric_key]
W
weishengyu 已提交
504 505 506 507

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
508
        all_unique_id = None
W
weishengyu 已提交
509 510 511 512 513 514 515
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
516
        has_unique_id = False
517 518
        max_iter = len(dataloader) - 1 if platform.system(
        ) == "Windows" else len(dataloader)
W
weishengyu 已提交
519 520
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
521 522
            if idx >= max_iter:
                break
L
littletomatodonkey 已提交
523 524 525 526
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
weishengyu 已提交
527
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
528
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
529
            if len(batch) == 3:
W
weishengyu 已提交
530
                has_unique_id = True
L
littletomatodonkey 已提交
531
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
532 533 534 535 536 537 538 539 540 541 542 543
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
544 545
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
546 547 548 549
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
550 551
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
weishengyu 已提交
552 553 554 555

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
556
            unique_id_list = []
W
weishengyu 已提交
557 558 559 560
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
561 562 563
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
564 565 566

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
567
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
568

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
C
cuicheng01 已提交
597 598
                if isinstance(out, list):
                    out = out[0]
599 600 601 602
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()