entry.py 42.9 KB
Newer Older
S
sandyhouse 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import division
16 17 18 19 20
from __future__ import print_function

import errno
import json
import logging
S
sandyhouse 已提交
21
import math
22
import os
S
sandyhouse 已提交
23
import shutil
24 25
import subprocess
import sys
26
import tempfile
27
import time
S
sandyhouse 已提交
28

29
import numpy as np
S
sandyhouse 已提交
30 31
import paddle
import paddle.fluid as fluid
32 33
import paddle.fluid.incubate.fleet.base.role_maker as role_maker
import paddle.fluid.transpiler.distribute_transpiler as dist_transpiler
S
sandyhouse 已提交
34
import sklearn
35 36 37 38
from paddle.fluid.incubate.fleet.collective import fleet, DistributedStrategy
from paddle.fluid.optimizer import Optimizer
from paddle.fluid.transpiler.details.program_utils import program_to_code

S
sandyhouse 已提交
39
from . import config
40
from .models import DistributedClassificationOptimizer
S
sandyhouse 已提交
41
from .models import base_model
42 43
from .models import resnet
from .utils import jpeg_reader as reader
S
sandyhouse 已提交
44
from .utils.learning_rate import lr_warmup
45
from .utils.parameter_converter import ParameterConverter
S
sandyhouse 已提交
46 47 48
from .utils.verification import evaluate

logging.basicConfig(
49 50
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
S
sandyhouse 已提交
51 52
    datefmt='%d %b %Y %H:%M:%S')
logger = logging.getLogger(__name__)
53

S
sandyhouse 已提交
54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

class Entry(object):
    """
    The class to encapsulate all operations.
    """

    def _check(self):
        """
        Check the validation of parameters.
        """
        supported_types = ["softmax", "arcface",
                           "dist_softmax", "dist_arcface"]
        assert self.loss_type in supported_types, \
            "All supported types are {}, but given {}.".format(
                supported_types, self.loss_type)

        if self.loss_type in ["dist_softmax", "dist_arcface"]:
            assert self.num_trainers > 1, \
72 73
                "At least 2 trainers are required for distributed fc-layer. " \
                "You can start your job using paddle.distributed.launch module."
S
sandyhouse 已提交
74 75 76 77

    def __init__(self):
        self.config = config.config
        super(Entry, self).__init__()
78 79
        num_trainers = int(os.getenv("PADDLE_TRAINERS_NUM", 1))
        trainer_id = int(os.getenv("PADDLE_TRAINER_ID", 0))
S
sandyhouse 已提交
80 81 82 83 84 85 86 87 88 89 90 91

        self.trainer_id = trainer_id
        self.num_trainers = num_trainers
        self.train_batch_size = self.config.train_batch_size
        self.test_batch_size = self.config.test_batch_size
        self.global_train_batch_size = self.train_batch_size * num_trainers
        self.global_test_batch_size = self.test_batch_size * num_trainers

        self.optimizer = None
        self.model = None
        self.train_reader = None
        self.test_reader = None
92
        self.predict_reader = None
S
sandyhouse 已提交
93 94 95 96 97 98 99 100

        self.train_program = fluid.Program()
        self.startup_program = fluid.Program()
        self.test_program = fluid.Program()
        self.predict_program = fluid.Program()

        self.fs_name = None
        self.fs_ugi = None
101 102 103 104 105 106 107 108 109
        self.fs_dir_for_save = None
        self.fs_checkpoint_dir = None

        self.param_attr = None
        self.bias_attr = None

        self.has_run_train = False  # Whether has run training or not
        self.test_initialized = False
        self.train_pass_id = -1
S
sandyhouse 已提交
110

D
danleifeng 已提交
111 112
        self.use_fp16 = False
        self.fp16_user_dict = None
D
danleifeng 已提交
113
        self.data_format = 'NCHW'
D
danleifeng 已提交
114

S
sandyhouse 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
        self.val_targets = self.config.val_targets
        self.dataset_dir = self.config.dataset_dir
        self.num_classes = self.config.num_classes
        self.image_shape = self.config.image_shape
        self.loss_type = self.config.loss_type
        self.margin = self.config.margin
        self.scale = self.config.scale
        self.lr = self.config.lr
        self.lr_steps = self.config.lr_steps
        self.train_image_num = self.config.train_image_num
        self.model_name = self.config.model_name
        self.emb_dim = self.config.emb_dim
        self.train_epochs = self.config.train_epochs
        self.checkpoint_dir = self.config.checkpoint_dir
        self.with_test = self.config.with_test
        self.model_save_dir = self.config.model_save_dir
        self.warmup_epochs = self.config.warmup_epochs
132
        self.calc_train_acc = False
S
sandyhouse 已提交
133

134 135 136 137 138 139 140
        if self.checkpoint_dir:
            self.checkpoint_dir = os.path.abspath(self.checkpoint_dir)
        if self.model_save_dir:
            self.model_save_dir = os.path.abspath(self.model_save_dir)
        if self.dataset_dir:
            self.dataset_dir = os.path.abspath(self.dataset_dir)

S
sandyhouse 已提交
141
        logger.info('=' * 30)
142
        logger.info("Default configuration:")
S
sandyhouse 已提交
143 144 145 146 147 148 149
        for key in self.config:
            logger.info('\t' + str(key) + ": " + str(self.config[key]))
        logger.info('trainer_id: {}, num_trainers: {}'.format(
            trainer_id, num_trainers))
        logger.info('=' * 30)

    def set_val_targets(self, targets):
150 151 152
        """
        Set the names of validation datasets, separated by comma.
        """
S
sandyhouse 已提交
153
        self.val_targets = targets
154
        logger.info("Set val_targets to {}.".format(targets))
S
sandyhouse 已提交
155 156 157 158

    def set_train_batch_size(self, batch_size):
        self.train_batch_size = batch_size
        self.global_train_batch_size = batch_size * self.num_trainers
159
        logger.info("Set train batch size to {}.".format(batch_size))
S
sandyhouse 已提交
160

161 162
    def set_mixed_precision(self,
                            use_fp16,
163 164 165 166 167 168 169
                            init_loss_scaling=1.0,
                            incr_every_n_steps=2000,
                            decr_every_n_nan_or_inf=2,
                            incr_ratio=2.0,
                            decr_ratio=0.5,
                            use_dynamic_loss_scaling=True,
                            amp_lists=None):
D
danleifeng 已提交
170 171 172 173
        """
        Whether to use mixed precision training.
        """
        self.use_fp16 = use_fp16
D
danleifeng 已提交
174 175
        if(self.use_fp16):
            self.data_format = 'NHWC'
D
danleifeng 已提交
176
        self.fp16_user_dict = dict()
177 178 179 180 181 182 183
        self.fp16_user_dict['init_loss_scaling'] = init_loss_scaling
        self.fp16_user_dict['incr_every_n_steps'] = incr_every_n_steps
        self.fp16_user_dict['decr_every_n_nan_or_inf'] = decr_every_n_nan_or_inf
        self.fp16_user_dict['incr_ratio'] = incr_ratio
        self.fp16_user_dict['decr_ratio'] = decr_ratio
        self.fp16_user_dict['use_dynamic_loss_scaling'] = use_dynamic_loss_scaling
        self.fp16_user_dict['amp_lists'] = amp_lists
D
danleifeng 已提交
184
        logger.info("Use mixed precision training: {}.".format(use_fp16))
185 186
        for key in self.fp16_user_dict:
            logger.info("Set init {} to {}.".format(key, self.fp16_user_dict[key]))
D
danleifeng 已提交
187

S
sandyhouse 已提交
188 189 190
    def set_test_batch_size(self, batch_size):
        self.test_batch_size = batch_size
        self.global_test_batch_size = batch_size * self.num_trainers
191
        logger.info("Set test batch size to {}.".format(batch_size))
S
sandyhouse 已提交
192

193 194 195 196 197
    def set_hdfs_info(self,
                      fs_name,
                      fs_ugi,
                      fs_dir_for_save=None,
                      fs_checkpoint_dir=None):
S
sandyhouse 已提交
198 199
        """
        Set the info to download from or upload to hdfs filesystems.
200
        If the information is provided, we will download pretrained
S
sandyhouse 已提交
201 202 203 204 205
        model from hdfs at the begining and upload pretrained models
        to hdfs at the end automatically.
        """
        self.fs_name = fs_name
        self.fs_ugi = fs_ugi
206 207
        self.fs_dir_for_save = fs_dir_for_save
        self.fs_checkpoint_dir = fs_checkpoint_dir
S
sandyhouse 已提交
208 209 210
        logger.info("HDFS Info:")
        logger.info("\tfs_name: {}".format(fs_name))
        logger.info("\tfs_ugi: {}".format(fs_ugi))
211 212
        logger.info("\tfs dir for save: {}".format(self.fs_dir_for_save))
        logger.info("\tfs checkpoint dir: {}".format(self.fs_checkpoint_dir))
S
sandyhouse 已提交
213 214 215

    def set_model_save_dir(self, directory):
        """
216
        Set the directory to save models.
S
sandyhouse 已提交
217
        """
218 219
        if directory:
            directory = os.path.abspath(directory)
S
sandyhouse 已提交
220
        self.model_save_dir = directory
221
        logger.info("Set model_save_dir to {}.".format(directory))
S
sandyhouse 已提交
222

223 224 225 226 227
    def set_calc_acc(self, calc):
        """
        Whether to calcuate acc1 and acc5 during training.
        """
        self.calc_train_acc = calc
228
        logger.info("Calculating acc1 and acc5 during training: {}.".format(
229 230
            calc))

S
sandyhouse 已提交
231 232 233 234
    def set_dataset_dir(self, directory):
        """
        Set the root directory for datasets.
        """
235 236
        if directory:
            directory = os.path.abspath(directory)
S
sandyhouse 已提交
237
        self.dataset_dir = directory
238
        logger.info("Set dataset_dir to {}.".format(directory))
S
sandyhouse 已提交
239 240 241 242 243 244

    def set_train_image_num(self, num):
        """
        Set the total number of images for train.
        """
        self.train_image_num = num
245
        logger.info("Set train_image_num to {}.".format(num))
S
sandyhouse 已提交
246 247 248 249 250 251

    def set_class_num(self, num):
        """
        Set the number of classes.
        """
        self.num_classes = num
252
        logger.info("Set num_classes to {}.".format(num))
S
sandyhouse 已提交
253 254 255 256 257

    def set_emb_size(self, size):
        """
        Set the size of the last hidding layer before the distributed fc-layer.
        """
258 259
        self.emb_dim = size
        logger.info("Set emb_dim to {}.".format(size))
S
sandyhouse 已提交
260 261 262 263 264 265 266 267 268

    def set_model(self, model):
        """
        Set user-defined model to use.
        """
        self.model = model
        if not isinstance(model, base_model.BaseModel):
            raise ValueError("The parameter for set_model must be an "
                             "instance of BaseModel.")
269
        logger.info("Set model to {}.".format(model))
S
sandyhouse 已提交
270 271 272 273 274 275

    def set_train_epochs(self, num):
        """
        Set the number of epochs to train.
        """
        self.train_epochs = num
276
        logger.info("Set train_epochs to {}.".format(num))
S
sandyhouse 已提交
277 278 279 280 281

    def set_checkpoint_dir(self, directory):
        """
        Set the directory for checkpoint loaded before training/testing.
        """
282 283
        if directory:
            directory = os.path.abspath(directory)
S
sandyhouse 已提交
284
        self.checkpoint_dir = directory
285
        logger.info("Set checkpoint_dir to {}.".format(directory))
S
sandyhouse 已提交
286 287 288

    def set_warmup_epochs(self, num):
        self.warmup_epochs = num
289
        logger.info("Set warmup_epochs to {}.".format(num))
S
sandyhouse 已提交
290

291
    def set_loss_type(self, loss_type):
S
sandyhouse 已提交
292
        supported_types = ["dist_softmax", "dist_arcface", "softmax", "arcface"]
293
        if loss_type not in supported_types:
S
sandyhouse 已提交
294 295
            raise ValueError("All supported loss types: {}".format(
                supported_types))
296 297
        self.loss_type = loss_type
        logger.info("Set loss_type to {}.".format(loss_type))
S
sandyhouse 已提交
298 299 300

    def set_image_shape(self, shape):
        if not isinstance(shape, (list, tuple)):
301
            raise ValueError("Shape must be of type list or tuple")
S
sandyhouse 已提交
302
        self.image_shape = shape
303
        logger.info("Set image_shape to {}.".format(shape))
S
sandyhouse 已提交
304 305 306

    def set_optimizer(self, optimizer):
        if not isinstance(optimizer, Optimizer):
307
            raise ValueError("Optimizer must be of type Optimizer")
S
sandyhouse 已提交
308
        self.optimizer = optimizer
309 310 311 312 313 314 315 316 317 318 319 320 321
        logger.info("User manually set optimizer.")

    def set_with_test(self, with_test):
        self.with_test = with_test
        logger.info("Set with_test to {}.".format(with_test))

    def set_distfc_attr(self, param_attr=None, bias_attr=None):
        self.param_attr = param_attr
        logger.info("Set param_attr for distfc to {}.".format(self.param_attr))
        if self.bias_attr:
            self.bias_attr = bias_attr
            logger.info(
                "Set bias_attr for distfc to {}.".format(self.bias_attr))
S
sandyhouse 已提交
322

323 324 325 326
    def _get_optimizer(self):
        if not self.optimizer:
            bd = [step for step in self.lr_steps]
            start_lr = self.lr
327

328 329 330 331 332
            global_batch_size = self.global_train_batch_size
            train_image_num = self.train_image_num
            images_per_trainer = int(math.ceil(
                train_image_num * 1.0 / self.num_trainers))
            steps_per_pass = int(math.ceil(
333
                images_per_trainer * 1.0 / self.train_batch_size))
334 335 336 337 338 339 340 341 342
            logger.info("Steps per epoch: %d" % steps_per_pass)
            warmup_steps = steps_per_pass * self.warmup_epochs
            batch_denom = 1024
            base_lr = start_lr * global_batch_size / batch_denom
            lr = [base_lr * (0.1 ** i) for i in range(len(bd) + 1)]
            logger.info("LR boundaries: {}".format(bd))
            logger.info("lr_step: {}".format(lr))
            if self.warmup_epochs:
                lr_val = lr_warmup(fluid.layers.piecewise_decay(boundaries=bd,
343 344 345 346
                                                                values=lr),
                                   warmup_steps,
                                   start_lr,
                                   base_lr)
347 348
            else:
                lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
349

350 351 352 353
            optimizer = fluid.optimizer.Momentum(
                learning_rate=lr_val, momentum=0.9,
                regularization=fluid.regularizer.L2Decay(5e-4))
            self.optimizer = optimizer
354

S
sandyhouse 已提交
355 356
        if self.loss_type in ["dist_softmax", "dist_arcface"]:
            self.optimizer = DistributedClassificationOptimizer(
357 358 359
                self.optimizer,
                self.train_batch_size,
                use_fp16=self.use_fp16,
D
danleifeng 已提交
360 361 362 363
                loss_type=self.loss_type,
                fp16_user_dict=self.fp16_user_dict)
        elif self.use_fp16:
            self.optimizer = fluid.contrib.mixed_precision.decorate(
364
                optimizer=self.optimizer,
365 366
                init_loss_scaling=self.fp16_user_dict['init_loss_scaling'],
                incr_every_n_steps=self.fp16_user_dict['incr_every_n_steps'],
367 368
                decr_every_n_nan_or_inf=self.fp16_user_dict[
                    'decr_every_n_nan_or_inf'],
369 370
                incr_ratio=self.fp16_user_dict['incr_ratio'],
                decr_ratio=self.fp16_user_dict['decr_ratio'],
371 372
                use_dynamic_loss_scaling=self.fp16_user_dict[
                    'use_dynamic_loss_scaling'],
373
                amp_lists=self.fp16_user_dict['amp_lists']
374
            )
S
sandyhouse 已提交
375 376 377 378
        return self.optimizer

    def build_program(self,
                      is_train=True,
379 380
                      use_parallel_test=False,
                      dist_strategy=None):
S
sandyhouse 已提交
381 382 383 384 385 386 387 388
        model_name = self.model_name
        assert not (is_train and use_parallel_test), \
            "is_train and use_parallel_test cannot be set simultaneously."

        trainer_id = self.trainer_id
        num_trainers = self.num_trainers

        image_shape = [int(m) for m in self.image_shape]
D
danleifeng 已提交
389 390
        if self.data_format == "NHWC":
            image_shape=[image_shape[1], image_shape[2], image_shape[0]]
S
sandyhouse 已提交
391 392 393 394 395 396 397 398 399
        # model definition
        model = self.model
        if model is None:
            model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
        main_program = self.train_program if is_train else self.test_program
        startup_program = self.startup_program
        with fluid.program_guard(main_program, startup_program):
            with fluid.unique_name.guard():
                image = fluid.layers.data(name='image',
400 401
                                          shape=image_shape,
                                          dtype='float32')
S
sandyhouse 已提交
402
                label = fluid.layers.data(name='label',
403 404 405 406 407 408 409 410 411 412 413 414 415
                                          shape=[1],
                                          dtype='int64')

                emb, loss, prob = model.get_output(input=image,
                                                   label=label,
                                                   num_ranks=num_trainers,
                                                   rank_id=trainer_id,
                                                   is_train=is_train,
                                                   num_classes=self.num_classes,
                                                   loss_type=self.loss_type,
                                                   param_attr=self.param_attr,
                                                   bias_attr=self.bias_attr,
                                                   margin=self.margin,
D
danleifeng 已提交
416 417
                                                   scale=self.scale,
                                                   data_format=self.data_format)
418

419 420
                acc1 = None
                acc5 = None
S
sandyhouse 已提交
421

422 423 424 425
                if self.loss_type in ["dist_softmax", "dist_arcface"]:
                    if self.calc_train_acc:
                        shard_prob = loss._get_info("shard_prob")

426 427 428 429 430 431 432
                        prob_all = fluid.layers.collective._c_allgather(
                            shard_prob,
                            nranks=num_trainers,
                            use_calc_stream=True)
                        prob_list = fluid.layers.split(
                            prob_all,
                            dim=0,
433 434
                            num_or_sections=num_trainers)
                        prob = fluid.layers.concat(prob_list, axis=1)
435 436 437 438 439 440 441 442 443 444
                        label_all = fluid.layers.collective._c_allgather(
                            label,
                            nranks=num_trainers,
                            use_calc_stream=True)
                        acc1 = fluid.layers.accuracy(input=prob,
                                                     label=label_all,
                                                     k=1)
                        acc5 = fluid.layers.accuracy(input=prob,
                                                     label=label_all,
                                                     k=5)
S
sandyhouse 已提交
445
                else:
446
                    if self.calc_train_acc:
447 448 449 450 451 452
                        acc1 = fluid.layers.accuracy(input=prob,
                                                     label=label,
                                                     k=1)
                        acc5 = fluid.layers.accuracy(input=prob,
                                                     label=label,
                                                     k=5)
453

S
sandyhouse 已提交
454 455 456
                optimizer = None
                if is_train:
                    # initialize optimizer
457
                    optimizer = self._get_optimizer()
458 459 460 461 462 463
                    if self.num_trainers > 1:
                        dist_optimizer = fleet.distributed_optimizer(
                            optimizer, strategy=dist_strategy)
                        dist_optimizer.minimize(loss)
                    else:  # single card training
                        optimizer.minimize(loss)
D
danleifeng 已提交
464
                    if "dist" in self.loss_type or self.use_fp16:
465
                        optimizer = optimizer._optimizer
S
sandyhouse 已提交
466
                elif use_parallel_test:
467 468 469 470
                    emb = fluid.layers.collective._c_allgather(
                        emb,
                        nranks=num_trainers,
                        use_calc_stream=True)
S
sandyhouse 已提交
471 472
        return emb, loss, acc1, acc5, optimizer

473 474 475 476 477
    def get_files_from_hdfs(self):
        assert self.fs_checkpoint_dir, \
            logger.error("Please set the fs_checkpoint_dir paramerters for "
                         "set_hdfs_info to get models from hdfs.")
        self.fs_checkpoint_dir = os.path.join(self.fs_checkpoint_dir, '*')
S
sandyhouse 已提交
478 479 480 481
        cmd = "hadoop fs -D fs.default.name="
        cmd += self.fs_name + " "
        cmd += "-D hadoop.job.ugi="
        cmd += self.fs_ugi + " "
482 483
        cmd += "-get " + self.fs_checkpoint_dir
        cmd += " " + self.checkpoint_dir
S
sandyhouse 已提交
484 485 486
        logger.info("hdfs download cmd: {}".format(cmd))
        cmd = cmd.split(' ')
        process = subprocess.Popen(cmd,
487 488
                                   stdout=sys.stdout,
                                   stderr=subprocess.STDOUT)
S
sandyhouse 已提交
489 490 491
        process.wait()

    def put_files_to_hdfs(self, local_dir):
492 493 494
        assert self.fs_dir_for_save, \
            logger.error("Please set fs_dir_for_save paramerter "
                         "for set_hdfs_info to save models to hdfs.")
S
sandyhouse 已提交
495 496 497 498 499
        cmd = "hadoop fs -D fs.default.name="
        cmd += self.fs_name + " "
        cmd += "-D hadoop.job.ugi="
        cmd += self.fs_ugi + " "
        cmd += "-put " + local_dir
500
        cmd += " " + self.fs_dir_for_save
S
sandyhouse 已提交
501 502 503
        logger.info("hdfs upload cmd: {}".format(cmd))
        cmd = cmd.split(' ')
        process = subprocess.Popen(cmd,
504 505
                                   stdout=sys.stdout,
                                   stderr=subprocess.STDOUT)
S
sandyhouse 已提交
506 507
        process.wait()

508
    def process_distributed_params(self, local_dir):
S
sandyhouse 已提交
509
        local_dir = os.path.abspath(local_dir)
510
        output_dir = tempfile.mkdtemp()
511 512
        converter = ParameterConverter(local_dir, output_dir, self.num_trainers)
        converter.process()
513

S
sandyhouse 已提交
514 515 516 517 518 519 520 521 522 523 524
        for file in os.listdir(local_dir):
            if "dist@" in file and "@rank@" in file:
                file = os.path.join(local_dir, file)
                os.remove(file)

        for file in os.listdir(output_dir):
            if "dist@" in file and "@rank@" in file:
                file = os.path.join(output_dir, file)
                shutil.move(file, local_dir)
        shutil.rmtree(output_dir)

525
    def _append_broadcast_ops(self, program):
S
sandyhouse 已提交
526
        """
527
        Before test, we broadcast bathnorm-related parameters to all
528
        other trainers from trainer-0.
S
sandyhouse 已提交
529
        """
530
        bn_vars = [var for var in program.list_vars()
S
sandyhouse 已提交
531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553
                   if 'batch_norm' in var.name and var.persistable]
        block = program.current_block()
        for var in bn_vars:
            block._insert_op(
                0,
                type='c_broadcast',
                inputs={'X': var},
                outputs={'Out': var},
                attrs={'use_calc_stream': True})

    def load_checkpoint(self,
                        executor,
                        main_program,
                        use_per_trainer_checkpoint=False,
                        load_for_train=True):
        if use_per_trainer_checkpoint:
            checkpoint_dir = os.path.join(
                self.checkpoint_dir, str(self.trainer_id))
        else:
            checkpoint_dir = self.checkpoint_dir

        if self.fs_name is not None:
            if os.path.exists(checkpoint_dir):
L
lilong12 已提交
554 555
                logger.info("Local dir {} exists, we'll overwrite it.".format(
                    checkpoint_dir))
556

557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572
            # sync all trainers to avoid loading checkpoints before 
            # parameters are downloaded
            file_name = os.path.join(checkpoint_dir, '.lock')
            if self.trainer_id == 0:
                self.get_files_from_hdfs()
                with open(file_name, 'w') as f:
                    pass
                time.sleep(10)
                os.remove(file_name)     
            else:
                while True:
                    if not os.path.exists(file_name):
                        time.sleep(1)
                    else:
                        break
        
S
sandyhouse 已提交
573 574
        # Preporcess distributed parameters.
        file_name = os.path.join(checkpoint_dir, '.lock')
575 576 577 578 579 580
        meta_file = os.path.join(checkpoint_dir, 'meta.json')
        if not os.path.exists(meta_file):
            logger.error("Please make sure the checkpoint dir {} exists, and "
                         "parameters in that dir are validating.".format(
                             checkpoint_dir))
            exit()
S
sandyhouse 已提交
581 582
        distributed = self.loss_type in ["dist_softmax", "dist_arcface"]
        if load_for_train and self.trainer_id == 0 and distributed:
583
            self.process_distributed_params(checkpoint_dir)
584 585 586
            with open(file_name, 'w') as f:
                pass
            time.sleep(10)
587
            os.remove(file_name)
S
sandyhouse 已提交
588 589 590 591 592 593 594 595 596 597 598
        elif load_for_train and distributed:
            # wait trainer_id (0) to complete
            while True:
                if not os.path.exists(file_name):
                    time.sleep(1)
                else:
                    break

        def if_exist(var):
            has_var = os.path.exists(os.path.join(checkpoint_dir, var.name))
            if has_var:
599
                logger.info('var: %s found' % (var.name))
S
sandyhouse 已提交
600 601
            return has_var

602 603 604 605
        fluid.io.load_vars(executor,
                           checkpoint_dir,
                           predicate=if_exist,
                           main_program=main_program)
S
sandyhouse 已提交
606 607 608 609 610 611 612 613

    def convert_for_prediction(self):
        model_name = self.model_name
        image_shape = [int(m) for m in self.image_shape]
        # model definition
        model = self.model
        if model is None:
            model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
614
        main_program = self.predict_program
S
sandyhouse 已提交
615 616 617 618
        startup_program = self.startup_program
        with fluid.program_guard(main_program, startup_program):
            with fluid.unique_name.guard():
                image = fluid.layers.data(name='image',
619 620
                                          shape=image_shape,
                                          dtype='float32')
S
sandyhouse 已提交
621
                label = fluid.layers.data(name='label',
622 623
                                          shape=[1],
                                          dtype='int64')
S
sandyhouse 已提交
624

625 626
                emb = model.build_network(input=image,
                                          label=label,
D
danleifeng 已提交
627 628
                                          is_train=False,
                                          data_format=self.data_format)
S
sandyhouse 已提交
629 630 631 632 633 634 635

        gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
        place = fluid.CUDAPlace(gpu_id)
        exe = fluid.Executor(place)
        exe.run(startup_program)

        assert self.checkpoint_dir, "No checkpoint found for converting."
636 637 638
        self.load_checkpoint(executor=exe,
                             main_program=main_program,
                             load_for_train=False)
S
sandyhouse 已提交
639 640

        assert self.model_save_dir, \
641
            "Does not set model_save_dir for inference model converting."
S
sandyhouse 已提交
642
        if os.path.exists(self.model_save_dir):
L
lilong12 已提交
643 644
            logger.info("model_save_dir for inference model ({}) exists, "
                        "we will overwrite it.".format(self.model_save_dir))
S
sandyhouse 已提交
645 646 647 648 649 650 651
            shutil.rmtree(self.model_save_dir)
        fluid.io.save_inference_model(self.model_save_dir,
                                      feeded_var_names=[image.name],
                                      target_vars=[emb],
                                      executor=exe,
                                      main_program=main_program)
        if self.fs_name:
L
lilong12 已提交
652
            self.put_files_to_hdfs(self.model_save_dir)
S
sandyhouse 已提交
653

654 655 656 657 658 659 660 661 662 663
    def _set_info(self, key, value):
        if not hasattr(self, '_info'):
            self._info = {}
        self._info[key] = value

    def _get_info(self, key):
        if hasattr(self, '_info') and key in self._info:
            return self._info[key]
        return None

S
sandyhouse 已提交
664 665 666 667 668 669 670 671 672 673 674 675
    def predict(self):
        model_name = self.model_name
        image_shape = [int(m) for m in self.image_shape]
        # model definition
        model = self.model
        if model is None:
            model = resnet.__dict__[model_name](emb_dim=self.emb_dim)
        main_program = self.predict_program
        startup_program = self.startup_program
        with fluid.program_guard(main_program, startup_program):
            with fluid.unique_name.guard():
                image = fluid.layers.data(name='image',
676 677
                                          shape=image_shape,
                                          dtype='float32')
S
sandyhouse 已提交
678
                label = fluid.layers.data(name='label',
679 680
                                          shape=[1],
                                          dtype='int64')
S
sandyhouse 已提交
681

682 683 684
                emb = model.build_network(input=image,
                                          label=label,
                                          is_train=False)
S
sandyhouse 已提交
685 686 687 688 689 690 691

        gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
        place = fluid.CUDAPlace(gpu_id)
        exe = fluid.Executor(place)
        exe.run(startup_program)

        assert self.checkpoint_dir, "No checkpoint found for predicting."
692 693 694 695 696 697
        self.load_checkpoint(executor=exe,
                             main_program=main_program,
                             load_for_train=False)

        if self.predict_reader is None:
            predict_reader = paddle.batch(reader.arc_train(self.dataset_dir,
D
danleifeng 已提交
698 699
                                                           self.num_classes,
                                                           data_format=self.data_format),
700
                                          batch_size=self.train_batch_size)
S
sandyhouse 已提交
701
        else:
702
            predict_reader = self.predict_reader
S
sandyhouse 已提交
703 704

        feeder = fluid.DataFeeder(place=place,
705 706
                                  feed_list=['image', 'label'],
                                  program=main_program)
707

S
sandyhouse 已提交
708
        fetch_list = [emb.name]
709
        for data in predict_reader():
710 711 712 713
            emb = exe.run(main_program,
                          feed=feeder.feed(data),
                          fetch_list=fetch_list,
                          use_program_cache=True)
S
sandyhouse 已提交
714 715
            print("emb: ", emb)

716 717 718 719 720 721
    def _run_test(self,
                  exe,
                  test_list,
                  test_name_list,
                  feeder,
                  fetch_list):
S
sandyhouse 已提交
722 723
        trainer_id = self.trainer_id
        real_test_batch_size = self.global_test_batch_size
724
        for i in range(len(test_list)):
S
sandyhouse 已提交
725 726
            data_list, issame_list = test_list[i]
            embeddings_list = []
727
            for j in range(len(data_list)):
S
sandyhouse 已提交
728 729 730 731 732 733 734 735 736
                data = data_list[j]
                embeddings = None
                parallel_test_steps = data.shape[0] // real_test_batch_size
                for idx in range(parallel_test_steps):
                    start = idx * real_test_batch_size
                    offset = trainer_id * self.test_batch_size
                    begin = start + offset
                    end = begin + self.test_batch_size
                    _data = []
737
                    for k in range(begin, end):
S
sandyhouse 已提交
738 739
                        _data.append((data[k], 0))
                    assert len(_data) == self.test_batch_size
740 741 742 743
                    [_embeddings] = exe.run(self.test_program,
                                            fetch_list=fetch_list,
                                            feed=feeder.feed(_data),
                                            use_program_cache=True)
S
sandyhouse 已提交
744
                    if embeddings is None:
745 746 747 748
                        embeddings = np.zeros((data.shape[0],
                                               _embeddings.shape[1]))
                    end = start + real_test_batch_size
                    embeddings[start:end, :] = _embeddings[:, :]
S
sandyhouse 已提交
749
                beg = parallel_test_steps * real_test_batch_size
750

S
sandyhouse 已提交
751 752 753 754
                while beg < data.shape[0]:
                    end = min(beg + self.test_batch_size, data.shape[0])
                    count = end - beg
                    _data = []
755
                    for k in range(end - self.test_batch_size, end):
S
sandyhouse 已提交
756
                        _data.append((data[k], 0))
757 758 759 760 761 762 763
                    [_embeddings] = exe.run(self.test_program,
                                            fetch_list=fetch_list,
                                            feed=feeder.feed(_data),
                                            use_program_cache=True)
                    _embeddings = _embeddings[0:self.test_batch_size, :]
                    embeddings[beg:end, :] = _embeddings[
                                             (self.test_batch_size - count):, :]
S
sandyhouse 已提交
764 765
                    beg = end
                embeddings_list.append(embeddings)
766

S
sandyhouse 已提交
767 768 769 770 771 772
            xnorm = 0.0
            xnorm_cnt = 0
            for embed in embeddings_list:
                xnorm += np.sqrt((embed * embed).sum(axis=1)).sum(axis=0)
                xnorm_cnt += embed.shape[0]
            xnorm /= xnorm_cnt
773

S
sandyhouse 已提交
774 775
            embeddings = embeddings_list[0] + embeddings_list[1]
            embeddings = sklearn.preprocessing.normalize(embeddings)
776 777 778
            _, _, accuracy, val, val_std, far = evaluate(embeddings,
                                                         issame_list,
                                                         nrof_folds=10)
S
sandyhouse 已提交
779
            acc, std = np.mean(accuracy), np.std(accuracy)
780

781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796
            if self.train_pass_id >= 0:
                logger.info('[{}][{}]XNorm: {:.5f}'.format(test_name_list[i],
                                                           self.train_pass_id,
                                                           xnorm))
                logger.info('[{}][{}]Accuracy-Flip: {:.5f}+-{:.5f}'.format(
                    test_name_list[i],
                    self.train_pass_id,
                    acc,
                    std))
            else:
                logger.info('[{}]XNorm: {:.5f}'.format(test_name_list[i],
                                                       xnorm))
                logger.info('[{}]Accuracy-Flip: {:.5f}+-{:.5f}'.format(
                    test_name_list[i],
                    acc,
                    std))
S
sandyhouse 已提交
797
            sys.stdout.flush()
798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876

    def test(self):
        self._check()

        trainer_id = self.trainer_id
        num_trainers = self.num_trainers

        # if the test program is not built, which means that is the first time
        # to call the test method, we will first build the test program and
        # add ops to broadcast bn-related parameters from trainer 0 to other
        # trainers for distributed tests.
        if not self.test_initialized:
            emb, loss, _, _, _ = self.build_program(False,
                                                    self.num_trainers > 1)
            emb_name = emb.name
            assert self._get_info(emb_name) is None
            self._set_info('emb_name', emb.name)

            if num_trainers > 1 and self.has_run_train:
                self._append_broadcast_ops(self.test_program)

            if num_trainers > 1 and not self.has_run_train:
                worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS")
                current_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT")

                config = dist_transpiler.DistributeTranspilerConfig()
                config.mode = "collective"
                config.collective_mode = "grad_allreduce"
                t = dist_transpiler.DistributeTranspiler(config=config)
                t.transpile(trainer_id=trainer_id,
                            trainers=worker_endpoints,
                            startup_program=self.startup_program,
                            program=self.test_program,
                            current_endpoint=current_endpoint)
        else:
            emb_name = self._get_info('emb_name')

        gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
        place = fluid.CUDAPlace(gpu_id)
        exe = fluid.Executor(place)
        if not self.has_run_train:
            exe.run(self.startup_program)

        if not self.test_reader:
            test_reader = reader.test
        else:
            test_reader = self.test_reader
        if not self.test_initialized:
            test_list, test_name_list = test_reader(self.dataset_dir,
                                                    self.val_targets)
            assert self._get_info('test_list') is None
            assert self._get_info('test_name_list') is None
            self._set_info('test_list', test_list)
            self._set_info('test_name_list', test_name_list)
        else:
            test_list = self._get_info('test_list')
            test_name_list = self._get_info('test_name_list')

        test_program = self.test_program

        if not self.has_run_train:
            assert self.checkpoint_dir, "No checkpoint found for test."
            self.load_checkpoint(executor=exe,
                                 main_program=test_program,
                                 load_for_train=False)

        feeder = fluid.DataFeeder(place=place,
                                  feed_list=['image', 'label'],
                                  program=test_program)
        fetch_list = [emb_name]

        self.test_initialized = True

        test_start = time.time()
        self._run_test(exe,
                       test_list,
                       test_name_list,
                       feeder,
                       fetch_list)
S
sandyhouse 已提交
877
        test_end = time.time()
878
        logger.info("test time: {:.4f}".format(test_end - test_start))
S
sandyhouse 已提交
879 880 881

    def train(self):
        self._check()
882
        self.has_run_train = True
S
sandyhouse 已提交
883 884 885

        trainer_id = self.trainer_id
        num_trainers = self.num_trainers
886

887 888 889 890 891 892 893 894 895 896 897 898
        strategy = None
        if num_trainers > 1:
            role = role_maker.PaddleCloudRoleMaker(is_collective=True)
            fleet.init(role)
            strategy = DistributedStrategy()
            strategy.mode = "collective"
            strategy.collective_mode = "grad_allreduce"
        emb, loss, acc1, acc5, optimizer = self.build_program(
            True,
            False,
            dist_strategy=strategy)
    
899 900
        global_lr = optimizer._global_learning_rate(
            program=self.train_program)
901 902 903 904 905 906 907
    
        if num_trainers > 1:
            origin_prog = fleet._origin_program
            train_prog = fleet.main_program
        else:
            origin_prog = self.train_program
            train_prog = self.train_program
908

S
sandyhouse 已提交
909 910 911 912 913 914 915
        if trainer_id == 0:
            with open('start.program', 'w') as fout:
                program_to_code(self.startup_program, fout, True)
            with open('main.program', 'w') as fout:
                program_to_code(train_prog, fout, True)
            with open('origin.program', 'w') as fout:
                program_to_code(origin_prog, fout, True)
916

S
sandyhouse 已提交
917 918 919 920
        gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
        place = fluid.CUDAPlace(gpu_id)
        exe = fluid.Executor(place)
        exe.run(self.startup_program)
921

922
        if self.checkpoint_dir:
S
sandyhouse 已提交
923
            load_checkpoint = True
924 925
        else:
            load_checkpoint = False
S
sandyhouse 已提交
926 927
        if load_checkpoint:
            self.load_checkpoint(executor=exe, main_program=origin_prog)
928

S
sandyhouse 已提交
929 930
        if self.train_reader is None:
            train_reader = paddle.batch(reader.arc_train(
D
danleifeng 已提交
931
                self.dataset_dir, self.num_classes, data_format=self.data_format),
S
sandyhouse 已提交
932 933 934 935 936
                batch_size=self.train_batch_size)
        else:
            train_reader = self.train_reader

        feeder = fluid.DataFeeder(place=place,
937 938 939
                                  feed_list=['image', 'label'],
                                  program=origin_prog)
    
940
        if self.calc_train_acc:
941 942
            fetch_list = [loss.name, global_lr.name,
                          acc1.name, acc5.name]
943
        else:
944 945
            fetch_list = [loss.name, global_lr.name]
    
S
sandyhouse 已提交
946 947 948 949 950
        local_time = 0.0
        nsamples = 0
        inspect_steps = 200
        global_batch_size = self.global_train_batch_size
        for pass_id in range(self.train_epochs):
951
            self.train_pass_id = pass_id
S
sandyhouse 已提交
952 953 954 955 956
            train_info = [[], [], [], []]
            local_train_info = [[], [], [], []]
            for batch_id, data in enumerate(train_reader()):
                nsamples += global_batch_size
                t1 = time.time()
957 958
                acc1 = None
                acc5 = None
959 960
                if self.calc_train_acc:
                    loss, lr, acc1, acc5 = exe.run(train_prog,
961 962 963
                                                   feed=feeder.feed(data),
                                                   fetch_list=fetch_list,
                                                   use_program_cache=True)
964
                else:
965 966 967 968
                    loss, lr = exe.run(train_prog,
                                       feed=feeder.feed(data),
                                       fetch_list=fetch_list,
                                       use_program_cache=True)
S
sandyhouse 已提交
969 970 971 972 973 974 975 976 977 978
                t2 = time.time()
                period = t2 - t1
                local_time += period
                train_info[0].append(np.array(loss)[0])
                train_info[1].append(np.array(lr)[0])
                local_train_info[0].append(np.array(loss)[0])
                local_train_info[1].append(np.array(lr)[0])
                if batch_id % inspect_steps == 0:
                    avg_loss = np.mean(local_train_info[0])
                    avg_lr = np.mean(local_train_info[1])
979
                    speed = nsamples / local_time
980
                    if self.calc_train_acc:
981 982 983 984 985 986 987 988 989
                        logger.info("Pass:{} batch:%d lr:{:.8f} loss:{:.6f} "
                                    "qps:{:.2f} acc1:{:.6f} acc5:{:.6f}".format(
                            pass_id,
                            batch_id,
                            avg_lr,
                            avg_loss,
                            speed,
                            acc1,
                            acc5))
990
                    else:
991 992 993 994 995 996
                        logger.info("Pass:{} batch:{} lr:{:.8f} loss:{:.6f} "
                                    "qps:{:.2f}".format(pass_id,
                                                        batch_id,
                                                        avg_lr,
                                                        avg_loss,
                                                        speed))
S
sandyhouse 已提交
997 998 999
                    local_time = 0
                    nsamples = 0
                    local_train_info = [[], [], [], []]
1000

S
sandyhouse 已提交
1001
            train_loss = np.array(train_info[0]).mean()
1002 1003
            logger.info("End pass {}, train_loss {:.6f}".format(pass_id,
                                                                train_loss))
S
sandyhouse 已提交
1004 1005 1006
            sys.stdout.flush()

            if self.with_test:
1007 1008 1009
                self.test()

            # save model
S
sandyhouse 已提交
1010 1011 1012 1013
            if self.model_save_dir:
                model_save_dir = os.path.join(
                    self.model_save_dir, str(pass_id))
                if not os.path.exists(model_save_dir):
1014
                    # may be more than one processes trying
1015 1016 1017 1018 1019 1020 1021
                    # to create the directory
                    try:
                        os.makedirs(model_save_dir)
                    except OSError as exc:
                        if exc.errno != errno.EEXIST:
                            raise
                        pass
S
sandyhouse 已提交
1022 1023
                if trainer_id == 0:
                    fluid.io.save_persistables(exe,
1024 1025
                                               model_save_dir,
                                               origin_prog)
S
sandyhouse 已提交
1026 1027 1028 1029 1030
                else:
                    def save_var(var):
                        to_save = "dist@" in var.name and '@rank@' in var.name
                        return to_save and var.persistable

1031 1032 1033 1034 1035 1036
                    fluid.io.save_vars(exe,
                                       model_save_dir,
                                       origin_prog,
                                       predicate=save_var)

            # save training info
S
sandyhouse 已提交
1037 1038
            if self.model_save_dir and trainer_id == 0:
                config_file = os.path.join(
1039
                    self.model_save_dir, str(pass_id), 'meta.json')
S
sandyhouse 已提交
1040 1041 1042 1043
                train_info = dict()
                train_info["pretrain_nranks"] = self.num_trainers
                train_info["emb_dim"] = self.emb_dim
                train_info['num_classes'] = self.num_classes
1044 1045
                with open(config_file, 'w') as f:
                    json.dump(train_info, f)
S
sandyhouse 已提交
1046

1047
        # upload model
S
sandyhouse 已提交
1048 1049
        if self.model_save_dir and self.fs_name and trainer_id == 0:
            self.put_files_to_hdfs(self.model_save_dir)
1050

S
sandyhouse 已提交
1051 1052 1053 1054

if __name__ == '__main__':
    ins = Entry()
    ins.train()