trainer.py 19.6 KB
Newer Older
D
dongshuilong 已提交
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
L
littletomatodonkey 已提交
2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
import numpy as np
__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../../')))

import argparse
import paddle
import paddle.nn as nn
import paddle.distributed as dist

from ppcls.utils.check import check_gpu
from ppcls.utils.misc import AverageMeter
from ppcls.utils import logger
from ppcls.data import build_dataloader
from ppcls.arch import build_model
W
weishengyu 已提交
33
from ppcls.loss import build_loss
W
weishengyu 已提交
34
from ppcls.metric import build_metrics
L
littletomatodonkey 已提交
35
from ppcls.optimizer import build_optimizer
36 37
from ppcls.utils.save_load import load_dygraph_pretrain
from ppcls.utils.save_load import init_model
L
littletomatodonkey 已提交
38 39
from ppcls.utils import save_load

40 41 42 43
from ppcls.data.utils.get_image_list import get_image_list
from ppcls.data.postprocess import build_postprocess
from ppcls.data.reader import create_operators

L
littletomatodonkey 已提交
44 45

class Trainer(object):
46
    def __init__(self, config, mode="train"):
L
littletomatodonkey 已提交
47
        self.mode = mode
48
        self.config = config
L
littletomatodonkey 已提交
49 50 51 52 53 54 55 56 57
        self.output_dir = self.config['Global']['output_dir']
        # set device
        assert self.config["Global"]["device"] in ["cpu", "gpu", "xpu"]
        self.device = paddle.set_device(self.config["Global"]["device"])
        # set dist
        self.config["Global"][
            "distributed"] = paddle.distributed.get_world_size() != 1
        if self.config["Global"]["distributed"]:
            dist.init_parallel_env()
D
dongshuilong 已提交
58 59 60 61 62 63 64 65

        if "Head" in self.config["Arch"]:
            self.config["Arch"]["Head"]["class_num"] = self.config["Global"][
                "class_num"]
            self.is_rec = True
        else:
            self.is_rec = False

L
littletomatodonkey 已提交
66 67
        self.model = build_model(self.config["Arch"])

68 69 70 71
        if self.config["Global"]["pretrained_model"] is not None:
            load_dygraph_pretrain(self.model,
                                  self.config["Global"]["pretrained_model"])

L
littletomatodonkey 已提交
72 73 74 75 76 77 78 79 80 81 82 83
        if self.config["Global"]["distributed"]:
            self.model = paddle.DataParallel(self.model)

        self.vdl_writer = None
        if self.config['Global']['use_visualdl']:
            from visualdl import LogWriter
            vdl_writer_path = os.path.join(self.output_dir, "vdl")
            if not os.path.exists(vdl_writer_path):
                os.makedirs(vdl_writer_path)
            self.vdl_writer = LogWriter(logdir=vdl_writer_path)
        logger.info('train with paddle {} and device {}'.format(
            paddle.__version__, self.device))
W
weishengyu 已提交
84 85 86 87 88 89 90 91 92 93 94
        # init members
        self.train_dataloader = None
        self.eval_dataloader = None
        self.gallery_dataloader = None
        self.query_dataloader = None
        self.eval_mode = self.config["Global"].get("eval_mode",
                                                   "classification")
        self.train_loss_func = None
        self.eval_loss_func = None
        self.train_metric_func = None
        self.eval_metric_func = None
L
littletomatodonkey 已提交
95 96 97

    def train(self):
        # build train loss and metric info
W
weishengyu 已提交
98
        if self.train_loss_func is None:
W
weishengyu 已提交
99 100 101 102 103 104 105 106
            self.train_loss_func = build_loss(self.config["Loss"])
        if self.train_metric_func is None:
            metric_config = self.config.get("Metric", None)
            if metric_config is None:
                metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
            else:
                metric_config = metric_config["Train"]
            self.train_metric_func = build_metrics(metric_config)
L
littletomatodonkey 已提交
107

W
weishengyu 已提交
108 109 110
        if self.train_dataloader is None:
            self.train_dataloader = build_dataloader(self.config["DataLoader"],
                                                     "Train", self.device)
L
littletomatodonkey 已提交
111

W
weishengyu 已提交
112
        step_each_epoch = len(self.train_dataloader)
L
littletomatodonkey 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131

        optimizer, lr_sch = build_optimizer(self.config["Optimizer"],
                                            self.config["Global"]["epochs"],
                                            step_each_epoch,
                                            self.model.parameters())

        print_batch_step = self.config['Global']['print_batch_step']
        save_interval = self.config["Global"]["save_interval"]

        best_metric = {
            "metric": 0.0,
            "epoch": 0,
        }
        # key: 
        # val: metrics list word
        output_info = dict()
        # global iter counter
        global_step = 0

132 133 134 135 136 137 138 139 140
        if self.config["Global"]["checkpoints"] is not None:
            metric_info = init_model(self.config["Global"], self.model,
                                     optimizer)
            if metric_info is not None:
                best_metric.update(metric_info)

        for epoch_id in range(best_metric["epoch"] + 1,
                              self.config["Global"]["epochs"] + 1):
            acc = 0.0
W
weishengyu 已提交
141
            for iter_id, batch in enumerate(self.train_dataloader()):
L
littletomatodonkey 已提交
142 143 144 145 146
                batch_size = batch[0].shape[0]
                batch[1] = paddle.to_tensor(batch[1].numpy().astype("int64")
                                            .reshape([-1, 1]))
                global_step += 1
                # image input
D
dongshuilong 已提交
147 148 149 150
                if not self.is_rec:
                    out = self.model(batch[0])
                else:
                    out = self.model(batch[0], batch[1])
L
littletomatodonkey 已提交
151
                # calc loss
W
weishengyu 已提交
152
                loss_dict = self.train_loss_func(out, batch[1])
L
littletomatodonkey 已提交
153 154 155 156 157 158
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
                # calc metric
W
weishengyu 已提交
159 160
                if self.train_metric_func is not None:
                    metric_dict = self.train_metric_func(out, batch[-1])
L
littletomatodonkey 已提交
161 162 163 164 165 166 167 168 169 170 171 172 173 174
                    for key in metric_dict:
                        if not key in output_info:
                            output_info[key] = AverageMeter(key, '7.5f')
                        output_info[key].update(metric_dict[key].numpy()[0],
                                                batch_size)

                if iter_id % print_batch_step == 0:
                    lr_msg = "lr: {:.5f}".format(lr_sch.get_lr())
                    metric_msg = ", ".join([
                        "{}: {:.5f}".format(key, output_info[key].avg)
                        for key in output_info
                    ])
                    logger.info("[Train][Epoch {}][Iter: {}/{}]{}, {}".format(
                        epoch_id, iter_id,
W
weishengyu 已提交
175
                        len(self.train_dataloader), lr_msg, metric_msg))
L
littletomatodonkey 已提交
176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195

                # step opt and lr
                loss_dict["loss"].backward()
                optimizer.step()
                optimizer.clear_grad()
                lr_sch.step()

            metric_msg = ", ".join([
                "{}: {:.5f}".format(key, output_info[key].avg)
                for key in output_info
            ])
            logger.info("[Train][Epoch {}][Avg]{}".format(epoch_id,
                                                          metric_msg))
            output_info.clear()

            # eval model and save model if possible
            if self.config["Global"][
                    "eval_during_train"] and epoch_id % self.config["Global"][
                        "eval_during_train"] == 0:
                acc = self.eval(epoch_id)
196
                if acc > best_metric["metric"]:
L
littletomatodonkey 已提交
197 198 199 200 201
                    best_metric["metric"] = acc
                    best_metric["epoch"] = epoch_id
                    save_load.save_model(
                        self.model,
                        optimizer,
202
                        best_metric,
L
littletomatodonkey 已提交
203 204 205
                        self.output_dir,
                        model_name=self.config["Arch"]["name"],
                        prefix="best_model")
W
weishengyu 已提交
206
                self.model.train()
L
littletomatodonkey 已提交
207 208 209 210 211

            # save model
            if epoch_id % save_interval == 0:
                save_load.save_model(
                    self.model,
212 213
                    optimizer, {"metric": acc,
                                "epoch": epoch_id},
L
littletomatodonkey 已提交
214 215 216 217 218 219 220 221 222 223
                    self.output_dir,
                    model_name=self.config["Arch"]["name"],
                    prefix="ppcls_epoch_{}".format(epoch_id))

    def build_avg_metrics(self, info_dict):
        return {key: AverageMeter(key, '7.5f') for key in info_dict}

    @paddle.no_grad()
    def eval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
224 225 226 227 228 229 230
        if self.eval_loss_func is None:
            loss_info = self.config.get("Loss", None)
            if loss_info is None:
                loss_info = [{"CELoss": {"weight": 1.0}}]
            else:
                loss_info = loss_info["Eval"]
            self.eval_loss_func = build_loss(loss_info)
W
weishengyu 已提交
231
        if self.eval_mode == "classification":
W
weishengyu 已提交
232 233 234
            if self.eval_dataloader is None:
                self.eval_dataloader = build_dataloader(
                    self.config["DataLoader"], "Eval", self.device)
W
weishengyu 已提交
235 236 237 238 239 240 241 242 243

            if self.eval_metric_func is None:
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "TopkAcc", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)

W
weishengyu 已提交
244 245
            eval_result = self.eval_cls(epoch_id)

W
weishengyu 已提交
246
        elif self.eval_mode == "retrieval":
W
weishengyu 已提交
247 248
            if self.gallery_dataloader is None:
                self.gallery_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
249
                    self.config["DataLoader"]["Eval"], "Gallery", self.device)
W
weishengyu 已提交
250 251 252

            if self.query_dataloader is None:
                self.query_dataloader = build_dataloader(
W
dbg  
weishengyu 已提交
253
                    self.config["DataLoader"]["Eval"], "Query", self.device)
W
weishengyu 已提交
254
            # build metric info
W
weishengyu 已提交
255
            if self.eval_metric_func is None:
W
weishengyu 已提交
256 257 258 259 260 261
                metric_config = self.config.get("Metric", None)
                if metric_config is None:
                    metric_config = [{"name": "Recallk", "topk": (1, 5)}]
                else:
                    metric_config = metric_config["Eval"]
                self.eval_metric_func = build_metrics(metric_config)
W
weishengyu 已提交
262
            eval_result = self.eval_retrieval(epoch_id)
W
weishengyu 已提交
263 264
        else:
            logger.warning("Invalid eval mode: {}".format(self.eval_mode))
W
weishengyu 已提交
265 266 267
            eval_result = None
        self.model.train()
        return eval_result
W
weishengyu 已提交
268 269 270

    def eval_cls(self, epoch_id=0):
        output_info = dict()
L
littletomatodonkey 已提交
271 272 273
        print_batch_step = self.config["Global"]["print_batch_step"]

        metric_key = None
W
weishengyu 已提交
274
        for iter_id, batch in enumerate(self.eval_dataloader()):
L
littletomatodonkey 已提交
275 276 277 278
            batch_size = batch[0].shape[0]
            batch[0] = paddle.to_tensor(batch[0]).astype("float32")
            batch[1] = paddle.to_tensor(batch[1]).reshape([-1, 1])
            # image input
D
dongshuilong 已提交
279 280 281 282
            if self.is_rec:
                out = self.model(batch[0], batch[1])
            else:
                out = self.model(batch[0])
W
weishengyu 已提交
283 284 285
            # calc loss
            if self.eval_loss_func is not None:
                loss_dict = self.eval_loss_func(out, batch[-1])
L
littletomatodonkey 已提交
286 287 288 289 290
                for key in loss_dict:
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
                    output_info[key].update(loss_dict[key].numpy()[0],
                                            batch_size)
W
weishengyu 已提交
291 292 293 294
            # calc metric
            if self.eval_metric_func is not None:
                metric_dict = self.eval_metric_func(out, batch[-1])
                if paddle.distributed.get_world_size() > 1:
L
littletomatodonkey 已提交
295
                    for key in metric_dict:
W
weishengyu 已提交
296 297 298 299 300 301 302 303 304 305
                        paddle.distributed.all_reduce(
                            metric_dict[key],
                            op=paddle.distributed.ReduceOp.SUM)
                        metric_dict[key] = metric_dict[
                            key] / paddle.distributed.get_world_size()
                for key in metric_dict:
                    if metric_key is None:
                        metric_key = key
                    if not key in output_info:
                        output_info[key] = AverageMeter(key, '7.5f')
L
littletomatodonkey 已提交
306

W
weishengyu 已提交
307 308
                    output_info[key].update(metric_dict[key].numpy()[0],
                                            batch_size)
L
littletomatodonkey 已提交
309 310 311 312 313 314 315

            if iter_id % print_batch_step == 0:
                metric_msg = ", ".join([
                    "{}: {:.5f}".format(key, output_info[key].val)
                    for key in output_info
                ])
                logger.info("[Eval][Epoch {}][Iter: {}/{}]{}".format(
W
weishengyu 已提交
316
                    epoch_id, iter_id, len(self.eval_dataloader), metric_msg))
L
littletomatodonkey 已提交
317 318 319 320 321 322 323 324

        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, output_info[key].avg)
            for key in output_info
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))

        # do not try to save best model
W
weishengyu 已提交
325
        if self.eval_metric_func is None:
L
littletomatodonkey 已提交
326 327 328
            return -1
        # return 1st metric in the dict
        return output_info[metric_key].avg
329

W
weishengyu 已提交
330 331
    def eval_retrieval(self, epoch_id=0):
        self.model.eval()
W
weishengyu 已提交
332
        cum_similarity_matrix = None
W
weishengyu 已提交
333 334 335 336 337 338 339
        # step1. build gallery
        gallery_feas, gallery_img_id, gallery_camera_id = self._cal_feature(
            name='gallery')
        query_feas, query_img_id, query_camera_id = self._cal_feature(
            name='query')
        gallery_img_id = paddle.to_tensor([gallery_img_id]).t()
        if gallery_camera_id is not None:
W
weishengyu 已提交
340
            gallery_camera_id = paddle.to_tensor([gallery_camera_id]).t()
W
weishengyu 已提交
341 342 343 344 345 346 347 348 349
        query_img_id = paddle.to_tensor(query_img_id)
        if query_camera_id is not None:
            query_camera_id = paddle.to_tensor(query_camera_id)
        # step2. do evaluation
        sim_block_size = self.config["Global"].get("sim_block_size", 1)
        sections = [sim_block_size] * (len(query_feas) // sim_block_size)
        if not len(query_feas) % sim_block_size:
            sections.append(len(query_feas) % sim_block_size)
        fea_blocks = paddle.split(query_feas, num_or_sections=sections)
W
weishengyu 已提交
350 351 352
        if query_camera_id is not None:
            camera_id_blocks = paddle.split(
                query_camera_id, num_or_sections=sections)
W
weishengyu 已提交
353 354 355 356
        image_id_blocks = paddle.split(query_img_id, num_or_sections=sections)
        metric_key = None

        for block_idx, block_fea in enumerate(fea_blocks):
W
weishengyu 已提交
357
            similarity_matrix = paddle.matmul(
W
weishengyu 已提交
358 359
                block_fea, gallery_feas, transpose_y=True)
            image_id_block = image_id_blocks[block_idx]
W
weishengyu 已提交
360 361 362 363 364 365 366 367 368 369 370 371
            image_id_mask = (image_id_block != gallery_img_id)
            similarity_matrix = similarity_matrix.masked_select(image_id_mask)
            if query_camera_id is not None:
                camera_id_block = camera_id_blocks[block_idx]
                camera_id_mask = (camera_id_block != gallery_camera_id)
                similarity_matrix = similarity_matrix.masked_select(
                    camera_id_mask)
            if similarity_matrix is None:
                cum_similarity_matrix = similarity_matrix
            else:
                cum_similarity_matrix = paddle.concat(cum_similarity_matrix,
                                                      similarity_matrix)
W
weishengyu 已提交
372

W
weishengyu 已提交
373 374 375 376 377 378
        # calc metric
        if self.eval_metric_func is not None:
            metric_dict = self.eval_metric_func(cum_similarity_matrix,
                                                query_img_id, gallery_img_id)
        else:
            metric_dict = {metric_key: 0.}
W
weishengyu 已提交
379 380 381 382 383
        metric_msg = ", ".join([
            "{}: {:.5f}".format(key, metric_dict[key].avg)
            for key in metric_dict
        ])
        logger.info("[Eval][Epoch {}][Avg]{}".format(epoch_id, metric_msg))
W
weishengyu 已提交
384

W
weishengyu 已提交
385
        return metric_dict[metric_key]
W
weishengyu 已提交
386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442

    def _cal_feature(self, name='gallery'):
        all_feas = None
        all_image_id = None
        all_camera_id = None
        if name == 'gallery':
            dataloader = self.gallery_dataloader
        elif name == 'query':
            dataloader = self.query_dataloader
        else:
            raise RuntimeError("Only support gallery or query dataset")

        has_cam_id = False
        for idx, batch in enumerate(dataloader(
        )):  # load is very time-consuming
            batch = [paddle.to_tensor(x) for x in batch]
            batch[1] = batch[1].reshape([-1, 1])
            if len(batch) == 3:
                has_cam_id = True
                batch[2] = batch[2].reshape([-1, 1])
            out = self.model(batch[0], batch[1])
            batch_feas = out["features"]

            # do norm
            if self.config["Global"].get("feature_normalize", True):
                feas_norm = paddle.sqrt(
                    paddle.sum(paddle.square(batch_feas), axis=1,
                               keepdim=True))
                batch_feas = paddle.divide(batch_feas, feas_norm)

            if all_feas is None:
                all_feas = batch_feas
                if has_cam_id:
                    all_camera_id = batch[2]
                all_image_id = batch[1]
            else:
                all_feas = paddle.concat([all_feas, batch_feas])
                all_image_id = paddle.concat([all_image_id, batch[1]])
                if has_cam_id:
                    all_camera_id = paddle.concat([all_camera_id, batch[2]])

        if paddle.distributed.get_world_size() > 1:
            feat_list = []
            img_id_list = []
            cam_id_list = []
            paddle.distributed.all_gather(feat_list, all_feas)
            paddle.distributed.all_gather(img_id_list, all_image_id)
            all_feas = paddle.concat(feat_list, axis=0)
            all_image_id = paddle.concat(img_id_list, axis=0)
            if has_cam_id:
                paddle.distributed.all_gather(cam_id_list, all_camera_id)
                all_camera_id = paddle.concat(cam_id_list, axis=0)

        logger.info("Build {} done, all feat shape: {}, begin to eval..".
                    format(name, all_feas.shape))
        return all_feas, all_image_id, all_camera_id

443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474
    @paddle.no_grad()
    def infer(self, ):
        total_trainer = paddle.distributed.get_world_size()
        local_rank = paddle.distributed.get_rank()
        image_list = get_image_list(self.config["Infer"]["infer_imgs"])
        # data split
        image_list = image_list[local_rank::total_trainer]

        preprocess_func = create_operators(self.config["Infer"]["transforms"])
        postprocess_func = build_postprocess(self.config["Infer"][
            "PostProcess"])

        batch_size = self.config["Infer"]["batch_size"]

        self.model.eval()

        batch_data = []
        image_file_list = []
        for idx, image_file in enumerate(image_list):
            with open(image_file, 'rb') as f:
                x = f.read()
            for process in preprocess_func:
                x = process(x)
            batch_data.append(x)
            image_file_list.append(image_file)
            if len(batch_data) >= batch_size or idx == len(image_list) - 1:
                batch_tensor = paddle.to_tensor(batch_data)
                out = self.model(batch_tensor)
                result = postprocess_func(out, image_file_list)
                print(result)
                batch_data.clear()
                image_file_list.clear()