multihead_trainer.py 2.9 KB
Newer Older
X
xixiaoyao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91

from paddlepalm.distribute import gpu_dev_count, cpu_dev_count
from paddlepalm import Trainer

dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count
VERBOSE=False


class MultiHeadTrainer(Trainer):
    
    def __init__(self, trainers, reuse_flags=None):
        assert len(trainers) == len(mix_ratios)
        if reuse_flags is not None:
            assert len(reuse_flags) == len(trainers)

        self._trainers = trainers

    def build_forward(self, backbone, head_dict):

        num_heads = len(self._trainers)
        assert len(head_dict) == num_heads

        for t in trainers:
            assert t.name in head_dict
        
        train_prog = fluid.Program()
        train_init_prog = fluid.Program()

        def get_loss(i):
            head = head_dict[self._trainers[i].name]
            loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog)
            return loss_var
      
        task_fns = {}
        for i in range(num_heads):
            def task_loss():
                task_id = i
                return lambda: get_loss(task_id)
            task_fns[i] = task_loss()

        head_id_var = fluid.data(name="branch",shape=[1],dtype='int64')
        loss_var = layers.switch_case(
            branch_index=head_id_var,
            branch_fns=task_fns
        )
        self._head_id_var = head_id_var
        return loss_var

    def fit_readers(self, reader_dict, mix_ratio, ):
        
        num_heads = len(self._trainers)
        assert len(head_dict) == num_heads

        name_to_position = []
        joint_shape_and_dtypes = []
        iterators = []
        prefixes = []
        mrs = []
        net_inputs = []
        for t in trainers:
            assert t.name in reader_dict
            t.fit_reader(reader_dict[t.name])
            net_inputs.append(t._net_inputs)
            prefixes.append(t.name)
            mrs.append(t.mix_ratio)
            iterators.append(t._raw_iterator_fn())
            name_to_position.append(t._name_to_position)
            joint_shape_and_dtypes.append(t._shape_and_dtypes)

        iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict')
        feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs)

        if gpu_dev_count > 1:
            distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn)
        else:
            distribute_feeder_fn = iterator_fn

        if phase == 'train':
            self._train_reader = distribute_feeder_fn()
            self._feed_batch_process_fn = feed_batch_process_fn
        elif phase == 'predict':
            self._predict_reader = distribute_feeder_fn()
            self._pred_feed_batch_process_fn = feed_batch_process_fn
        
        
    def train(self):
        pass

    def train_one_step(self):
        pass