diff --git a/tools/program.py b/tools/program.py index b63d416c709279a8b18fe322d90aded83e821b16..34541043623ea9c4f78387488203f57e7fa8a0c7 100644 --- a/tools/program.py +++ b/tools/program.py @@ -49,7 +49,7 @@ def create_dataloader(): dataloader(fluid dataloader): """ trainer_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1)) - capacity = 64 if trainer_num <= 1 else 8 + capacity = 64 if trainer_num == 1 else 8 dataloader = fluid.io.DataLoader.from_generator( capacity=capacity, use_double_buffer=True, iterable=True) @@ -163,15 +163,7 @@ def create_metric(out, return fetchs -def create_fetchs(feeds, - out, - config, - architecture, - topk=5, - classes_num=1000, - epsilon=None, - use_mix=False, - use_distillation=False): +def create_fetchs(feeds, net, config, mode="train"): """ Create fetchs as model outputs(included loss and measures), will call create_loss and create_metric(if use_mix). @@ -190,6 +182,15 @@ def create_fetchs(feeds, Returns: fetchs(dict): dict of model outputs(included loss and measures) """ + architecture = config.ARCHITECTURE + topk = config.topk + classes_num = config.classes_num + epsilon = config.get('ls_epsilon') + use_mix = config.get('use_mix') and mode == 'train' + use_distillation = config.get('use_distillation') + + out = net(feeds["image"]) + fetchs = OrderedDict() fetchs['loss'] = create_loss(feeds, out, architecture, classes_num, epsilon, use_mix, use_distillation) @@ -276,40 +277,6 @@ def mixed_precision_optimizer(config, optimizer): return optimizer -def compute(feeds, net, config, mode='train'): - """ - Build a program using a model and an optimizer - 1. create feeds - 2. create a dataloader - 3. create a model - 4. create fetchs - 5. create an optimizer - - Args: - config(dict): config - main_prog(): main program - startup_prog(): startup program - is_train(bool): train or valid - - Returns: - dataloader(): a bridge between the model and the data - fetchs(dict): dict of model outputs(included loss and measures) - """ - out = net(feeds["image"]) - fetchs = create_fetchs( - feeds, - out, - config, - config.ARCHITECTURE, - config.topk, - config.classes_num, - epsilon=config.get('ls_epsilon'), - use_mix=config.get('use_mix') and mode == 'train', - use_distillation=config.get('use_distillation')) - - return fetchs - - def create_feeds(batch, use_mix): image = to_variable(batch[0].numpy().astype("float32")) if use_mix: @@ -360,7 +327,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'): for idx, batch in enumerate(dataloader()): batch_size = len(batch[0]) feeds = create_feeds(batch, use_mix) - fetchs = compute(feeds, net, config, mode) + fetchs = create_fetchs(feeds, net, config, mode) if mode == 'train': avg_loss = net.scale_loss(fetchs['loss']) avg_loss.backward()