@@ -125,27 +125,8 @@ class MultiHeadTrainer(Trainer):
...
@@ -125,27 +125,8 @@ class MultiHeadTrainer(Trainer):
branch_index=task_id_var,
branch_index=task_id_var,
branch_fns=task_fns
branch_fns=task_fns
)
)
# self._task_id_var = task_id_var
# self._loss_var = loss_var
# self._fetch_list = [loss_var.name]
ifnotself._multi_task:
ifnotself._multi_task:
self._init_exe_prog(for_train=False)
self._init_exe_prog(for_train=False)
# return self.build_forward()
# """
# Build computation graph for evaluation and prediction.
# Arguments:
# - pred_backbone: a Backbone object with phase == 'predict'. For evaluating model during training, the predict backbone should keep the same with train backbone.
# - pred_head: a Head object with phase == 'predict'. For evaluating model during training, the predict head should keep the same with train head.
#
# Return:
# - output_vars: dict type. Each value is a computational graph variable(node) argumented by pred_head outputs_attr.
# """
# for i in self._trainers:
# assert i._predict_vars is not None, "{} need to build_predict_forward before "
assertself._train_init_progisnotNoneorself._pred_init_progisnotNone,"model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."
assertself._train_init_progisnotNoneorself._pred_init_progisnotNone,"model graph not built. You should at least build_forward or build_predict_forward to load its checkpoint."
# if self._train_init_prog is not None:
# saver.init_pretraining_params(
# self._exe,
# model_path,
# convert=False,
# main_program=self._train_init_prog,
# strict=True)
# elif self._pred_init_prog is not None:
# saver.init_pretraining_params(
# self._exe,
# model_path,
# convert=False,
# main_program=self._pred_init_prog,
# strict=True)
ifself._train_init_progisnotNone:
ifself._train_init_progisnotNone:
saver.init_pretraining_params(
print('loading checkpoint into train program')
saver.init_checkpoint(
self._exe,
self._exe,
model_path,
model_path,
convert=False,
convert=False,
main_program=self._train_init_prog,
main_program=self._train_init_prog)
strict=True)
elifself._pred_init_progisnotNone:
elifself._pred_init_progisnotNone:
saver.init_pretraining_params(
saver.init_checkpoint(
self._exe,
self._exe,
model_path,
model_path,
convert=False,
main_program=self._pred_init_prog)
main_program=self._pred_init_prog,
strict=True)
else:
else:
raiseException("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")
raiseException("model not found. You should at least build_forward or build_predict_forward to load its checkpoint.")