提交 7b82c250 编写于 作者: X xixiaoyao

add multihead

上级 a9c5a794
...@@ -9,6 +9,7 @@ import head ...@@ -9,6 +9,7 @@ import head
from trainer import Trainer from trainer import Trainer
from multihead_trainer import MultiHeadTrainer
del interface del interface
del task_instance del task_instance
......
from paddle import fluid
from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
from paddlepalm import Trainer from paddlepalm import Trainer
...@@ -9,47 +11,109 @@ VERBOSE=False ...@@ -9,47 +11,109 @@ VERBOSE=False
class MultiHeadTrainer(Trainer): class MultiHeadTrainer(Trainer):
def __init__(self, trainers, reuse_flags=None): def __init__(self, trainers, reuse_flags=None):
assert len(trainers) == len(mix_ratios)
if reuse_flags is not None: if reuse_flags is not None:
assert len(reuse_flags) == len(trainers) assert len(reuse_flags) == len(trainers)
self._trainers = trainers self._trainers = trainers
def build_forward(self, backbone, head_dict): self._train_init = False
self._predict_init = False
self._feeded_var_names = None
self._cur_train_step = 0
self._target_vars = None
self._inputname_to_varname = {}
self._pred_input_name_list = []
self._pred_input_varname_list = []
self._pred_fetch_name_list = []
self._pred_fetch_var_list = []
self._exe = None
self._save_protocol = {
'input_names': 'self._pred_input_name_list',
'input_varnames': 'self._pred_input_varname_list',
'fetch_list': 'self._pred_fetch_name_list'}
self._check_save = lambda: False
for t in self._trainers:
t._set_multitask()
def build_forward(self, backbone, heads):
if isinstance(heads, list):
head_dict = {k.name: v for k,v in zip(self._trainers, heads)}
elif isinstance(heads, dict):
head_dict = heads
else:
raise ValueError()
num_heads = len(self._trainers) num_heads = len(self._trainers)
assert len(head_dict) == num_heads assert len(head_dict) == num_heads
for t in trainers: for t in self._trainers:
assert t.name in head_dict assert t.name in head_dict, "expected: {}, exists: {}".format(t.name, head_dict.keys())
train_prog = fluid.Program() train_prog = fluid.Program()
train_init_prog = fluid.Program() train_init_prog = fluid.Program()
self._train_prog = train_prog
self._train_init_prog = train_init_prog
def get_loss(i): def get_loss(i):
head = head_dict[self._trainers[i].name] head = head_dict[self._trainers[i].name]
loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog) # loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
loss_var = self._trainers[i].build_forward(backbone, head)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(self._trainers[i].name)
print(i)
print(i)
print(i)
print(i)
return loss_var return loss_var
task_fns = {} # task_fns = {}
for i in range(num_heads): # for i in range(num_heads):
def task_loss():
task_id = i # def task_loss():
return lambda: get_loss(task_id) # task_id = i
task_fns[i] = task_loss() # return lambda: get_loss(task_id)
head_id_var = fluid.data(name="branch",shape=[1],dtype='int64') # task_fns[i] = task_loss()
loss_var = layers.switch_case(
branch_index=head_id_var, task_fns = {i: lambda: get_loss(i) for i in range(num_heads)}
branch_fns=task_fns print(task_fns)
)
with fluid.program_guard(train_prog, train_init_prog):
head_id_var = fluid.data(name="branch",shape=[1],dtype='int64')
loss_var = layers.switch_case(
branch_index=head_id_var,
branch_fns=task_fns
)
self._head_id_var = head_id_var self._head_id_var = head_id_var
return loss_var return loss_var
def fit_readers(self, reader_dict, mix_ratio, ): def fit_readers(self, reader_dict):
raise NotImplementedError()
def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs):
if isinstance(readers, list):
reader_dict = {k.name: v for k,v in zip(self._trainers, readers)}
elif isinstance(readers, dict):
reader_dict = readers
else:
raise ValueError()
num_heads = len(self._trainers) num_heads = len(self._trainers)
assert len(head_dict) == num_heads assert len(reader_dict) == num_heads
trainer_dict = {t.name: t for t in self._trainers}
assert sampling_reference in trainer_dict
trainer_dict[sampling_reference].fit_reader(reader_dict[sampling_reference])
base_steps_pur_epoch = trainer_dict[sampling_reference]._steps_pur_epoch
name_to_position = [] name_to_position = []
joint_shape_and_dtypes = [] joint_shape_and_dtypes = []
...@@ -57,9 +121,18 @@ class MultiHeadTrainer(Trainer): ...@@ -57,9 +121,18 @@ class MultiHeadTrainer(Trainer):
prefixes = [] prefixes = []
mrs = [] mrs = []
net_inputs = [] net_inputs = []
for t in trainers: global_steps = 0
for t in self._trainers:
assert t.name in reader_dict assert t.name in reader_dict
t.fit_reader(reader_dict[t.name]) assert reader_dict[t.name].num_epochs is None, "{}: num_epochs is not None. \
To run with multi-head mode, num_epochs of each Trainer should be set as None.".format(t.name)
print(num_epochs, t.mix_ratio, base_steps_pur_epoch)
max_train_steps = int(num_epochs * t.mix_ratio * base_steps_pur_epoch)
if not t.set_as_aux:
print('{}: expected train steps {}.'.format(t.name, max_train_steps))
global_steps += max_train_steps
if t.name != sampling_reference:
t.fit_reader(reader_dict[t.name])
net_inputs.append(t._net_inputs) net_inputs.append(t._net_inputs)
prefixes.append(t.name) prefixes.append(t.name)
mrs.append(t.mix_ratio) mrs.append(t.mix_ratio)
...@@ -67,7 +140,11 @@ class MultiHeadTrainer(Trainer): ...@@ -67,7 +140,11 @@ class MultiHeadTrainer(Trainer):
name_to_position.append(t._name_to_position) name_to_position.append(t._name_to_position)
joint_shape_and_dtypes.append(t._shape_and_dtypes) joint_shape_and_dtypes.append(t._shape_and_dtypes)
iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') print('Estimated overall train steps {}.'.format(global_steps))
self._overall_train_steps = global_steps
iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, \
mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict')
feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs) feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs)
if gpu_dev_count > 1: if gpu_dev_count > 1:
...@@ -82,10 +159,104 @@ class MultiHeadTrainer(Trainer): ...@@ -82,10 +159,104 @@ class MultiHeadTrainer(Trainer):
self._predict_reader = distribute_feeder_fn() self._predict_reader = distribute_feeder_fn()
self._pred_feed_batch_process_fn = feed_batch_process_fn self._pred_feed_batch_process_fn = feed_batch_process_fn
def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
iterator = self._train_reader
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
save_type = save_type.split(',')
if 'predict' in save_type:
assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
save_predict = True
if not os.path.exists(save_path):
os.makedirs(save_path)
else:
save_predict = False
if 'ckpt' in save_type:
if save_path is not None and save_steps is not None:
save_ckpt = True
if not os.path.exists(save_path):
os.makedirs(save_path)
else:
"WARNING: save_path or save_steps is not set, model will not be saved during training."
save_ckpt = False
else:
save_ckpt = False
time_begin = time.time()
for feed in iterator:
print(feed)
batch, task_id = feed
rt_outputs = self.train_one_step(batch, task_id)
task_rt_outputs = {k[len(self._trainers[task_id].name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self._trainers[task_id].name+'.')}
self._task_head.batch_postprocess(task_rt_outputs)
if print_steps > 0 and self._cur_train_step % print_steps == 0:
loss = rt_outputs[self._trainers[task_id].name+'.loss']
loss = np.mean(np.squeeze(loss)).tolist()
time_end = time.time()
time_cost = time_end - time_begin
print("global step: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format(
(self._cur_train_step, self._trainers[task_id]._cur_train_step-1) % self._trainers[task_id]._steps_pur_epoch + 1, self._trainers[task_id]._steps_pur_epoch, self._trainers[task_id]._cur_train_epoch,
loss, print_steps / time_cost))
time_begin = time.time()
self._check_save()
# if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
# print(cur_task.name+': train finished!')
# cur_task.save()
# if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0:
# if save_predict:
# self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
# if save_ckpt:
# fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
# print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch:
break
def train_one_step(self, batch, task_id):
if dev_count > 1:
assert isinstance(batch, list)
for f in batch:
f['branch'] = np.array([task_id], dtype='int64')
else:
assert isinstance(batch, dict)
batch['branch'] = np.array([task_id], dtype='int64')
# feed = self._trainers[task_id].get_one_batch()
rt_outputs = self._trainers[task_id].train_one_step(batch, self._exe, self._distribute_train_prog)
self._cur_train_steps += 1
def train(self): # if dev_count > 1:
pass # # feed, mask, task_id = batch
# for f in feed:
# f['branch'] = np.array([task_id], dtype='int64')
# rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list)
# num_fakes = decode_fake(len(rt_outputs[0]), mask, self._trainers[task_id]._batch_size)
# for _ in range(num_fakes):
# for item in rt_outputs:
# item.pop()
# else:
# feed, task_id = batch
# feed['branch'] = np.array([task_id], dtype='int64')
# rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._trainers[task_id]._fetch_list)
def predict_one_batch(self, batch):
raise NotImplementedError()
def train_one_step(self): def predict(self, output_dir=None, print_steps=1000):
pass raise NotImplementedError()
@property
def overall_train_steps(self):
return self._overall_train_steps
...@@ -46,7 +46,6 @@ class Adam(BaseOptimizer): ...@@ -46,7 +46,6 @@ class Adam(BaseOptimizer):
fluid.clip.set_gradient_clip( fluid.clip.set_gradient_clip(
clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres)) clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=clip_norm_thres))
print(self._loss)
_, param_grads = optimizer.minimize(self._loss) _, param_grads = optimizer.minimize(self._loss)
return param_grads return param_grads
......
...@@ -41,6 +41,8 @@ class Trainer(object): ...@@ -41,6 +41,8 @@ class Trainer(object):
self._train_init = False self._train_init = False
self._predict_init = False self._predict_init = False
self._check_save = lambda: False
# if save_predict_model: # if save_predict_model:
# self._save_predict_model = True # self._save_predict_model = True
# assert pred_head is not None, "pred_head is required to save predict model." # assert pred_head is not None, "pred_head is required to save predict model."
...@@ -59,6 +61,8 @@ class Trainer(object): ...@@ -59,6 +61,8 @@ class Trainer(object):
self._num_examples = 0 self._num_examples = 0
self._multi_task = False
# training process management # training process management
self._mix_ratio = mix_ratio self._mix_ratio = mix_ratio
self._expected_train_steps = None self._expected_train_steps = None
...@@ -133,8 +137,11 @@ class Trainer(object): ...@@ -133,8 +137,11 @@ class Trainer(object):
return output_vars return output_vars
def _set_multitask(self):
self._multi_task = True
def build_forward(self, backbone, task_head, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None): def build_forward(self, backbone, task_head):
# assert not self._multi_task, "you cannot build_forward in trainer when a train is wrapper by MultiHeadTrainer."
self._task_head = task_head self._task_head = task_head
# assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode" # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode"
...@@ -168,23 +175,22 @@ class Trainer(object): ...@@ -168,23 +175,22 @@ class Trainer(object):
input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)] input_attrs = [[i, j, k] for i, (j,k) in zip(input_names, shape_and_dtypes)]
if train_prog is None: train_prog = fluid.Program()
train_prog = fluid.Program() train_init_prog = fluid.Program()
if train_init_prog is None:
train_init_prog = fluid.Program()
self._prog = train_prog
self._train_prog = train_prog self._train_prog = train_prog
self._train_init_prog = train_init_prog self._train_init_prog = train_init_prog
with fluid.program_guard(train_prog, train_init_prog): if not self._multi_task:
with fluid.program_guard(train_prog, train_init_prog):
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
bb_output_vars = backbone.build(net_inputs)
else:
net_inputs = reader_helper.create_net_inputs(input_attrs, async=False) net_inputs = reader_helper.create_net_inputs(input_attrs, async=False)
self._net_inputs = net_inputs
# build backbone and task layers
# bb_output_vars = self._backbone.build(net_inputs, scope_name='__paddlepalm_')
bb_output_vars = backbone.build(net_inputs) bb_output_vars = backbone.build(net_inputs)
assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys()) self._net_inputs = net_inputs
assert sorted(bb_output_vars.keys()) == sorted(backbone.outputs_attr.keys())
# self._bb_output_vars.keys # self._bb_output_vars.keys
# fluid.framework.switch_main_program(train_prog) # fluid.framework.switch_main_program(train_prog)
# fluid.framework.switch_startup_program(train_init_prog) # fluid.framework.switch_startup_program(train_init_prog)
...@@ -195,9 +201,14 @@ class Trainer(object): ...@@ -195,9 +201,14 @@ class Trainer(object):
task_inputs['reader'] = task_inputs_from_reader task_inputs['reader'] = task_inputs_from_reader
scope = self.name+'.' scope = self.name+'.'
with fluid.program_guard(train_prog, train_init_prog): if not self._multi_task:
with fluid.program_guard(train_prog, train_init_prog):
with fluid.unique_name.guard(scope):
output_vars = self._build_head(task_inputs, phase='train', scope=scope)
else:
with fluid.unique_name.guard(scope): with fluid.unique_name.guard(scope):
output_vars = self._build_head(task_inputs, phase='train', scope=scope) output_vars = self._build_head(task_inputs, phase='train', scope=scope)
output_vars = {self.name+'.'+key: val for key, val in output_vars.items()} output_vars = {self.name+'.'+key: val for key, val in output_vars.items()}
old = len(task_output_vars) # for debug old = len(task_output_vars) # for debug
task_output_vars.update(output_vars) task_output_vars.update(output_vars)
...@@ -215,7 +226,10 @@ class Trainer(object): ...@@ -215,7 +226,10 @@ class Trainer(object):
# task_id_vec = layers.one_hot(task_id_var, num_instances) # task_id_vec = layers.one_hot(task_id_var, num_instances)
# losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) # losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0)
# loss = layers.reduce_sum(task_id_vec * losses) # loss = layers.reduce_sum(task_id_vec * losses)
with fluid.program_guard(train_prog, train_init_prog): if not self._multi_task:
with fluid.program_guard(train_prog, train_init_prog):
loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
else:
loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss']) loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss'])
# for _id, block in enumerate(self._train_prog.blocks): # for _id, block in enumerate(self._train_prog.blocks):
...@@ -225,6 +239,7 @@ class Trainer(object): ...@@ -225,6 +239,7 @@ class Trainer(object):
return loss_var return loss_var
def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999): def build_backward(self, optimizer, weight_decay=None, use_ema=False, ema_decay=0.9999):
# assert not self._multi_task, "you cannot build_backward in trainer when a train is wrapper by MultiHeadTrainer."
# build optimizer # build optimizer
assert self._train_init_prog is not None, "train graph not foung! You should build_forward first." assert self._train_init_prog is not None, "train graph not foung! You should build_forward first."
optimizer._set_prog(self._train_prog, self._train_init_prog) optimizer._set_prog(self._train_prog, self._train_init_prog)
...@@ -235,7 +250,7 @@ class Trainer(object): ...@@ -235,7 +250,7 @@ class Trainer(object):
param_list = dict() param_list = dict()
for param in self._prog.global_block().all_parameters(): for param in self._train_prog.global_block().all_parameters():
param_list[param.name] = param * 1.0 param_list[param.name] = param * 1.0
param_list[param.name].stop_gradient = True param_list[param.name].stop_gradient = True
...@@ -271,8 +286,11 @@ class Trainer(object): ...@@ -271,8 +286,11 @@ class Trainer(object):
# print(self._train_prog) # print(self._train_prog)
def fit_reader(self, reader, phase='train'): def fit_reader(self, reader, phase='train'):
# assert not self._multi_task, "you cannot fit_reader in trainer when a train is wrapper by MultiHeadTrainer."
# load data # load data
assert self._train_init_prog is not None or self._pred_init_prog is not None, "You need to build_forward or build_predict_head first to prepare input features."
assert self._shape_and_dtypes is not None or self._pred_shape_and_dtypes is not None, "You need to build_forward or build_predict_head first to prepare input features."
# 这里不确定是否要向上取整,需确认 # 这里不确定是否要向上取整,需确认
# tail = self._num_examples % batch_size > 0 # tail = self._num_examples % batch_size > 0
# self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0 # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0
...@@ -378,34 +396,52 @@ class Trainer(object): ...@@ -378,34 +396,52 @@ class Trainer(object):
convert=convert, convert=convert,
main_program=self._train_init_prog) main_program=self._train_init_prog)
def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5): def set_saver(self, save_path, save_steps, save_type='ckpt'):
"""
Argument:
save_type: ckpt, predict, pretrain
"""
iterator = self._train_reader
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
save_type = save_type.split(',') save_type = save_type.split(',')
if 'predict' in save_type: if 'predict' in save_type:
assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model." assert self._pred_head is not None, "Predict head not found! You should build_predict_head first if you want to save predict model."
assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.' assert save_path is not None and save_steps is not None, 'save_path and save_steps is required to save model.'
save_predict = True self._save_predict = True
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
else: else:
save_predict = False self._save_predict = False
if 'ckpt' in save_type: if 'ckpt' in save_type:
if save_path is not None and save_steps is not None: if save_path is not None and save_steps is not None:
save_ckpt = True self._save_ckpt = True
if not os.path.exists(save_path): if not os.path.exists(save_path):
os.makedirs(save_path) os.makedirs(save_path)
else: else:
"WARNING: save_path or save_steps is not set, model will not be saved during training." "WARNING: save_path or save_steps is not set, model will not be saved during training."
save_ckpt = False self._save_ckpt = False
else: else:
save_ckpt = False self._save_ckpt = False
def temp_func():
if (self._save_predict or self._save_ckpt) and self._cur_train_step % save_steps == 0:
if self._save_predict:
self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
print('predict model has been saved at '+os.path.join(save_path, 'pred.step'+str(self._cur_train_step)))
if self._save_ckpt:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
return True
else:
return False
self._check_save = temp_func
def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5):
"""
Argument:
save_type: ckpt, predict, pretrain
"""
iterator = self._train_reader
self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name)
# if save_path is not None or save_steps is not None: # if save_path is not None or save_steps is not None:
# assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built." # assert self._save_predict_model, "If you want to save model, you need set save_predict_model=True when this trainer is built."
...@@ -438,12 +474,6 @@ class Trainer(object): ...@@ -438,12 +474,6 @@ class Trainer(object):
task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
self._task_head.batch_postprocess(task_rt_outputs) self._task_head.batch_postprocess(task_rt_outputs)
# rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')}
self._task_head.batch_postprocess(task_rt_outputs)
# if self._save_predict_model and self._cur_train_step % save_steps == 0: # if self._save_predict_model and self._cur_train_step % save_steps == 0:
# self.save(save_path, suffix='.step'+str(self._cur_train_steps)) # self.save(save_path, suffix='.step'+str(self._cur_train_steps))
...@@ -462,13 +492,9 @@ class Trainer(object): ...@@ -462,13 +492,9 @@ class Trainer(object):
# if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: # if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps:
# print(cur_task.name+': train finished!') # print(cur_task.name+': train finished!')
# cur_task.save() # cur_task.save()
self._check_save()
if (save_predict or save_ckpt) and self._cur_train_step % save_steps == 0:
if save_predict:
self.save(save_path, suffix='pred.step'+str(self._cur_train_step))
if save_ckpt:
fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog)
print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)))
if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch: if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch:
break break
...@@ -525,23 +551,42 @@ class Trainer(object): ...@@ -525,23 +551,42 @@ class Trainer(object):
results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir) results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir)
return results return results
def train_one_step(self, batch): def train_one_step(self, batch, executor=None, distribute_train_prog=None):
exe = self._exe if executor is None else executor
distribute_train_prog = self._distribute_train_prog if distribute_train_prog is None else distribute_train_prog
if gpu_dev_count > 1: if gpu_dev_count > 1:
feed, mask = batch feed, mask = batch
rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size) num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size)
for _ in range(num_fakes): for _ in range(num_fakes):
for item in rt_outputs: for item in rt_outputs:
item.pop() item.pop()
else: else:
feed = self._feed_batch_process_fn(batch) feed = self._feed_batch_process_fn(batch)
rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) rt_outputs = exe.run(distribute_train_prog, feed=feed, fetch_list=self._fetch_list)
rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)}
self._cur_train_step += 1 self._cur_train_step += 1
self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch
return rt_outputs return rt_outputs
@property
def num_epochs(self):
return self._num_epochs
@property
def cur_train_steps(self):
return self._cur_train_step
@property
def cur_train_epoch(self):
return self._cur_train_epoch
@property
def steps_pur_epoch(self):
return self._steps_pur_epoch
def predict_one_batch(self, batch): def predict_one_batch(self, batch):
if gpu_dev_count > 1: if gpu_dev_count > 1:
feed, mask = batch feed, mask = batch
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册