提交 ac321194 编写于 作者: littletomatodonkey's avatar littletomatodonkey

fix fetchs

上级 7713736e
......@@ -49,7 +49,7 @@ def create_dataloader():
dataloader(fluid dataloader):
"""
trainer_num = int(os.environ.get('PADDLE_TRAINERS_NUM', 1))
capacity = 64 if trainer_num <= 1 else 8
capacity = 64 if trainer_num == 1 else 8
dataloader = fluid.io.DataLoader.from_generator(
capacity=capacity, use_double_buffer=True, iterable=True)
......@@ -163,15 +163,7 @@ def create_metric(out,
return fetchs
def create_fetchs(feeds,
out,
config,
architecture,
topk=5,
classes_num=1000,
epsilon=None,
use_mix=False,
use_distillation=False):
def create_fetchs(feeds, net, config, mode="train"):
"""
Create fetchs as model outputs(included loss and measures),
will call create_loss and create_metric(if use_mix).
......@@ -190,6 +182,15 @@ def create_fetchs(feeds,
Returns:
fetchs(dict): dict of model outputs(included loss and measures)
"""
architecture = config.ARCHITECTURE
topk = config.topk
classes_num = config.classes_num
epsilon = config.get('ls_epsilon')
use_mix = config.get('use_mix') and mode == 'train'
use_distillation = config.get('use_distillation')
out = net(feeds["image"])
fetchs = OrderedDict()
fetchs['loss'] = create_loss(feeds, out, architecture, classes_num,
epsilon, use_mix, use_distillation)
......@@ -276,40 +277,6 @@ def mixed_precision_optimizer(config, optimizer):
return optimizer
def compute(feeds, net, config, mode='train'):
"""
Build a program using a model and an optimizer
1. create feeds
2. create a dataloader
3. create a model
4. create fetchs
5. create an optimizer
Args:
config(dict): config
main_prog(): main program
startup_prog(): startup program
is_train(bool): train or valid
Returns:
dataloader(): a bridge between the model and the data
fetchs(dict): dict of model outputs(included loss and measures)
"""
out = net(feeds["image"])
fetchs = create_fetchs(
feeds,
out,
config,
config.ARCHITECTURE,
config.topk,
config.classes_num,
epsilon=config.get('ls_epsilon'),
use_mix=config.get('use_mix') and mode == 'train',
use_distillation=config.get('use_distillation'))
return fetchs
def create_feeds(batch, use_mix):
image = to_variable(batch[0].numpy().astype("float32"))
if use_mix:
......@@ -360,7 +327,7 @@ def run(dataloader, config, net, optimizer=None, epoch=0, mode='train'):
for idx, batch in enumerate(dataloader()):
batch_size = len(batch[0])
feeds = create_feeds(batch, use_mix)
fetchs = compute(feeds, net, config, mode)
fetchs = create_fetchs(feeds, net, config, mode)
if mode == 'train':
avg_loss = net.scale_loss(fetchs['loss'])
avg_loss.backward()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册