@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer):
"""
Build forward computation graph for training, which usually built from input layer to loss node.
Args:
backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
heads: a list of Head objects. Phase of each head should be set as 'train', which is used to build task specific output layers.
Return:
- loss_var: a Variable object. The computational graph variable(node) of loss.
"""
...
...
@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer):
Bind readers and loaded train/predict data to trainers.
Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only
works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from
their `mix_ratio`.
Args:
readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each
sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling.
num_epochs: training epochs of the sampling_reference task (trainer).