未验证 提交 59f79be2 编写于 作者: L lilong12 提交者: GitHub

Support universal inputs (#56)

* add support for universal input
上级 632a8f3d
......@@ -17,14 +17,13 @@ from __future__ import print_function
import errno
import json
import logging
import math
import os
import shutil
import subprocess
import sys
import tempfile
import time
import logging
import numpy as np
import paddle
......@@ -43,6 +42,7 @@ from .utils import jpeg_reader as reader
from .utils.learning_rate import lr_warmup
from .utils.parameter_converter import ParameterConverter
from .utils.verification import evaluate
from .utils.input_field import InputField
log_handler = logging.StreamHandler()
log_format = logging.Formatter(
......@@ -116,7 +116,6 @@ class Entry(object):
self.val_targets = self.config.val_targets
self.dataset_dir = self.config.dataset_dir
self.num_classes = self.config.num_classes
self.image_shape = self.config.image_shape
self.loss_type = self.config.loss_type
self.margin = self.config.margin
self.scale = self.config.scale
......@@ -142,6 +141,15 @@ class Entry(object):
self.lr_decay_factor = 0.1
self.log_period = 200
self.input_info = [{'name': 'image',
'shape': [-1, 3, 224, 224],
'dtype': 'float32'},
{'name': 'label',
'shape':[-1, 1],
'dtype': 'int64'}
]
self.input_field = None
logger.info('=' * 30)
logger.info("Default configuration:")
for key in self.config:
......@@ -152,6 +160,31 @@ class Entry(object):
logger.info('default log period: {}'.format(self.log_period))
logger.info('=' * 30)
def set_input_info(self, input):
"""
Set the information of inputs which is a list or tuple. Each element
is a dict which contains the info of a input, including name, dtype
and shape.
"""
if not (isinstance(input, list) or isinstance(input, tuple)):
raise ValueError("The type of 'input' must be list or tuple.")
has_label = False
for element in input:
assert isinstance(element, dict), (
"The type of elements for input must be dict")
assert 'name' in element.keys(), (
"Every element has to contain the key 'name'")
assert 'shape' in element.keys(), (
"Every element has to contain the key 'shape'")
assert 'dtype' in element.keys(), (
"Every element has to contain the key 'dtype'")
if element['name'] == 'label':
has_label = True
assert has_label, "The input must contain a field named 'label'"
self.input_info = input
def set_val_targets(self, targets):
"""
Set the names of validation datasets, separated by comma.
......@@ -314,12 +347,6 @@ class Entry(object):
self.loss_type = loss_type
logger.info("Set loss_type to {}.".format(loss_type))
def set_image_shape(self, shape):
if not isinstance(shape, (list, tuple)):
raise ValueError("Shape must be of type list or tuple")
self.image_shape = shape
logger.info("Set image_shape to {}.".format(shape))
def set_optimizer(self, optimizer):
if not isinstance(optimizer, Optimizer):
raise ValueError("Optimizer must be of type Optimizer")
......@@ -404,7 +431,6 @@ class Entry(object):
trainer_id = self.trainer_id
num_trainers = self.num_trainers
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
......@@ -413,15 +439,11 @@ class Entry(object):
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')
emb, loss, prob = model.get_output(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()
self.input_field = input_field
emb, loss, prob = model.get_output(input=input_field,
num_ranks=num_trainers,
rank_id=trainer_id,
is_train=is_train,
......@@ -449,7 +471,7 @@ class Entry(object):
num_or_sections=num_trainers)
prob = fluid.layers.concat(prob_list, axis=1)
label_all = fluid.layers.collective._c_allgather(
label,
input_field.label,
nranks=num_trainers,
use_calc_stream=True)
acc1 = fluid.layers.accuracy(input=prob,
......@@ -461,10 +483,10 @@ class Entry(object):
else:
if self.calc_train_acc:
acc1 = fluid.layers.accuracy(input=prob,
label=label,
label=input_field.label,
k=1)
acc5 = fluid.layers.accuracy(input=prob,
label=label,
label=input_field.label,
k=5)
optimizer = None
......@@ -489,7 +511,7 @@ class Entry(object):
def get_files_from_hdfs(self):
assert self.fs_checkpoint_dir, \
logger.error("Please set the fs_checkpoint_dir paramerters for "
"set_hdfs_info to get models from hdfs.")
"set_llllllhdfs_info to get models from hdfs.")
self.fs_checkpoint_dir = os.path.join(self.fs_checkpoint_dir, '*')
cmd = "hadoop fs -D fs.default.name="
cmd += self.fs_name + " "
......@@ -631,15 +653,10 @@ class Entry(object):
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')
emb = model.build_network(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()
emb = model.build_network(input=input_field,
is_train=False)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
......@@ -658,8 +675,12 @@ class Entry(object):
logger.info("model_save_dir for inference model ({}) exists, "
"we will overwrite it.".format(self.model_save_dir))
shutil.rmtree(self.model_save_dir)
feed_var_names = []
for name in input_field.feed_list_str:
if name == "label": continue
feed_var_names.append(name)
fluid.io.save_inference_model(self.model_save_dir,
feeded_var_names=[image.name],
feeded_var_names=feed_var_names,
target_vars=[emb],
executor=exe,
main_program=main_program)
......@@ -678,7 +699,6 @@ class Entry(object):
def predict(self):
model_name = self.model_name
image_shape = [int(m) for m in self.image_shape]
# model definition
model = self.model
if model is None:
......@@ -687,15 +707,10 @@ class Entry(object):
startup_program = self.startup_program
with fluid.program_guard(main_program, startup_program):
with fluid.unique_name.guard():
image = fluid.layers.data(name='image',
shape=image_shape,
dtype='float32')
label = fluid.layers.data(name='label',
shape=[1],
dtype='int64')
emb = model.build_network(input=image,
label=label,
input_field = InputField(self.input_info)
input_field.build()
emb = model.build_network(input=input_field,
is_train=False)
gpu_id = int(os.getenv("FLAGS_selected_gpus", 0))
......@@ -709,20 +724,20 @@ class Entry(object):
load_for_train=False)
if self.predict_reader is None:
predict_reader = paddle.batch(reader.arc_train(self.dataset_dir,
self.num_classes),
batch_size=self.train_batch_size)
predict_reader = reader.arc_train(self.dataset_dir,
self.num_classes)
else:
predict_reader = self.predict_reader
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
program=main_program)
input_field.loader.set_sample_generator(
predict_reader,
batch_size=self.train_batch_size,
places=place)
fetch_list = [emb.name]
for data in predict_reader():
for data in input_field.loader:
emb = exe.run(main_program,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
print("emb: ", emb)
......@@ -741,6 +756,14 @@ class Entry(object):
for j in range(len(data_list)):
data = data_list[j]
embeddings = None
# For multi-card test, the dataset can be partitioned into two
# part. For the first part, the total number of samples is
# divisiable by the number of cards. And then, these samples
# are split on different cards and tested parallely. For the
# second part, these samples are tested on all cards but only
# the result of the first card is used.
# The number of steps for parallel test.
parallel_test_steps = data.shape[0] // real_test_batch_size
for idx in range(parallel_test_steps):
start = idx * real_test_batch_size
......@@ -876,7 +899,7 @@ class Entry(object):
load_for_train=False)
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
feed_list=self.input_field.feed_list_str,
program=test_program)
fetch_list = [emb_name]
......@@ -940,9 +963,10 @@ class Entry(object):
else:
train_reader = self.train_reader
feeder = fluid.DataFeeder(place=place,
feed_list=['image', 'label'],
program=origin_prog)
self.input_field.loader.set_sample_generator(
train_reader,
batch_size=self.train_batch_size,
places=place)
if self.calc_train_acc:
fetch_list = [loss.name, global_lr.name,
......@@ -958,19 +982,19 @@ class Entry(object):
self.train_pass_id = pass_id
train_info = [[], [], [], []]
local_train_info = [[], [], [], []]
for batch_id, data in enumerate(train_reader()):
for batch_id, data in enumerate(self.input_field.loader):
nsamples += global_batch_size
t1 = time.time()
acc1 = None
acc5 = None
if self.calc_train_acc:
loss, lr, acc1, acc5 = exe.run(train_prog,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
else:
loss, lr = exe.run(train_prog,
feed=feeder.feed(data),
feed=data,
fetch_list=fetch_list,
use_program_cache=True)
t2 = time.time()
......
......@@ -33,7 +33,7 @@ class BaseModel(object):
def __init__(self):
super(BaseModel, self).__init__()
def build_network(self, input, label, is_train=True):
def build_network(self, input, is_train=True):
"""
Construct the custom model, and we will add the distributed fc layer
at the end of your model automatically.
......@@ -43,7 +43,6 @@ class BaseModel(object):
def get_output(self,
input,
label,
num_classes,
num_ranks=1,
rank_id=0,
......@@ -76,7 +75,8 @@ class BaseModel(object):
"Supported loss types: {}, but given: {}".format(
supported_loss_types, loss_type)
emb = self.build_network(input, label, is_train)
emb = self.build_network(input, is_train)
label = input.label
prob = None
loss = None
if loss_type == "softmax":
......
......@@ -27,7 +27,6 @@ class ResNet(BaseModel):
def build_network(self,
input,
label,
is_train=True):
layers = self.layers
supported_layers = [50, 101, 152]
......@@ -44,7 +43,7 @@ class ResNet(BaseModel):
num_filters = [64, 128, 256, 512]
conv = self.conv_bn_layer(
input=input, num_filters=64, filter_size=3, stride=1,
input=input.image, num_filters=64, filter_size=3, stride=1,
pad=1, act='prelu', is_train=is_train)
for block in range(len(depth)):
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from __future__ import division
from __future__ import print_function
import numpy as np
import paddle.fluid as fluid
class InputField(object):
"""
A high-level API for handling inputs in PaddlePaddle.
"""
def __init__(self, input_slots=[]):
self.shapes = []
self.dtypes = []
self.names = []
self.lod_levels = []
self.input_slots = {}
self.feed_list_str = []
self.feed_list = []
self.loader = None
if input_slots:
for input_slot in input_slots:
self += input_slot
def __add__(self, input_slot):
if isinstance(input_slot, list) or isinstance(input_slot, tuple):
name = input_slot[0]
shape = input_slot[1]
dtype = input_slot[2]
lod_level = input_slot[3] if len(input_slot) == 4 else 0
if isinstance(input_slot, dict):
name = input_slot["name"]
shape = input_slot["shape"]
dtype = input_slot["dtype"]
lod_level = input_slot[
"lod_level"] if "lod_level" in input_slot else 0
self.shapes.append(shape)
self.dtypes.append(dtype)
self.names.append(name)
self.lod_levels.append(lod_level)
self.feed_list_str.append(name)
return self
def __getattr__(self, name):
if name not in self.input_slots:
raise Warning("the attr %s has not been defined yet." % name)
return None
return self.input_slots[name]
def build(self, capacity=64, iterable=True):
for _name, _shape, _dtype, _lod_level in zip(
self.names, self.shapes, self.dtypes, self.lod_levels):
self.input_slots[_name] = fluid.data(
name=_name, shape=_shape, dtype=_dtype, lod_level=_lod_level)
for name in self.feed_list_str:
self.feed_list.append(self.input_slots[name])
self.loader = fluid.io.DataLoader.from_generator(
feed_list=self.feed_list,
capacity=capacity,
iterable=iterable,
use_double_buffer=True)
if __name__ == "__main__":
mnist_input_slots = [{
"name": "image",
"shape": (-1, 32, 32, 1),
"dtype": "int32"
}, {
"name": "label",
"shape": [-1, 1],
"dtype": "int64"
}]
input_field = InputField(mnist_input_slots)
input_field += {
"name": "large_image",
"shape": (-1, 64, 64, 1),
"dtype": "int32"
}
input_field += {
"name": "large_color_image",
"shape": (-1, 64, 64, 3),
"dtype": "int32"
}
input_field.build()
print(input_field.feed_list)
print(input_field.image)
print(input_field.large_color_image)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册