diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py index 39ec9711cc3c88286213461d72bc3b91329cd9ed..bc7af3448eaff179e7c48b329a783acdf6555ad3 100644 --- a/paddlepalm/multihead_trainer.py +++ b/paddlepalm/multihead_trainer.py @@ -1,11 +1,10 @@ from paddle import fluid from paddle.fluid import layers -from paddlepalm.distribute import gpu_dev_count, cpu_dev_count +from paddlepalm.distribute import gpu_dev_count, cpu_dev_count, data_feeder, decode_fake from paddlepalm import Trainer from paddlepalm.utils import reader_helper import numpy as np -from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake import time dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count @@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer): """ Build forward computation graph for training, which usually built from input layer to loss node. - Args: - backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding. - heads: a list of Head objects. Phase of each head should be set as 'train', which is used to build task specific output layers. - Return: - loss_var: a Variable object. The computational graph variable(node) of loss. """ @@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer): def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): """ - Bind readers and loaded train/predict data to trainers. + Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only + works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from + their `mix_ratio`. Args: readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each + sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling. + num_epochs: training epochs of the sampling_reference task (trainer). """ self._check_phase(phase)