未验证 提交 a3012b87 编写于 作者: X Xiaoyao Xi 提交者: GitHub

fix multi-dev predict

fix bugs
......@@ -5,5 +5,5 @@ import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder
from reader import yield_pieces, data_feeder, decode_fake
......@@ -58,7 +58,6 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
if postprocess_fn is None:
def postprocess_fn(batch):
return batch
......@@ -108,3 +107,15 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
queue.join()
def decode_fake(nums, mask, bs):
n_t = 0
for flag in mask:
if not flag:
break
n_t = n_t + 1
n_f = len(mask) - n_t
p1 = nums - (n_t-1) * bs
each_f = p1 / (n_f+1)
return each_f * n_f
......@@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from paddlepalm.utils.config_helper import PDConfig
from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.distribute import data_feeder
from paddlepalm.distribute import data_feeder, decode_fake
from default_settings import *
from task_instance import TaskInstance, check_instances
......@@ -228,6 +228,7 @@ class Controller(object):
exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True))
self.exe = exe
self.dev_count = dev_count
self.batch_size = mtl_conf.get('batch_size')
print_dict(mtl_conf, title='global configuration')
......@@ -350,7 +351,7 @@ class Controller(object):
dev_count = self.dev_count
num_instances = len(instances)
mrs = self.mrs
branch = fluid.data(name="branch",shape=[1],dtype='int32')
branch = fluid.data(name="branch",shape=[1],dtype='int64')
# set first_target/main task instance
main_inst = None
......@@ -536,9 +537,8 @@ class Controller(object):
# prepare for train
self.train_backbone = train_backbone
#self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
self.saver_program = fluid.default_main_program()
self.train_program = self.saver_program
self.main_inst = main_inst
self.has_init_train = True
......@@ -564,7 +564,7 @@ class Controller(object):
insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_prog = inst.load(infer_model_path)
# pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
if inst.reader['pred'] is None:
pred_reader = inst.Reader(inst.config, phase='pred')
inst.reader['pred'] = pred_reader
......@@ -628,8 +628,8 @@ class Controller(object):
while not train_finish():
feed, mask, id = next(distribute_feeder)
feed[0].update({'branch':np.array([id],dtype='int32')})
for i in range(self.dev_count):
feed[i].update({'branch':np.array([id],dtype='int64')})
fetch_list.append(self._switched_loss)
rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list)
rt_loss = rt_outputs.pop()
......@@ -714,33 +714,23 @@ class Controller(object):
buf = []
for feed, mask, id in distribute_feeder:
# print('before run')
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
# print('after run')
splited_rt_outputs = []
for item in rt_outputs:
splited_rt_outputs.append(np.split(item, len(mask)))
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
# print(mask)
while mask.pop() == False:
print(mask)
for item in splited_rt_outputs:
nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size)
while nums_fake:
for item in rt_outputs:
item.pop()
rt_outputs = []
# print('cancat')
for item in splited_rt_outputs:
rt_outputs.append(np.concatenate(item))
nums_fake = nums_fake - 1
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred')
# print('leave feeder')
if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs()
else:
reader_outputs = None
# print('epoch postprocess')
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
......@@ -754,4 +744,4 @@ if __name__ == '__main__':
__all__ = ["Controller"]
\ No newline at end of file
__all__ = ["Controller"]
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册