未验证 提交 d91c9c63 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update multihead_trainer.py

上级 b88394ac
from paddle import fluid from paddle import fluid
from paddle.fluid import layers from paddle.fluid import layers
from paddlepalm.distribute import gpu_dev_count, cpu_dev_count from paddlepalm.distribute import gpu_dev_count, cpu_dev_count, data_feeder, decode_fake
from paddlepalm import Trainer from paddlepalm import Trainer
from paddlepalm.utils import reader_helper from paddlepalm.utils import reader_helper
import numpy as np import numpy as np
from paddlepalm.distribute import gpu_dev_count, data_feeder, decode_fake
import time import time
dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
...@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer): ...@@ -61,10 +60,6 @@ class MultiHeadTrainer(Trainer):
""" """
Build forward computation graph for training, which usually built from input layer to loss node. Build forward computation graph for training, which usually built from input layer to loss node.
Args:
backbone: a Backbone object with phase == 'train', which is used to extract multi-level text features, e.g., contextual word embedding and sentence embedding.
heads: a list of Head objects. Phase of each head should be set as 'train', which is used to build task specific output layers.
Return: Return:
- loss_var: a Variable object. The computational graph variable(node) of loss. - loss_var: a Variable object. The computational graph variable(node) of loss.
""" """
...@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer): ...@@ -115,10 +110,14 @@ class MultiHeadTrainer(Trainer):
def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'): def fit_readers_with_mixratio(self, readers, sampling_reference, num_epochs, phase='train'):
""" """
Bind readers and loaded train/predict data to trainers. Bind readers and loaded train/predict data to trainers. The `num_epochs` argument only
works on `sampling_reference` task(trainer), and num_epochs of other tasks are infered from
their `mix_ratio`.
Args: Args:
readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each readers: a dict or list of Reader objects. For dict case, each key is a trainer's name, and the mapped value is the reader object to bind to the trainer. For list case, each
sampling_reference: a trainer name. The task(trainer) selected as baseline for task sampling.
num_epochs: training epochs of the sampling_reference task (trainer).
""" """
self._check_phase(phase) self._check_phase(phase)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册