@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer):
...
@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer):
"""
"""
Build forward computation graph for training, which usually built from input layer to loss node.
Build forward computation graph for training, which usually built from input layer to loss node.
Args:
backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
heads: a list of Head objects. Phase of each head should be set as 'train', which is used to build task specific output layers.
Return:
Return:
- loss_var: a Variable object. The computational graph variable(node) of loss.
- loss_var: a Variable object. The computational graph variable(node) of loss.
"""
"""
...
@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer):
...
@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer):
Bind readers and loaded train/predict data to trainers.
Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only
works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from
their `mix_ratio`.
Args:
Args:
readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each
readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each
sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling.
num_epochs: training epochs of the sampling_reference task (trainer).