trainer.py 22.8 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
23 24
import time
import datetime
L
littletomatodonkey 已提交
25 26 27 28 29 30 31 32
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
33 34
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
35 36
from ppcls.data import build_dataloader
from ppcls.arch import build_model
A
Aurelius84 已提交
37
from ppcls.arch import apply_to_static
W
weishengyu 已提交
38
from ppcls.loss import build_loss
W
weishengyu 已提交
39
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
40
from ppcls.optimizer import build_optimizer
41 42
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
43 44
from ppcls.utils import save_load

45 46
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
47
from ppcls.data import create_operators
48

L
littletomatodonkey 已提交
49 50

class Trainer(object):
51
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
52
        self.mode = mode
53
        self.config = config
L
littletomatodonkey 已提交
54
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
55 56 57 58 59

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
60 61 62 63 64 65 66 67
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
68 69 70 71 72 73

        if "Head" in self.config["Arch"]:
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
74
        self.model = build_model(self.config["Arch"])
A
Aurelius84 已提交
75 76
        # set @to_static for benchmark, skip this by default.
        apply_to_static(self.config, self.model)
L
littletomatodonkey 已提交
77

78 79 80 81
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(self.model,
                                  self.config["Global"]["pretrained_model"])

L
littletomatodonkey 已提交
82 83 84 85 86 87 88 89 90 91 92 93
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
        if self.config['Global']['use_visualdl']:
            from visualdl import LogWriter
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
94 95 96 97 98 99 100 101 102 103 104
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
105 106 107

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
108
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
109
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
110
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
111
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
112
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
113
            if metric_config is not None:
W
dbg  
weishengyu 已提交
114 115 116
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
117

W
weishengyu 已提交
118 119 120
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
121

W
weishengyu 已提交
122
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
139 140 141 142 143 144
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
145 146 147
        # global iter counter
        global_step = 0

148 149 150 151 152 153
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

littletomatodonkey's avatar
littletomatodonkey 已提交
154
        tic = time.time()
155 156 157
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
158
            for iter_id, batch in enumerate(self.train_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
159 160 161 162
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
163
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
164 165
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
166 167
                global_step += 1
                # image input
D
dongshuilong 已提交
168 169 170 171
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
172

L
littletomatodonkey 已提交
173
                # calc loss
W
weishengyu 已提交
174
                loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
175

L
littletomatodonkey 已提交
176 177 178 179 180 181
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
182 183
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
184 185 186 187 188 189
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
190 191 192 193 194 195 196 197
                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
198 199 200 201 202 203
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
204 205 206 207
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
208

littletomatodonkey's avatar
littletomatodonkey 已提交
209 210 211 212 213 214 215 216
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
217 218 219
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
220 221 222
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
                tic = time.time()
L
littletomatodonkey 已提交
223 224 225 226 227

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
228 229
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
230 231 232 233 234
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
235
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
236
                acc = self.eval(epoch_id)
237
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
238 239 240 241 242
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
243
                        best_metric,
L
littletomatodonkey 已提交
244 245 246
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
247
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
248
                    epoch_id, best_metric["metric"]))
W
weishengyu 已提交
249
                self.model.train()
L
littletomatodonkey 已提交
250 251 252 253 254

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
255 256
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
257 258
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
259 260 261 262 263 264 265 266 267
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
268 269 270 271 272 273 274

    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
275
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
276 277 278 279 280
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
281
        if self.eval_mode == "classification":
W
weishengyu 已提交
282 283 284
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
285 286

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
287
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
288
                if metric_config is not None:
W
dbg  
weishengyu 已提交
289 290 291
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
292

W
weishengyu 已提交
293 294
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
295
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
296 297
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
298
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
299 300 301

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
302
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
303
            # build metric info
W
weishengyu 已提交
304
            if self.eval_metric_func is None:
W
weishengyu 已提交
305 306 307 308 309 310
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
311
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
312 313
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
314 315 316
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
317

littletomatodonkey's avatar
littletomatodonkey 已提交
318
    @paddle.no_grad()
W
weishengyu 已提交
319 320
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
321 322 323 324 325 326
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
327 328 329
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
330
        tic = time.time()
W
weishengyu 已提交
331
        for iter_id, batch in enumerate(self.eval_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
332 333 334 335 336
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
337 338
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
339
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
340
            # image input
D
dongshuilong 已提交
341 342 343 344
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
345 346 347
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
348 349 350 351 352
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
353 354 355 356
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
357
                    for key in metric_dict:
W
weishengyu 已提交
358 359 360 361 362 363 364 365 366 367
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
368

W
weishengyu 已提交
369 370
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
371

littletomatodonkey's avatar
littletomatodonkey 已提交
372 373
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
374
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
375 376 377 378 379 380 381 382
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
383 384 385 386
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
387 388 389 390 391
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
L
littletomatodonkey 已提交
392 393 394 395 396 397 398 399

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
400
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
401 402 403
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
404

W
weishengyu 已提交
405 406
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
407
        cum_similarity_matrix = None
W
weishengyu 已提交
408
        # step1. build gallery
W
weishengyu 已提交
409
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
410
            name='gallery')
W
weishengyu 已提交
411
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
412
            name='query')
B
Bin Lu 已提交
413

W
weishengyu 已提交
414
        # step2. do evaluation
W
dbg  
weishengyu 已提交
415
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
416
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
417
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
418 419
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
420 421 422
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
423
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
424 425
        metric_key = None

F
Felix 已提交
426
        if self.eval_metric_func is None:
W
weishengyu 已提交
427
            metric_dict = {metric_key: 0.}
F
Felix 已提交
428 429 430 431 432 433 434 435 436 437 438 439 440
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
441 442
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")
D
dongshuilong 已提交
443 444
                else:
                    keep_mask = None
littletomatodonkey's avatar
littletomatodonkey 已提交
445 446 447

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
D
dongshuilong 已提交
448
                                                   gallery_img_id, keep_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
449

F
Felix 已提交
450 451
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
452 453
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
454
                    else:
L
littletomatodonkey 已提交
455 456
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
457

W
dbg  
weishengyu 已提交
458 459 460 461 462 463
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
464
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
465

littletomatodonkey's avatar
littletomatodonkey 已提交
466
        return metric_dict[metric_key]
W
weishengyu 已提交
467 468 469 470

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
471
        all_unique_id = None
W
weishengyu 已提交
472 473 474 475 476 477 478
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
479
        has_unique_id = False
W
weishengyu 已提交
480 481
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
L
littletomatodonkey 已提交
482 483 484 485
            if idx % self.config["Global"]["print_batch_step"] == 0:
                logger.info(
                    f"{name} feature calculation process: [{idx}/{len(dataloader)}]"
                )
W
weishengyu 已提交
486
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
487
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
488
            if len(batch) == 3:
W
weishengyu 已提交
489
                has_unique_id = True
L
littletomatodonkey 已提交
490
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
491 492 493 494 495 496 497 498 499 500 501 502
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
503 504
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
505 506 507 508
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
509 510
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
weishengyu 已提交
511 512 513 514

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
515
            unique_id_list = []
W
weishengyu 已提交
516 517 518 519
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
520 521 522
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
523 524 525

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
526
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
527

528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()