trainer.py 22.9 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
23 24
import time
import datetime
L
littletomatodonkey 已提交
25 26 27 28 29 30 31 32
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
33 34
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
35 36
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
37
from ppcls.arch import apply_to_static
W
weishengyu 已提交
38
from ppcls.loss import build_loss
W
weishengyu 已提交
39
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
40
from ppcls.optimizer import build_optimizer
41 42
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
43 44
from ppcls.utils import save_load

45 46
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
47
from ppcls.data import create_operators
48

L
littletomatodonkey 已提交
49 50

class Trainer(object):
51
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
52
        self.mode = mode
53
        self.config = config
L
littletomatodonkey 已提交
54
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
55 56 57 58 59

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
60 61 62 63 64 65 66 67
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
68 69 70 71 72 73 74 75

        if "Head" in self.config["Arch"]:
            self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
                "class_num"]
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
76
        self.model = build_model(self.config["Arch"])
A
Aurelius84 已提交
77 78
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
79

80 81 82 83
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(self.model,
                                  self.config["Global"]["pretrained_model"])

L
littletomatodonkey 已提交
84 85 86 87 88 89 90 91 92 93 94 95
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
        if self.config['Global']['use_visualdl']:
            from visualdl import LogWriter
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
96 97 98 99 100 101 102 103 104 105 106
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
107 108 109

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
110
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
111
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
112
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
113
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
114
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
115
            if metric_config is not None:
W
dbg  
weishengyu 已提交
116 117 118
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
119

W
weishengyu 已提交
120 121 122
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
123

W
weishengyu 已提交
124
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
141 142 143 144 145 146
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
147 148 149
        # global iter counter
        global_step = 0

150 151 152 153 154 155
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

littletomatodonkey's avatar
littletomatodonkey 已提交
156
        tic = time.time()
157 158 159
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
160
            for iter_id, batch in enumerate(self.train_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
161 162 163 164
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
165
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
166 167
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
168 169
                global_step += 1
                # image input
D
dongshuilong 已提交
170 171 172 173
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
174

L
littletomatodonkey 已提交
175
                # calc loss
W
weishengyu 已提交
176
                loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
177

L
littletomatodonkey 已提交
178 179 180 181 182 183
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
184 185
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
186 187 188 189 190 191
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
192 193 194 195 196 197 198 199
                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
200 201 202 203 204 205
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
206 207 208 209
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
210

littletomatodonkey's avatar
littletomatodonkey 已提交
211 212 213 214 215 216 217 218
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
219 220 221
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
222 223 224
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
                tic = time.time()
L
littletomatodonkey 已提交
225 226 227 228 229

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
230 231
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
232 233 234 235 236
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
237
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
238
                acc = self.eval(epoch_id)
239
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
240 241 242 243 244
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
245
                        best_metric,
L
littletomatodonkey 已提交
246 247 248
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
249
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
250
                    epoch_id, best_metric["metric"]))
W
weishengyu 已提交
251
                self.model.train()
L
littletomatodonkey 已提交
252 253 254 255 256

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
257 258
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
259 260
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
261 262 263 264 265 266 267 268 269
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
270 271 272 273 274 275 276

    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
277
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
278 279 280 281 282
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
283
        if self.eval_mode == "classification":
W
weishengyu 已提交
284 285 286
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
287 288

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
289
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
290
                if metric_config is not None:
W
dbg  
weishengyu 已提交
291 292 293
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
294

W
weishengyu 已提交
295 296
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
297
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
298 299
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
300
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
301 302 303

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
304
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
305
            # build metric info
W
weishengyu 已提交
306
            if self.eval_metric_func is None:
W
weishengyu 已提交
307 308 309 310 311 312
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
313
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
314 315
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
316 317 318
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
319

littletomatodonkey's avatar
littletomatodonkey 已提交
320
    @paddle.no_grad()
W
weishengyu 已提交
321 322
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
323 324 325 326 327 328
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
329 330 331
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
332
        tic = time.time()
W
weishengyu 已提交
333
        for iter_id, batch in enumerate(self.eval_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
334 335 336 337 338
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
339 340
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
341
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
342
            # image input
D
dongshuilong 已提交
343 344 345 346
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
347 348 349
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
350 351 352 353 354
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
355 356 357 358
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
359
                    for key in metric_dict:
W
weishengyu 已提交
360 361 362 363 364 365 366 367 368 369
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
370

W
weishengyu 已提交
371 372
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
373

littletomatodonkey's avatar
littletomatodonkey 已提交
374 375
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
376
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
377 378 379 380 381 382 383 384
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
385 386 387 388
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
389 390 391 392 393
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
L
littletomatodonkey 已提交
394 395 396 397 398 399 400 401

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
402
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
403 404 405
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
406

W
weishengyu 已提交
407 408
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
409
        cum_similarity_matrix = None
W
weishengyu 已提交
410
        # step1. build gallery
W
weishengyu 已提交
411
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
412
            name='gallery')
W
weishengyu 已提交
413
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
414
            name='query')
B
Bin Lu 已提交
415

W
weishengyu 已提交
416
        # step2. do evaluation
W
dbg  
weishengyu 已提交
417
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
418
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
419
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
420 421
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
422 423 424
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
425
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
426 427
        metric_key = None

F
Felix 已提交
428
        if self.eval_metric_func is None:
W
weishengyu 已提交
429
            metric_dict = {metric_key: 0.}
F
Felix 已提交
430 431 432 433 434 435 436 437 438 439 440 441 442
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
443 444
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
445 446
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
447 448 449

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
450
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
451

F
Felix 已提交
452 453
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
454 455
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
456
                    else:
L
littletomatodonkey 已提交
457 458
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
459

W
dbg  
weishengyu 已提交
460 461 462 463 464 465
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
466
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
467

littletomatodonkey's avatar
littletomatodonkey 已提交
468
        return metric_dict[metric_key]
W
weishengyu 已提交
469 470 471 472

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
473
        all_unique_id = None
W
weishengyu 已提交
474 475 476 477 478 479 480
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
481
        has_unique_id = False
W
weishengyu 已提交
482 483
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
L
littletomatodonkey 已提交
484 485 486 487
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
weishengyu 已提交
488
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
489
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
490
            if len(batch) == 3:
W
weishengyu 已提交
491
                has_unique_id = True
L
littletomatodonkey 已提交
492
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
493 494 495 496 497 498 499 500 501 502 503 504
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
505 506
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
507 508 509 510
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
511 512
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
weishengyu 已提交
513 514 515 516

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
517
            unique_id_list = []
W
weishengyu 已提交
518 519 520 521
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
522 523 524
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
525 526 527

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
528
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
529

530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()