trainer.py 22.5 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

littletomatodonkey's avatar
littletomatodonkey 已提交
23 24
import time
import datetime
L
littletomatodonkey 已提交
25 26 27 28 29 30 31 32
import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
L
littletomatodonkey 已提交
33 34
from ppcls.utils.logger import init_logger
from ppcls.utils.config import print_config
L
littletomatodonkey 已提交
35 36
from ppcls.data import build_dataloader
from ppcls.arch import build_model
W
weishengyu 已提交
37
from ppcls.loss import build_loss
W
weishengyu 已提交
38
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
39
from ppcls.optimizer import build_optimizer
40 41
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
42 43
from ppcls.utils import save_load

44 45
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
littletomatodonkey's avatar
littletomatodonkey 已提交
46
from ppcls.data import create_operators
47

L
littletomatodonkey 已提交
48 49

class Trainer(object):
50
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
51
        self.mode = mode
52
        self.config = config
L
littletomatodonkey 已提交
53
        self.output_dir = self.config['Global']['output_dir']
L
littletomatodonkey 已提交
54 55 56 57 58

        log_file = os.path.join(self.output_dir, self.config["Arch"]["name"],
                                f"{mode}.log")
        init_logger(name='root', log_file=log_file)
        print_config(config)
L
littletomatodonkey 已提交
59 60 61 62 63 64 65 66
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
67 68 69 70 71 72 73 74

        if "Head" in self.config["Arch"]:
            self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
                "class_num"]
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
75 76
        self.model = build_model(self.config["Arch"])

77 78 79 80
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(self.model,
                                  self.config["Global"]["pretrained_model"])

L
littletomatodonkey 已提交
81 82 83 84 85 86 87 88 89 90 91 92
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
        if self.config['Global']['use_visualdl']:
            from visualdl import LogWriter
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
93 94 95 96 97 98 99 100 101 102 103
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
104 105 106

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
107
        if self.train_loss_func is None:
W
dbg  
weishengyu 已提交
108
            loss_info = self.config["Loss"]["Train"]
W
dbg  
weishengyu 已提交
109
            self.train_loss_func = build_loss(loss_info)
W
weishengyu 已提交
110
        if self.train_metric_func is None:
W
dbg  
weishengyu 已提交
111
            metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
112
            if metric_config is not None:
W
dbg  
weishengyu 已提交
113 114 115
                metric_config = metric_config.get("Train")
                if metric_config is not None:
                    self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
116

W
weishengyu 已提交
117 118 119
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
120

W
weishengyu 已提交
121
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
138 139 140 141 142 143
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
144 145 146
        # global iter counter
        global_step = 0

147 148 149 150 151 152
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

littletomatodonkey's avatar
littletomatodonkey 已提交
153
        tic = time.time()
154 155 156
        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
157
            for iter_id, batch in enumerate(self.train_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
158 159 160 161
                if iter_id == 5:
                    for key in time_info:
                        time_info[key].reset()
                time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
162
                batch_size = batch[0].shape[0]
L
littletomatodonkey 已提交
163 164
                batch[1] = batch[1].reshape([-1, 1]).astype("int64")

L
littletomatodonkey 已提交
165 166
                global_step += 1
                # image input
D
dongshuilong 已提交
167 168 169 170
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
171

L
littletomatodonkey 已提交
172
                # calc loss
W
weishengyu 已提交
173
                loss_dict = self.train_loss_func(out, batch[1])
littletomatodonkey's avatar
littletomatodonkey 已提交
174

L
littletomatodonkey 已提交
175 176 177 178 179 180
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
181 182
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
183 184 185 186 187 188
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

littletomatodonkey's avatar
littletomatodonkey 已提交
189 190 191 192 193 194 195 196
                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

                time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
197 198 199 200 201 202
                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
littletomatodonkey's avatar
littletomatodonkey 已提交
203 204 205 206
                    time_msg = "s, ".join([
                        "{}: {:.5f}".format(key, time_info[key].avg)
                        for key in time_info
                    ])
L
littletomatodonkey 已提交
207

littletomatodonkey's avatar
littletomatodonkey 已提交
208 209 210 211 212 213 214 215
                    ips_msg = "ips: {:.5f} images/sec".format(
                        batch_size / time_info["batch_cost"].avg)
                    eta_sec = ((self.config["Global"]["epochs"] - epoch_id + 1
                                ) * len(self.train_dataloader) - iter_id
                               ) * time_info["batch_cost"].avg
                    eta_msg = "eta: {:s}".format(
                        str(datetime.timedelta(seconds=int(eta_sec))))
                    logger.info(
L
littletomatodonkey 已提交
216 217 218
                        "[Train][Epoch {}/{}][Iter: {}/{}]{}, {}, {}, {}, {}".
                        format(epoch_id, self.config["Global"][
                            "epochs"], iter_id,
littletomatodonkey's avatar
littletomatodonkey 已提交
219 220 221
                               len(self.train_dataloader), lr_msg, metric_msg,
                               time_msg, ips_msg, eta_msg))
                tic = time.time()
L
littletomatodonkey 已提交
222 223 224 225 226

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
L
littletomatodonkey 已提交
227 228
            logger.info("[Train][Epoch {}/{}][Avg]{}".format(
                epoch_id, self.config["Global"]["epochs"], metric_msg))
L
littletomatodonkey 已提交
229 230 231 232 233
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
C
cuicheng01 已提交
234
                        "eval_interval"] == 0:
L
littletomatodonkey 已提交
235
                acc = self.eval(epoch_id)
236
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
237 238 239 240 241
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
242
                        best_metric,
L
littletomatodonkey 已提交
243 244 245
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
246
                logger.info("[Eval][Epoch {}][best metric: {}]".format(
L
littletomatodonkey 已提交
247
                    epoch_id, best_metric["metric"]))
W
weishengyu 已提交
248
                self.model.train()
L
littletomatodonkey 已提交
249 250 251 252 253

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
254 255
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
256 257
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
258 259 260 261 262 263 264 265 266
                    prefix="epoch_{}".format(epoch_id))
                # save the latest model
                save_load.save_model(
                    self.model,
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="latest")
L
littletomatodonkey 已提交
267 268 269 270 271 272 273

    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
274
        if self.eval_loss_func is None:
W
dbg  
weishengyu 已提交
275 276 277 278 279
            loss_config = self.config.get("Loss", None)
            if loss_config is not None:
                loss_config = loss_config.get("Eval")
                if loss_config is not None:
                    self.eval_loss_func = build_loss(loss_config)
W
weishengyu 已提交
280
        if self.eval_mode == "classification":
W
weishengyu 已提交
281 282 283
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
284 285

            if self.eval_metric_func is None:
W
dbg  
weishengyu 已提交
286
                metric_config = self.config.get("Metric")
W
dbg  
weishengyu 已提交
287
                if metric_config is not None:
W
dbg  
weishengyu 已提交
288 289 290
                    metric_config = metric_config.get("Eval")
                    if metric_config is not None:
                        self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
291

W
weishengyu 已提交
292 293
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
294
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
295 296
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
297
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
298 299 300

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
301
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
302
            # build metric info
W
weishengyu 已提交
303
            if self.eval_metric_func is None:
W
weishengyu 已提交
304 305 306 307 308 309
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
310
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
311 312
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
313 314 315
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
316

littletomatodonkey's avatar
littletomatodonkey 已提交
317
    @paddle.no_grad()
W
weishengyu 已提交
318 319
    def eval_cls(self, epoch_id=0):
        output_info = dict()
littletomatodonkey's avatar
littletomatodonkey 已提交
320 321 322 323 324 325
        time_info = {
            "batch_cost": AverageMeter(
                "batch_cost", '.5f', postfix=" s,"),
            "reader_cost": AverageMeter(
                "reader_cost", ".5f", postfix=" s,"),
        }
L
littletomatodonkey 已提交
326 327 328
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
littletomatodonkey's avatar
littletomatodonkey 已提交
329
        tic = time.time()
W
weishengyu 已提交
330
        for iter_id, batch in enumerate(self.eval_dataloader()):
littletomatodonkey's avatar
littletomatodonkey 已提交
331 332 333 334 335
            if iter_id == 5:
                for key in time_info:
                    time_info[key].reset()

            time_info["reader_cost"].update(time.time() - tic)
L
littletomatodonkey 已提交
336 337
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
L
littletomatodonkey 已提交
338
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
L
littletomatodonkey 已提交
339
            # image input
D
dongshuilong 已提交
340 341 342 343
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
344 345 346
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
347 348 349 350 351
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
352 353 354 355
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
356
                    for key in metric_dict:
W
weishengyu 已提交
357 358 359 360 361 362 363 364 365 366
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
367

W
weishengyu 已提交
368 369
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
370

littletomatodonkey's avatar
littletomatodonkey 已提交
371 372
            time_info["batch_cost"].update(time.time() - tic)

L
littletomatodonkey 已提交
373
            if iter_id % print_batch_step == 0:
littletomatodonkey's avatar
littletomatodonkey 已提交
374 375 376 377 378 379 380 381
                time_msg = "s, ".join([
                    "{}: {:.5f}".format(key, time_info[key].avg)
                    for key in time_info
                ])

                ips_msg = "ips: {:.5f} images/sec".format(
                    batch_size / time_info["batch_cost"].avg)

L
littletomatodonkey 已提交
382 383 384 385
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
littletomatodonkey's avatar
littletomatodonkey 已提交
386 387 388 389 390
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}, {}, {}".format(
                    epoch_id, iter_id,
                    len(self.eval_dataloader), metric_msg, time_msg, ips_msg))

            tic = time.time()
L
littletomatodonkey 已提交
391 392 393 394 395 396 397 398

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
399
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
400 401 402
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
403

W
weishengyu 已提交
404 405
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
406
        cum_similarity_matrix = None
W
weishengyu 已提交
407
        # step1. build gallery
W
weishengyu 已提交
408
        gallery_feas, gallery_img_id, gallery_unique_id = self._cal_feature(
W
weishengyu 已提交
409
            name='gallery')
W
weishengyu 已提交
410
        query_feas, query_img_id, query_query_id = self._cal_feature(
W
weishengyu 已提交
411
            name='query')
B
Bin Lu 已提交
412

W
weishengyu 已提交
413
        # step2. do evaluation
W
dbg  
weishengyu 已提交
414
        sim_block_size = self.config["Global"].get("sim_block_size", 64)
W
weishengyu 已提交
415
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
W
dbg  
weishengyu 已提交
416
        if len(query_feas) % sim_block_size:
W
weishengyu 已提交
417 418
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
419 420 421
        if query_query_id is not None:
            query_id_blocks = paddle.split(
                query_query_id, num_or_sections=sections)
littletomatodonkey's avatar
littletomatodonkey 已提交
422
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
W
weishengyu 已提交
423 424
        metric_key = None

F
Felix 已提交
425
        if self.eval_metric_func is None:
W
weishengyu 已提交
426
            metric_dict = {metric_key: 0.}
F
Felix 已提交
427 428 429 430 431 432 433 434 435 436 437 438 439
        else:
            metric_dict = dict()
            for block_idx, block_fea in enumerate(fea_blocks):
                similarity_matrix = paddle.matmul(
                    block_fea, gallery_feas, transpose_y=True)
                if query_query_id is not None:
                    query_id_block = query_id_blocks[block_idx]
                    query_id_mask = (query_id_block != gallery_unique_id.t())

                    image_id_block = image_id_blocks[block_idx]
                    image_id_mask = (image_id_block != gallery_img_id.t())

                    keep_mask = paddle.logical_or(query_id_mask, image_id_mask)
littletomatodonkey's avatar
littletomatodonkey 已提交
440 441 442 443 444 445 446
                    similarity_matrix = similarity_matrix * keep_mask.astype(
                        "float32")

                metric_tmp = self.eval_metric_func(similarity_matrix,
                                                   image_id_blocks[block_idx],
                                                   gallery_img_id)

F
Felix 已提交
447 448
                for key in metric_tmp:
                    if key not in metric_dict:
L
littletomatodonkey 已提交
449 450
                        metric_dict[key] = metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
F
Felix 已提交
451
                    else:
L
littletomatodonkey 已提交
452 453
                        metric_dict[key] += metric_tmp[key] * block_fea.shape[
                            0] / len(query_feas)
littletomatodonkey's avatar
littletomatodonkey 已提交
454

W
dbg  
weishengyu 已提交
455 456 457 458 459 460
        metric_info_list = []
        for key in metric_dict:
            if metric_key is None:
                metric_key = key
            metric_info_list.append("{}: {:.5f}".format(key, metric_dict[key]))
        metric_msg = ", ".join(metric_info_list)
W
weishengyu 已提交
461
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
462

littletomatodonkey's avatar
littletomatodonkey 已提交
463
        return metric_dict[metric_key]
W
weishengyu 已提交
464 465 466 467

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
W
weishengyu 已提交
468
        all_unique_id = None
W
weishengyu 已提交
469 470 471 472 473 474 475
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

W
weishengyu 已提交
476
        has_unique_id = False
W
weishengyu 已提交
477 478 479
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
            batch = [paddle.to_tensor(x) for x in batch]
L
littletomatodonkey 已提交
480
            batch[1] = batch[1].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
481
            if len(batch) == 3:
W
weishengyu 已提交
482
                has_unique_id = True
L
littletomatodonkey 已提交
483
                batch[2] = batch[2].reshape([-1, 1]).astype("int64")
W
weishengyu 已提交
484 485 486 487 488 489 490 491 492 493 494 495
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
W
weishengyu 已提交
496 497
                if has_unique_id:
                    all_unique_id = batch[2]
W
weishengyu 已提交
498 499 500 501
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
W
weishengyu 已提交
502 503
                if has_unique_id:
                    all_unique_id = paddle.concat([all_unique_id, batch[2]])
W
weishengyu 已提交
504 505 506 507

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
W
weishengyu 已提交
508
            unique_id_list = []
W
weishengyu 已提交
509 510 511 512
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
W
weishengyu 已提交
513 514 515
            if has_unique_id:
                paddle.distributed.all_gather(unique_id_list, all_unique_id)
                all_unique_id = paddle.concat(unique_id_list, axis=0)
W
weishengyu 已提交
516 517 518

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
W
weishengyu 已提交
519
        return all_feas, all_image_id, all_unique_id
W
weishengyu 已提交
520

521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()