diff --git a/plsc/entry.py b/plsc/entry.py index 589f3ce6855984852d9d83cc7e14411c7ac87b26..748ef946944f5d40d729e9e69c7bafc59dd12b74 100644 --- a/plsc/entry.py +++ b/plsc/entry.py @@ -17,14 +17,13 @@ from __future__ import print_function import errno import json -import logging -import math import os import shutil import subprocess import sys import tempfile import time +import logging import numpy as np import paddle @@ -43,6 +42,7 @@ from .utils import jpeg_reader as reader from .utils.learning_rate import lr_warmup from .utils.parameter_converter import ParameterConverter from .utils.verification import evaluate +from .utils.input_field import InputField log_handler = logging.StreamHandler() log_format = logging.Formatter( @@ -116,7 +116,6 @@ class Entry(object): self.val_targets = self.config.val_targets self.dataset_dir = self.config.dataset_dir self.num_classes = self.config.num_classes - self.image_shape = self.config.image_shape self.loss_type = self.config.loss_type self.margin = self.config.margin self.scale = self.config.scale @@ -142,6 +141,15 @@ class Entry(object): self.lr_decay_factor = 0.1 self.log_period = 200 + self.input_info = [{'name': 'image', + 'shape': [-1, 3, 224, 224], + 'dtype': 'float32'}, + {'name': 'label', + 'shape':[-1, 1], + 'dtype': 'int64'} + ] + self.input_field = None + logger.info('=' * 30) logger.info("Default configuration:") for key in self.config: @@ -152,6 +160,31 @@ class Entry(object): logger.info('default log period: {}'.format(self.log_period)) logger.info('=' * 30) + def set_input_info(self, input): + """ + Set the information of inputs which is a list or tuple. Each element + is a dict which contains the info of a input, including name, dtype + and shape. + """ + if not (isinstance(input, list) or isinstance(input, tuple)): + raise ValueError("The type of 'input' must be list or tuple.") + + has_label = False + for element in input: + assert isinstance(element, dict), ( + "The type of elements for input must be dict") + assert 'name' in element.keys(), ( + "Every element has to contain the key 'name'") + assert 'shape' in element.keys(), ( + "Every element has to contain the key 'shape'") + assert 'dtype' in element.keys(), ( + "Every element has to contain the key 'dtype'") + if element['name'] == 'label': + has_label = True + assert has_label, "The input must contain a field named 'label'" + + self.input_info = input + def set_val_targets(self, targets): """ Set the names of validation datasets, separated by comma. @@ -314,12 +347,6 @@ class Entry(object): self.loss_type = loss_type logger.info("Set loss_type to {}.".format(loss_type)) - def set_image_shape(self, shape): - if not isinstance(shape, (list, tuple)): - raise ValueError("Shape must be of type list or tuple") - self.image_shape = shape - logger.info("Set image_shape to {}.".format(shape)) - def set_optimizer(self, optimizer): if not isinstance(optimizer, Optimizer): raise ValueError("Optimizer must be of type Optimizer") @@ -404,7 +431,6 @@ class Entry(object): trainer_id = self.trainer_id num_trainers = self.num_trainers - image_shape = [int(m) for m in self.image_shape] # model definition model = self.model if model is None: @@ -413,15 +439,11 @@ class Entry(object): startup_program = self.startup_program with fluid.program_guard(main_program, startup_program): with fluid.unique_name.guard(): - image = fluid.layers.data(name='image', - shape=image_shape, - dtype='float32') - label = fluid.layers.data(name='label', - shape=[1], - dtype='int64') - - emb, loss, prob = model.get_output(input=image, - label=label, + input_field = InputField(self.input_info) + input_field.build() + self.input_field = input_field + + emb, loss, prob = model.get_output(input=input_field, num_ranks=num_trainers, rank_id=trainer_id, is_train=is_train, @@ -449,7 +471,7 @@ class Entry(object): num_or_sections=num_trainers) prob = fluid.layers.concat(prob_list, axis=1) label_all = fluid.layers.collective._c_allgather( - label, + input_field.label, nranks=num_trainers, use_calc_stream=True) acc1 = fluid.layers.accuracy(input=prob, @@ -461,10 +483,10 @@ class Entry(object): else: if self.calc_train_acc: acc1 = fluid.layers.accuracy(input=prob, - label=label, + label=input_field.label, k=1) acc5 = fluid.layers.accuracy(input=prob, - label=label, + label=input_field.label, k=5) optimizer = None @@ -489,7 +511,7 @@ class Entry(object): def get_files_from_hdfs(self): assert self.fs_checkpoint_dir, \ logger.error("Please set the fs_checkpoint_dir paramerters for " - "set_hdfs_info to get models from hdfs.") + "set_llllllhdfs_info to get models from hdfs.") self.fs_checkpoint_dir = os.path.join(self.fs_checkpoint_dir, '*') cmd = "hadoop fs -D fs.default.name=" cmd += self.fs_name + " " @@ -631,15 +653,10 @@ class Entry(object): startup_program = self.startup_program with fluid.program_guard(main_program, startup_program): with fluid.unique_name.guard(): - image = fluid.layers.data(name='image', - shape=image_shape, - dtype='float32') - label = fluid.layers.data(name='label', - shape=[1], - dtype='int64') - - emb = model.build_network(input=image, - label=label, + input_field = InputField(self.input_info) + input_field.build() + + emb = model.build_network(input=input_field, is_train=False) gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) @@ -658,8 +675,12 @@ class Entry(object): logger.info("model_save_dir for inference model ({}) exists, " "we will overwrite it.".format(self.model_save_dir)) shutil.rmtree(self.model_save_dir) + feed_var_names = [] + for name in input_field.feed_list_str: + if name == "label": continue + feed_var_names.append(name) fluid.io.save_inference_model(self.model_save_dir, - feeded_var_names=[image.name], + feeded_var_names=feed_var_names, target_vars=[emb], executor=exe, main_program=main_program) @@ -678,7 +699,6 @@ class Entry(object): def predict(self): model_name = self.model_name - image_shape = [int(m) for m in self.image_shape] # model definition model = self.model if model is None: @@ -687,15 +707,10 @@ class Entry(object): startup_program = self.startup_program with fluid.program_guard(main_program, startup_program): with fluid.unique_name.guard(): - image = fluid.layers.data(name='image', - shape=image_shape, - dtype='float32') - label = fluid.layers.data(name='label', - shape=[1], - dtype='int64') - - emb = model.build_network(input=image, - label=label, + input_field = InputField(self.input_info) + input_field.build() + + emb = model.build_network(input=input_field, is_train=False) gpu_id = int(os.getenv("FLAGS_selected_gpus", 0)) @@ -709,20 +724,20 @@ class Entry(object): load_for_train=False) if self.predict_reader is None: - predict_reader = paddle.batch(reader.arc_train(self.dataset_dir, - self.num_classes), - batch_size=self.train_batch_size) + predict_reader = reader.arc_train(self.dataset_dir, + self.num_classes) else: predict_reader = self.predict_reader - feeder = fluid.DataFeeder(place=place, - feed_list=['image', 'label'], - program=main_program) + input_field.loader.set_sample_generator( + predict_reader, + batch_size=self.train_batch_size, + places=place) fetch_list = [emb.name] - for data in predict_reader(): + for data in input_field.loader: emb = exe.run(main_program, - feed=feeder.feed(data), + feed=data, fetch_list=fetch_list, use_program_cache=True) print("emb: ", emb) @@ -741,6 +756,14 @@ class Entry(object): for j in range(len(data_list)): data = data_list[j] embeddings = None + # For multi-card test, the dataset can be partitioned into two + # part. For the first part, the total number of samples is + # divisiable by the number of cards. And then, these samples + # are split on different cards and tested parallely. For the + # second part, these samples are tested on all cards but only + # the result of the first card is used. + + # The number of steps for parallel test. parallel_test_steps = data.shape[0] // real_test_batch_size for idx in range(parallel_test_steps): start = idx * real_test_batch_size @@ -876,7 +899,7 @@ class Entry(object): load_for_train=False) feeder = fluid.DataFeeder(place=place, - feed_list=['image', 'label'], + feed_list=self.input_field.feed_list_str, program=test_program) fetch_list = [emb_name] @@ -940,9 +963,10 @@ class Entry(object): else: train_reader = self.train_reader - feeder = fluid.DataFeeder(place=place, - feed_list=['image', 'label'], - program=origin_prog) + self.input_field.loader.set_sample_generator( + train_reader, + batch_size=self.train_batch_size, + places=place) if self.calc_train_acc: fetch_list = [loss.name, global_lr.name, @@ -958,19 +982,19 @@ class Entry(object): self.train_pass_id = pass_id train_info = [[], [], [], []] local_train_info = [[], [], [], []] - for batch_id, data in enumerate(train_reader()): + for batch_id, data in enumerate(self.input_field.loader): nsamples += global_batch_size t1 = time.time() acc1 = None acc5 = None if self.calc_train_acc: loss, lr, acc1, acc5 = exe.run(train_prog, - feed=feeder.feed(data), + feed=data, fetch_list=fetch_list, use_program_cache=True) else: loss, lr = exe.run(train_prog, - feed=feeder.feed(data), + feed=data, fetch_list=fetch_list, use_program_cache=True) t2 = time.time() diff --git a/plsc/models/base_model.py b/plsc/models/base_model.py index b13a8adb4f96a1bc8c55f66cbf5526188c2bc47b..2d8c50a35223aa66f7eb826976f16ceeb6a7fd09 100644 --- a/plsc/models/base_model.py +++ b/plsc/models/base_model.py @@ -33,7 +33,7 @@ class BaseModel(object): def __init__(self): super(BaseModel, self).__init__() - def build_network(self, input, label, is_train=True): + def build_network(self, input, is_train=True): """ Construct the custom model, and we will add the distributed fc layer at the end of your model automatically. @@ -43,7 +43,6 @@ class BaseModel(object): def get_output(self, input, - label, num_classes, num_ranks=1, rank_id=0, @@ -76,7 +75,8 @@ class BaseModel(object): "Supported loss types: {}, but given: {}".format( supported_loss_types, loss_type) - emb = self.build_network(input, label, is_train) + emb = self.build_network(input, is_train) + label = input.label prob = None loss = None if loss_type == "softmax": diff --git a/plsc/models/resnet.py b/plsc/models/resnet.py index 5424647d0e972453d6a5898e11381f61563a0f00..c6b3159f0b161b790164f97610366de5b6f8a86a 100644 --- a/plsc/models/resnet.py +++ b/plsc/models/resnet.py @@ -27,7 +27,6 @@ class ResNet(BaseModel): def build_network(self, input, - label, is_train=True): layers = self.layers supported_layers = [50, 101, 152] @@ -44,7 +43,7 @@ class ResNet(BaseModel): num_filters = [64, 128, 256, 512] conv = self.conv_bn_layer( - input=input, num_filters=64, filter_size=3, stride=1, + input=input.image, num_filters=64, filter_size=3, stride=1, pad=1, act='prelu', is_train=is_train) for block in range(len(depth)): diff --git a/plsc/utils/input_field.py b/plsc/utils/input_field.py new file mode 100644 index 0000000000000000000000000000000000000000..6fd71168c9c380f8955fc8cea761ffb08ac711bf --- /dev/null +++ b/plsc/utils/input_field.py @@ -0,0 +1,125 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +from __future__ import division +from __future__ import print_function + +import numpy as np +import paddle.fluid as fluid + + +class InputField(object): + """ + A high-level API for handling inputs in PaddlePaddle. + """ + + def __init__(self, input_slots=[]): + + self.shapes = [] + self.dtypes = [] + self.names = [] + self.lod_levels = [] + + self.input_slots = {} + self.feed_list_str = [] + self.feed_list = [] + + self.loader = None + + if input_slots: + for input_slot in input_slots: + self += input_slot + + def __add__(self, input_slot): + + if isinstance(input_slot, list) or isinstance(input_slot, tuple): + name = input_slot[0] + shape = input_slot[1] + dtype = input_slot[2] + lod_level = input_slot[3] if len(input_slot) == 4 else 0 + + if isinstance(input_slot, dict): + name = input_slot["name"] + shape = input_slot["shape"] + dtype = input_slot["dtype"] + lod_level = input_slot[ + "lod_level"] if "lod_level" in input_slot else 0 + + self.shapes.append(shape) + self.dtypes.append(dtype) + self.names.append(name) + self.lod_levels.append(lod_level) + + self.feed_list_str.append(name) + + return self + + def __getattr__(self, name): + + if name not in self.input_slots: + raise Warning("the attr %s has not been defined yet." % name) + return None + + return self.input_slots[name] + + def build(self, capacity=64, iterable=True): + + for _name, _shape, _dtype, _lod_level in zip( + self.names, self.shapes, self.dtypes, self.lod_levels): + self.input_slots[_name] = fluid.data( + name=_name, shape=_shape, dtype=_dtype, lod_level=_lod_level) + + for name in self.feed_list_str: + self.feed_list.append(self.input_slots[name]) + + self.loader = fluid.io.DataLoader.from_generator( + feed_list=self.feed_list, + capacity=capacity, + iterable=iterable, + use_double_buffer=True) + + +if __name__ == "__main__": + + mnist_input_slots = [{ + "name": "image", + "shape": (-1, 32, 32, 1), + "dtype": "int32" + }, { + "name": "label", + "shape": [-1, 1], + "dtype": "int64" + }] + + input_field = InputField(mnist_input_slots) + + input_field += { + "name": "large_image", + "shape": (-1, 64, 64, 1), + "dtype": "int32" + } + input_field += { + "name": "large_color_image", + "shape": (-1, 64, 64, 3), + "dtype": "int32" + } + + input_field.build() + + print(input_field.feed_list) + + print(input_field.image) + + print(input_field.large_color_image)