未验证 提交 a3012b87 编写于 作者: X Xiaoyao Xi 提交者: GitHub

fix multi-dev predict

fix bugs
...@@ -5,5 +5,5 @@ import multiprocessing ...@@ -5,5 +5,5 @@ import multiprocessing
gpu_dev_count = int(fluid.core.get_cuda_device_count()) gpu_dev_count = int(fluid.core.get_cuda_device_count())
cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) cpu_dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count()))
from reader import yield_pieces, data_feeder from reader import yield_pieces, data_feeder, decode_fake
...@@ -58,7 +58,6 @@ def yield_pieces(data, distribute_strategy, batch_size): ...@@ -58,7 +58,6 @@ def yield_pieces(data, distribute_strategy, batch_size):
yield temp yield temp
def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
if postprocess_fn is None: if postprocess_fn is None:
def postprocess_fn(batch): def postprocess_fn(batch):
return batch return batch
...@@ -108,3 +107,15 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'): ...@@ -108,3 +107,15 @@ def data_feeder(reader, postprocess_fn=None, prefetch_steps=2, phase='train'):
queue.join() queue.join()
def decode_fake(nums, mask, bs):
n_t = 0
for flag in mask:
if not flag:
break
n_t = n_t + 1
n_f = len(mask) - n_t
p1 = nums - (n_t-1) * bs
each_f = p1 / (n_f+1)
return each_f * n_f
...@@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint ...@@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint
from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.config_helper import PDConfig
from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.print_helper import print_dict
from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs
from paddlepalm.distribute import data_feeder from paddlepalm.distribute import data_feeder, decode_fake
from default_settings import * from default_settings import *
from task_instance import TaskInstance, check_instances from task_instance import TaskInstance, check_instances
...@@ -228,6 +228,7 @@ class Controller(object): ...@@ -228,6 +228,7 @@ class Controller(object):
exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True))
self.exe = exe self.exe = exe
self.dev_count = dev_count self.dev_count = dev_count
self.batch_size = mtl_conf.get('batch_size')
print_dict(mtl_conf, title='global configuration') print_dict(mtl_conf, title='global configuration')
...@@ -350,7 +351,7 @@ class Controller(object): ...@@ -350,7 +351,7 @@ class Controller(object):
dev_count = self.dev_count dev_count = self.dev_count
num_instances = len(instances) num_instances = len(instances)
mrs = self.mrs mrs = self.mrs
branch = fluid.data(name="branch",shape=[1],dtype='int32') branch = fluid.data(name="branch",shape=[1],dtype='int64')
# set first_target/main task instance # set first_target/main task instance
main_inst = None main_inst = None
...@@ -536,9 +537,8 @@ class Controller(object): ...@@ -536,9 +537,8 @@ class Controller(object):
# prepare for train # prepare for train
self.train_backbone = train_backbone self.train_backbone = train_backbone
#self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name)
self.saver_program = fluid.default_main_program() self.saver_program = fluid.default_main_program()
self.train_program = self.saver_program
self.main_inst = main_inst self.main_inst = main_inst
self.has_init_train = True self.has_init_train = True
...@@ -564,7 +564,7 @@ class Controller(object): ...@@ -564,7 +564,7 @@ class Controller(object):
insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False)
pred_prog = inst.load(infer_model_path) pred_prog = inst.load(infer_model_path)
# pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel()
if inst.reader['pred'] is None: if inst.reader['pred'] is None:
pred_reader = inst.Reader(inst.config, phase='pred') pred_reader = inst.Reader(inst.config, phase='pred')
inst.reader['pred'] = pred_reader inst.reader['pred'] = pred_reader
...@@ -628,8 +628,8 @@ class Controller(object): ...@@ -628,8 +628,8 @@ class Controller(object):
while not train_finish(): while not train_finish():
feed, mask, id = next(distribute_feeder) feed, mask, id = next(distribute_feeder)
for i in range(self.dev_count):
feed[0].update({'branch':np.array([id],dtype='int32')}) feed[i].update({'branch':np.array([id],dtype='int64')})
fetch_list.append(self._switched_loss) fetch_list.append(self._switched_loss)
rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list)
rt_loss = rt_outputs.pop() rt_loss = rt_outputs.pop()
...@@ -714,33 +714,23 @@ class Controller(object): ...@@ -714,33 +714,23 @@ class Controller(object):
buf = [] buf = []
for feed, mask, id in distribute_feeder: for feed, mask, id in distribute_feeder:
# print('before run')
rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) rt_outputs = self.exe.run(pred_prog, feed, fetch_vars)
# print('after run')
splited_rt_outputs = [] nums_fake = decode_fake(len(rt_outputs[0]), mask, self.batch_size)
for item in rt_outputs: while nums_fake:
splited_rt_outputs.append(np.split(item, len(mask))) for item in rt_outputs:
# assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)]
# print(mask)
while mask.pop() == False:
print(mask)
for item in splited_rt_outputs:
item.pop() item.pop()
rt_outputs = [] nums_fake = nums_fake - 1
# print('cancat')
for item in splited_rt_outputs:
rt_outputs.append(np.concatenate(item))
rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)}
inst.postprocess(rt_outputs, phase='pred') inst.postprocess(rt_outputs, phase='pred')
# print('leave feeder')
if inst.task_layer['pred'].epoch_inputs_attrs: if inst.task_layer['pred'].epoch_inputs_attrs:
reader_outputs = inst.reader['pred'].get_epoch_outputs() reader_outputs = inst.reader['pred'].get_epoch_outputs()
else: else:
reader_outputs = None reader_outputs = None
# print('epoch postprocess')
inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') inst.epoch_postprocess({'reader':reader_outputs}, phase='pred')
...@@ -754,4 +744,4 @@ if __name__ == '__main__': ...@@ -754,4 +744,4 @@ if __name__ == '__main__':
__all__ = ["Controller"] __all__ = ["Controller"]
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册