From 2c49818b2b401b71be74156de2a967e3cfd665c3 Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Wed, 22 Jan 2020 15:06:29 +0800 Subject: [PATCH] update docs --- README.md | 41 +- paddlepalm/README.md | 0 paddlepalm/_downloader.py | 51 +- paddlepalm/controller/__init__.py | 3 - paddlepalm/controller/conf_controller.py | 747 ----------------------- paddlepalm/controller/controller.py | 614 ------------------- paddlepalm/default_settings.py | 42 -- paddlepalm/interface.py | 177 ------ paddlepalm/mtl_controller.py | 746 ---------------------- paddlepalm/reader/cls.py | 18 +- paddlepalm/reader/match.py | 25 + paddlepalm/task_instance.py | 309 ---------- setup.cfg | 8 +- setup.py | 20 +- 14 files changed, 87 insertions(+), 2714 deletions(-) delete mode 100644 paddlepalm/README.md delete mode 100644 paddlepalm/controller/__init__.py delete mode 100755 paddlepalm/controller/conf_controller.py delete mode 100755 paddlepalm/controller/controller.py delete mode 100644 paddlepalm/default_settings.py delete mode 100644 paddlepalm/interface.py delete mode 100755 paddlepalm/mtl_controller.py delete mode 100644 paddlepalm/task_instance.py diff --git a/README.md b/README.md index ed86548..73203ec 100644 --- a/README.md +++ b/README.md @@ -1,48 +1,17 @@ # PaddlePALM -PaddlePALM (Paddle for Multi-task) 是一个灵活通用且易用的NLP大规模预训练与多任务学习框架。通过PaddlePALM,用户可以轻松完成复杂的多任务学习与参数复用,无缝集成「**单任务训练**」、「**多任务辅助训练**」和「**多目标任务联合训练**」这 *3* 种训练方式和灵活的保存与预测机制,且仅需书写极少量代码即可”一键启动”高性能单机单卡和分布式训练与推理。 +PaddlePALM (PArallel Learning from Multi-tasks) is a flexible, general and easy-to-use NLP large-scale pretraining and multi-task learning friendly framework. PALM is a high level framework aiming at **fastly** develop **high-performance** NLP models. With PALM, 8 steps to achieve a typical NLP task for supervised learning or pretraining. 6 steps to achieve multi-task learning for prepared tasks. Zero steps to adapt your code to large-scale training/inference (with multiple GPUs and multiple computation nodes). -框架中内置了丰富的[主干网络](#附录b内置主干网络backbone)及其[预训练模型](#预训练模型)(BERT、ERNIE等)、常见的[任务范式](#附录c内置任务范式paradigm)(分类、匹配、机器阅读理解等)和相应的[数据集读取与处理工具](#附录a内置数据集载入与处理工具reader)。同时框架提供了用户自定义接口,若内置工具、主干网络和任务无法满足需求,开发者可以轻松完成相关组件的自定义。各个组件均为零耦合设计,用户仅需完成组件本身的特性开发即可完成与框架的融合。 - -PaddlePALM (PArallel Learning from Multi-tasks) is a flexible, general and easy-to-use NLP large-scale pretraining and multi-task learning friendly framework. PALM is a high level framework aiming at **fastly** develop **high-performance** NLP models. With PALM, a typical NLP task can be achieved just in 8 steps. -s +PaddlePALM also provides state-of-the-art general purpose architectures (BERT,ERNIE,RoBERTa,...) as build-in model backbones. We have decoupled the model backbone, dataset reader and task output layers, so that you can easily replace any of the component to other candidates with quite minor changes of your code. In addition, PaddlePALM support customized development of any component, e.g, backbone, task head, reader and optimizer, which gives high flexibility for developers to adapt to complicated NLP scenes. 然后给出一些成功案例和一些公开数据集的各个backbone的实验结果(BERT、ERNIE、RoBERTa)和一些成功的多任务学习示例。 -## 目录 - -- [安装](#安装) -- [前期准备](#前期准备) - - [理论准备](#理论准备) - - [框架原理](#框架原理) - - [预训练模型](#预训练模型) -- [X行代码实现文本分类](#三个demo入门paddlepalm) - - -- [] - - [DEMO1:单任务训练](#demo1单任务训练) - - [DEMO2:多任务辅助训练与目标任务预测](#demo2多任务辅助训练与目标任务预测) - - [DEMO3:多目标任务联合训练与任务层参数复用](#demo3多目标任务联合训练与任务层参数复用) -- [进阶篇](#进阶篇) - - [配置广播机制](#配置广播机制) - - [reader、backbone与paradigm的选择](#readerbackbone与paradigm的选择) - - [多目标任务下的训练终止条件与预期训练步数](#多目标任务下的训练终止条件与预期训练步数) - - [多个目标任务](#多个目标任务) - - [训练终止条件](#训练终止条件) - - [任务采样概率与预期训练步数](#任务采样概率与预期训练步数) - - [多个目标任务时预期训练步数的计算](#多个目标任务时预期训练步数的计算) - - [模型保存与预测机制](#模型保存与预测机制) - - [分布式训练](#分布式训练) -- [附录A:内置数据集载入与处理工具(reader)](#附录a内置数据集载入与处理工具reader) -- [附录B:内置主干网络(backbone)](#附录b内置主干网络backbone) -- [附录C:内置任务范式(paradigm)](#附录c内置任务范式paradigm) -- [附录D:可配置的全局参数列表](#附录d可配置的全局参数列表) - ## Package Overview | **paddlepalm** | an open source NLP pretraining and multitask learning framework, built on paddlepaddle. | | **paddlepalm.reader** | a collection of elastic task-specific dataset readers. | -| **paddlepalm.backbone** | a collection of classic NLP representation models, e.g., BERT. | +| **paddlepalm.backbone** | a collection of classic NLP representation models, e.g., BERT, ERNIE, RoBERTa. | | **paddlepalm.head** | a collection of task-specific output layers. | | **paddlepalm.lr_sched** | a collection of learning rate schedualers. | | **paddlepalm.optimizer** | a collection of optimizers. | @@ -67,7 +36,7 @@ cd PALM && python setup.py install ``` ### Library Dependencies -- Python >= 2.7 (即将支持python3) +- Python >= 2.7 - cuda >= 9.0 - cudnn >= 7.0 - PaddlePaddle >= 1.6.3 (请参考[安装指南](http://www.paddlepaddle.org/#quick-start)进行安装) @@ -108,7 +77,7 @@ Available pretrain items: 5. create a task *trainer* with `paddlepalm.Trainer`, then build forward graph with backbone and task head (created in step 2 and 4) through `trainer.build_forward`. 6. use `paddlepalm.optimizer` (and `paddlepalm.lr_sched` if is necessary) to create a *optimizer*, then build backward through `trainer.build_backward`. 7. fit prepared reader and data (achieved in step 1) to trainer with `trainer.fit_reader` method. -8. randomly initialize model parameters (and `trainer.load_pretrain` if needed), then do training with `trainer.train`. +8. load pretrain model with `trainer.load_pretrain`, or load checkpoint with `trainer.load_ckpt` or nothing to do for training from scratch, then do training with `trainer.train`. More implementation details see following demos: [Sentiment Classification](), [Quora Question Pairs matching](), [Tagging](), [SQuAD machine Reading Comprehension](). diff --git a/paddlepalm/README.md b/paddlepalm/README.md deleted file mode 100644 index e69de29..0000000 diff --git a/paddlepalm/_downloader.py b/paddlepalm/_downloader.py index b36cbcc..7d8de3c 100644 --- a/paddlepalm/_downloader.py +++ b/paddlepalm/_downloader.py @@ -38,9 +38,10 @@ _items = { 'roberta-cn-base': 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_ext_L-12_H-768_A-12.tar.gz', 'roberta-cn-large': 'https://bert-models.bj.bcebos.com/chinese_roberta_wwm_large_ext_L-24_H-1024_A-16.tar.gz', 'utils': None}, - 'reader': {'utils': None}, + 'vocab': {'utils': None}, 'backbone': {'utils': None}, - 'tasktype': {'utils': None}, + 'head': {'utils': None}, + 'reader': {'utils': None}, } def _download(item, scope, path, silent=False, convert=False): @@ -131,20 +132,27 @@ def _convert(path, silent=False): tar_info.close() os.removedirs(path + '/params1/') -def download(item, scope='all', path='.'): - item = item.lower() +def download(scope, item='all', path='.'): + """download an item. The available scopes and contained items can be showed with `paddlepalm.downloader.ls`. + + Args: + scope: the scope the item belongs to. + item: the item to download. + path: the target dir to download to. Default is `.`, means current dir. + """ scope = scope.lower() - assert item in _items, '{} is not found. Support list: {}'.format(item, list(_items.keys())) + item = item.lower() + ascopeert scope in _scopes, '{} is not found. Support list: {}'.format(scope, list(_scopes.keys())) - if _items[item]['utils'] is not None: - _download(item, 'utils', path, silent=True) + if _scopes[scope]['utils'] is not None: + _download(scope, 'utils', path, silent=True) - if scope != 'all': - assert scope in _items[item], '{} is not found. Support scopes: {}'.format(scope, list(_items[item].keys())) - _download(item, scope, path) + if item != 'all': + ascopeert item in _scopes[scope], '{} is not found. Support items: {}'.format(item, list(_scopes[scope].keys())) + _download(scope, item, path) else: - for s in _items[item].keys(): - _download(item, s, path) + for s in _scopes[scope].keys(): + _download(scope, s, path) def _ls(item, scope, l = 10): @@ -157,19 +165,22 @@ def _ls(item, scope, l = 10): continue print (' => '+s) -def ls(item='all', scope='all'): +def ls(scope='all'): + """show all the available download items of a scope. + + Args: + scope: the scope to show items. Default is 'all', means to show all items in all scopes. Avaliable scopes: pretrain. + """ - if scope == 'utils': - return - if item != 'all': - assert item in _items, '{} is not found. Support scopes: {}'.format(item, list(_items.keys())) - print ('Available {} items:'.format(item)) - _ls(item, scope) + if scope != 'all': + assert scope in _items, '{} is not found. Support scopes: {}'.format(scope, list(_items.keys())) + print ('Available {} scopes:'.format(scope)) + _ls(scope, 'all') else: l = max(map(len, _items.keys())) for i in _items.keys(): print ('Available {} items: '.format(i)) - _ls(i, scope, l) + _ls(i, 'all', l) diff --git a/paddlepalm/controller/__init__.py b/paddlepalm/controller/__init__.py deleted file mode 100644 index 9e43b33..0000000 --- a/paddlepalm/controller/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ - -from conf_controller import ConfigController -from controller import Controller diff --git a/paddlepalm/controller/conf_controller.py b/paddlepalm/controller/conf_controller.py deleted file mode 100755 index 33f227d..0000000 --- a/paddlepalm/controller/conf_controller.py +++ /dev/null @@ -1,747 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import os -import sys -import importlib -import multiprocessing -from paddle import fluid -from paddle.fluid import layers -import yaml -import json -import logging -import time -import numpy as np - -from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint -from paddlepalm.utils.config_helper import PDConfig -from paddlepalm.utils.print_helper import print_dict -from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs - -from paddlepalm.default_settings import * -from paddlepalm.task_instance import TaskInstance, check_instances - -import Queue -from threading import Thread - -DEBUG=False -VERBOSE=0 - -def _get_basename(f): - return os.path.splitext(f)[0] - - -def _get_suffix(f): - return os.path.splitext(f)[-1] - - -def _parse_yaml(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(yaml_file=f, fuse_args=True) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) - return yaml_config - else: - raise NotImplementedError() - - -def _parse_json(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(json_file=f, fuse_args=support_cmd_line) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - config = json.load(fin) - return config - else: - raise NotImplementedError() - - -def _parse_list(string, astype=str): - assert isinstance(string, str), "{} is not a string.".format(string) - if ',' not in string: - return [astype(string)] - string = string.replace(',', ' ') - return [astype(i) for i in string.split()] - - -def _try_float(s): - try: - float(s) - return(float(s)) - except: - return s - - -def _check_conf(conf, checklist=None): - assert isinstance(conf, dict), "{} is not a dict.".format(conf) - ret = {} - for k,v in conf.items(): - if isinstance(v, str): - v = _try_float(v) - ret[k] = v - if checklist is not None: - for k, t in checklist: - assert k in ret, "required argument {} is NOT exist in config file.".format(k) - assert isintance(ret[k], t), "value type of argument {} should be {}".format(k, t) - return ret - - -# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None -def _check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): - for name, attr in in_attr.items(): - assert name in out_attr, in_name+': '+name+' not found in '+out_name - if attr != out_attr[name]: - if strict: - raise ValueError(name+': shape or dtype not consistent!') - else: - logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) - - -def _merge_conf(conf1, conf2, conf1_first=True, strict=False): - assert isinstance(conf1, dict), "{} is not a dict.".format(conf1) - assert isinstance(conf2, dict), "{} is not a dict.".format(conf2) - base_conf = conf2 if conf1_first else conf1 - base_conf = base_conf.copy() - new_conf = conf1 if conf1_first else conf2 - - for k, v in new_conf.items(): - if k in base_conf: - if base_conf[k] != v: - raise Warning("value of argument {} has been updated to {}.".format(k, v)) - else: - if strict: - continue - - base_conf[k] = v - return base_conf - - -def _encode_inputs(inputs, scope_name, sep='/', cand_set=None): - outputs = {} - for k, v in inputs.items(): - if cand_set is not None: - if k in cand_set: - outputs[k] = v - if scope_name+sep+k in cand_set: - outputs[scope_name+sep+k] = v - else: - outputs[scope_name+sep+k] = v - return outputs - - -def _decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True): - outputs = {} - for name, value in inputs.items(): - # var for backbone are also available to tasks - if keep_unk_keys and sep not in name: - outputs[name] = value - # var for this inst - if name.startswith(scope_name+'/'): - outputs[name[len(scope_name+'/'):]] = value - return outputs - - -def _init_env(use_gpu): - if use_gpu: - place = fluid.CUDAPlace(0) - dev_count = fluid.core.get_cuda_device_count() - else: - place = fluid.CPUPlace() - dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - return fluid.Executor(place), dev_count - - -def _fit_attr(conf, fit_attr, strict=False): - for i, attr in fit_attr.items(): - if i not in conf: - if strict: - raise Exception('Argument {} is required to create a controller.'.format(i)) - else: - continue - conf[i] = attr(conf[i]) - return conf - - -class ConfigController(object): - - def __init__(self, config, task_dir='.', for_train=True): - """ - Args: - config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径; - """ - - self._for_train = for_train - assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller." - - if isinstance(config, str): - mtl_conf = _parse_yaml(config, support_cmd_line=True) - else: - mtl_conf = config - - mtl_conf = _check_conf(mtl_conf) - mtl_conf = _fit_attr(mtl_conf, REQUIRED_ARGS, strict=True) - mtl_conf = _fit_attr(mtl_conf, OPTIONAL_ARGS, strict=False) - - exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) - self.exe = exe - self.dev_count = dev_count - - print_dict(mtl_conf, title='global configuration') - - # parse task instances and target tags - instnames = _parse_list(mtl_conf['task_instance']) - assert len(instnames) == len(set(instnames)), "repeated task_instance is NOT supported." - num_instances = len(instnames) - self.num_instances = num_instances - - instname_to_conf = {} - instname_to_id = {} - for id, instname in enumerate(instnames): - instpath = os.path.join(task_dir, instname+'.yaml') - conf = _parse_yaml(instpath, support_cmd_line=False) - # conf = _check_conf(conf, TASK_INSTANCE_REQUIRED_ARGS) - conf = _check_conf(conf) - temp_conf = _merge_conf(mtl_conf, conf, strict=True) - print_dict(temp_conf, title='{} configuration'.format(instname)) - conf = _merge_conf(mtl_conf, conf) - - instname_to_conf[instname] = conf - instname_to_id[instname] = id - - # prepare backbone - if 'backbone_config_path' in mtl_conf: - bb_conf = _parse_json(mtl_conf['backbone_config_path']) - bb_conf = _merge_conf(mtl_conf, bb_conf) - else: - bb_conf = mtl_conf - print_dict(bb_conf, title = 'backbone configuration'.format(instname)) - - bb_name = mtl_conf['backbone'] - bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) - Backbone = getattr(bb_mod, 'Model') - - # create task instances - instances = [] - for name in instnames: - instances.append(TaskInstance(name, instname_to_id[name], instname_to_conf[name])) - - check_instances(instances) - - # parse target_tag - if 'target_tag' in mtl_conf: - target_tag = str(mtl_conf['target_tag']) - tags = _parse_list(target_tag, astype=int) - assert len(tags) == len(instnames), "number of target_tag is NOT consistent with that in task_instance." - for tag, inst in zip(tags, instances): - inst.is_target = tag - else: - tags = [i.is_target for i in instances] - num_targets = sum(tags) - num_auxes = num_instances - num_targets - - # parse mix ratios - if 'mix_ratio' in mtl_conf: - mix_ratio = str(mtl_conf['mix_ratio']) - mrs = _parse_list(mix_ratio, astype=float) - assert len(mrs) == num_instances, "number of mix_ratios is NOT consistent with num_instances." - else: - mrs = [1.0] * num_instances - - for mr, inst in zip(mrs, instances): - inst.mix_ratio = mr - - # parse task layer reuse tags - instname_to_reusehost = {i:i for i in instnames} - if 'task_reuse_tag' in mtl_conf: - tags = _parse_list(mtl_conf['task_reuse_tag'], astype=int) - assert len(tags) == num_targets, 'number of reuse_tags is NOT consistent with number of instances.' - else: - tags = [] - mapper = {} - for inst in instances: - history = set() - history.add(inst.name) - cur_inst = inst - while True: - if cur_inst.task_reuse_scope in history: - mapper[inst.name] = len(tags) - break - elif cur_inst.task_reuse_scope in mapper: - mapper[inst.name] = mapper[cur_inst.task_reuse_scope] - break - else: - cur_inst = name_to_instance[cur_inst.task_reuse_scope] - history.add(cur_inst.name) - - tags.append(mapper[inst.name]) - - for i in range(1, num_instances): - for j in range(i): - if tags[i] == tags[j]: - assert instances[i].Paradigm == \ - instances[j].Paradigm, \ - "paradigm of reuse tasks should be consistent" - instances[i].task_reuse_scope = instances[j].name - break - - self.instances = instances - self.mrs = mrs - self.Backbone = Backbone - self.bb_conf = bb_conf - self.bb_name = bb_name - - self.has_init_train = False - self.has_init_pred = False - - if self._for_train: - print("initialing for training...") - self._init_train() - self.has_init_train = True - - def _init_train(self): - - instances = self.instances - Backbone = self.Backbone - bb_conf = self.bb_conf - bb_name = self.bb_name - dev_count = self.dev_count - num_instances = len(instances) - mrs = self.mrs - - # set first_target/main task instance - main_inst = None - for inst in instances: - if inst.is_target: - main_inst = inst - inst.is_first_target = True - break - main_conf = main_inst.config - if not os.path.exists(main_conf['save_path']): - os.makedirs(main_conf['save_path']) - os.makedirs(os.path.join(main_conf['save_path'], 'ckpt')) - - # prepare backbone - train_backbone = Backbone(bb_conf, phase='train') - pred_backbone = Backbone(bb_conf, phase='pred') - - # create reader, task - # then check i/o across reader, backbone and task_layer - task_attrs = [] - pred_task_attrs = [] - for inst in instances: - train_reader = inst.Reader(inst.config, phase='train') - inst.reader['train'] = train_reader - train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) - inst.task_layer['train'] = train_parad - task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) - task_attrs.append(task_attr_from_reader) - - _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') - _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') - _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') - - if inst.is_target: - if 'pred_file' not in inst.config: - inst.config['pred_file'] = '' - pred_reader = inst.Reader(inst.config, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) - inst.task_layer['pred'] = pred_parad - task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) - pred_task_attrs.append(task_attr_from_reader) - _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - - # merge reader input attrs from backbone and task_instances - joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) - pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] - - if DEBUG: - print('----- for debug -----') - print('joint input names:') - print(joint_input_names) - print('joint input shape and dtypes:') - print(joint_shape_and_dtypes) - - # load data - for inst in instances: - print(inst.name+": preparing data...", end='') - inst.reader['train'].load_data() - print('ok!') - - # merge dataset iterators and create net input vars - iterators = [] - prefixes = [] - mrs = [] - for inst in instances: - iterators.append(inst.reader['train'].iterator()) - prefixes.append(inst.name) - mrs.append(inst.mix_ratio) - - joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') - self._joint_iterator_fn = joint_iterator_fn - - input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] - pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) - net_inputs = create_net_inputs(input_attrs, async=False) - self._net_inputs = net_inputs - - # build backbone and task layers - train_prog = fluid.default_main_program() - train_init_prog = fluid.default_startup_program() - bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') - assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) - - pred_prog = fluid.Program() - pred_init_prog = fluid.Program() - - with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): - pred_net_inputs = create_net_inputs(pred_input_attrs) - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') - - fluid.framework.switch_main_program(train_prog) - fluid.framework.switch_startup_program(train_init_prog) - - task_output_vars = {} - for inst in instances: - task_inputs = {'backbone': bb_output_vars} - task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) - task_inputs['reader'] = task_inputs_from_reader - - scope = inst.task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) - output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} - old = len(task_output_vars) # for debug - task_output_vars.update(output_vars) - assert len(task_output_vars) - old == len(output_vars) # for debug - - # prepare predict vars for saving inference model - if inst.is_target: - - with fluid.program_guard(pred_prog, pred_init_prog): - cur_inputs = _decode_inputs(pred_net_inputs, inst.name) - inst.pred_input = cur_inputs - pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} - scope = inst.task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) - - - bb_fetches = {k: v.name for k,v in bb_output_vars.items()} - task_fetches = {k: v.name for k,v in task_output_vars.items()} - fetches = task_fetches - fetches['__task_id'] = net_inputs['__task_id'].name - - # compute loss - task_id_var = net_inputs['__task_id'] - task_id_vec = fluid.one_hot(task_id_var, num_instances) - losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) - loss = layers.reduce_sum(task_id_vec * losses) - - main_reader = main_inst.reader['train'] - - num_examples = main_reader.num_examples - for inst in instances: - max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * (num_examples // main_conf['batch_size'] // dev_count)) - if inst.is_target: - print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) - inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count - inst.expected_train_steps = max_train_steps - - global_max_train_steps = int(main_conf['num_epochs'] * sum(mrs) * (num_examples // main_conf['batch_size'] // dev_count)) - print('Estimated overall train steps {}.'.format(global_max_train_steps)) - - if 'warmup_proportion' in main_conf and main_conf['warmup_proportion'] > 0: - warmup_steps = int(global_max_train_steps * main_conf['warmup_proportion']) - print('Warmup steps: '+str(warmup_steps)) - else: - warmup_steps = 0 - - # build optimizer - if 'optimizer' in main_conf: - optim_mod = importlib.import_module(OPTIMIZER_DIR + '.' + main_conf['optimizer']) - optimize = getattr(optim_mod, OPTIMIZE_METHOD) - optimize(loss, main_conf, max_train_steps, warmup_steps, fluid.default_main_program()) - - loss.persistable = True - if main_conf.get('use_ema', False): - assert 'ema_decay' in main_conf, "ema_decay should be set when use_ema is enabled." - ema = fluid.optimizer.ExponentialMovingAverage(main_conf['ema_decay']) - ema.update() - - # prepare for train - self.train_backbone = train_backbone - self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - self.saver_program = fluid.default_main_program() - - self.main_inst = main_inst - self.fetches = fetches - self.has_init_train = True - self.has_init_pred = True - - self.exe.run(fluid.default_startup_program()) - print("\nRandomly initialize parameters...\n") - - def _init_pred(self, instance, infer_model_path): - inst = instance - if 'pred_output_path' not in inst.config: - inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name) - - if not os.path.exists(inst.config['pred_output_path']): - os.makedirs(inst.config['pred_output_path']) - - pred_backbone = self.Backbone(self.bb_conf, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) - inst.task_layer['pred'] = pred_parad - pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( - pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], - insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - - pred_prog = inst.load(infer_model_path) - if inst.reader['pred'] is None: - pred_reader = inst.Reader(inst.config, phase='pred') - inst.reader['pred'] = pred_reader - return pred_prog - - def load_pretrain(self, pretrain_path=None): - # load pretrain model (or ckpt) - if pretrain_path is None: - assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set." - pretrain_path = self.main_conf['pretrain_path'] - - init_pretraining_params( - self.exe, - pretrain_path, - main_program=fluid.default_startup_program()) - - - def train(self): - - if not self.has_init_train: - self._init_train() - self.has_init_train = True - - instances = self.instances - num_instances = self.num_instances - main_inst = self.main_inst - main_conf = main_inst.config - - backbone = self.train_backbone - train_program = self.train_program - saver_program = self.saver_program - fetches = self.fetches - - finish = [] - for inst in instances: - if inst.is_target: - if inst.expected_train_steps > 0: - finish.append(False) - else: - finish.append(True) - print(inst.name+': train finished!') - inst.save() - - def train_finish(): - for inst in instances: - if inst.is_target: - if not inst.train_finish: - return False - return True - - def pack_multicard_feed(iterator, net_inputs, dev_count): - ret = [] - mask = [] - for i in range(dev_count): - temp = {} - content, flag = next(iterator) - for q, var in net_inputs.items(): - temp[var.name] = content[q] - ret.append(temp) - mask.append(1 if flag else 0) - return ret, mask - - # do training - fetch_names, fetch_list = zip(*fetches.items()) - - main_step = 0 # only count for main task - global_step = 0 # count for all tasks - epoch = 0 - time_begin = time.time() - backbone_buffer = [] - - def multi_dev_reader(reader, dev_count): - def worker(reader, dev_count, queue): - dev_batches = [] - for index, data in enumerate(reader()): - if len(dev_batches) < dev_count: - dev_batches.append(data) - if len(dev_batches) == dev_count: - queue.put((dev_batches, 0)) - dev_batches = [] - # For the prediction of the remained batches, pad more batches to - # the number of devices and the padded samples would be removed in - # prediction outputs. - if len(dev_batches) > 0: - num_pad = dev_count - len(dev_batches) - for i in range(len(dev_batches), dev_count): - dev_batches.append(dev_batches[-1]) - queue.put((dev_batches, num_pad)) - queue.put(None) - - queue = Queue.Queue(dev_count*2) - p = Thread( - target=worker, args=(reader, dev_count, queue)) - p.daemon = True - p.start() - while True: - ret = queue.get() - if ret is not None: - batches, num_pad = ret - queue.task_done() - for batch in batches: - flag = num_pad == 0 - if num_pad > 0: - num_pad -= 1 - yield batch, flag - else: - break - queue.join() - - joint_iterator = multi_dev_reader(self._joint_iterator_fn, self.dev_count) - - while not train_finish(): - feed, mask = pack_multicard_feed(joint_iterator, self._net_inputs, self.dev_count) - rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() - rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id - cur_task = instances[rt_task_id] - - backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} - backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) - - task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) - - global_step += 1 - cur_task.cur_train_step += 1 - - cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch - if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task_global_step)) - - if global_step % main_conf.get('print_every_n_steps', 5) == 0: - loss = rt_outputs[cur_task.name+'/loss'] - loss = np.mean(np.squeeze(loss)).tolist() - - time_end = time.time() - time_cost = time_end - time_begin - - print("Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( - global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, - loss, main_conf.get('print_every_n_steps', 5) / time_cost)) - time_begin = time.time() - - if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: - print(cur_task.name+': train finished!') - cur_task.save() - - if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: - save_path = os.path.join(main_conf['save_path'], 'ckpt', - "step_" + str(global_step)) - fluid.io.save_persistables(self.exe, save_path, saver_program) - print('checkpoint has been saved at '+save_path) - - save_path = os.path.join(main_conf['save_path'], 'ckpt', - "step_" + str(global_step)) - fluid.io.save_persistables(self.exe, save_path, saver_program) - print('checkpoint has been saved at '+save_path) - - print("ALL tasks train finished, exiting...") - - def pred(self, task_instance, inference_model_dir=None): - if self._for_train: - raise Exception('This controller is a trainer. Please build a new controller with for_train=False for predicting.') - - assert isinstance(task_instance, str) - if isinstance(inference_model_dir, str): - assert os.path.exists(inference_model_dir), inference_model_dir+" not found." - # if not self.has_init_pred and inference_model_dir is None: - # raise ValueError('infer_model_path is required for prediction.') - if inference_model_dir is None: - assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." - inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') - - instance = None - for inst in self.instances: - if inst.name == task_instance: - instance = inst - break - - if instance is None: - raise ValueError(task_instance + ' is not a valid task_instance.') - - pred_prog = self._init_pred(instance, inference_model_dir) - - inst = instance - print(inst.name+": loading data...") - inst.reader['pred'].load_data() - fetch_names, fetch_vars = inst.pred_fetch_list - - print('predicting...') - mapper = {k:v for k,v in inst.pred_input} - buf = [] - for feed in inst.reader['pred'].iterator(): - feed = _encode_inputs(feed, inst.name, cand_set=mapper) - feed = {mapper[k]: v for k,v in feed.items()} - - rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - inst.postprocess(rt_outputs, phase='pred') - if inst.task_layer['pred'].epoch_inputs_attrs: - reader_outputs = inst.reader['pred'].get_epoch_outputs() - else: - reader_outputs = None - inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') - - -if __name__ == '__main__': - assert len(sys.argv) == 2, "Usage: python mtl_controller.py " - conf_path = sys.argv[1] - del sys.argv[1] - controller = Controller(conf_path) - if controller.main_conf['do_train']: - controller.train() - - - -__all__ = ["Controller"] - - - diff --git a/paddlepalm/controller/controller.py b/paddlepalm/controller/controller.py deleted file mode 100755 index 2fb73ca..0000000 --- a/paddlepalm/controller/controller.py +++ /dev/null @@ -1,614 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import os -import sys -import importlib -import multiprocessing -from paddle import fluid -from paddle.fluid import layers -import yaml -import json -import logging -import time -import numpy as np - -from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint -from paddlepalm.utils.config_helper import PDConfig -from paddlepalm.utils.print_helper import print_dict -from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs - -from paddlepalm.default_settings import * -from paddlepalm.task_instance import TaskInstance, check_instances - -DEBUG=False -VERBOSE=0 - -def _get_basename(f): - return os.path.splitext(f)[0] - - -def _get_suffix(f): - return os.path.splitext(f)[-1] - - -def _parse_yaml(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(yaml_file=f, fuse_args=True) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) - return yaml_config - else: - raise NotImplementedError() - - -def _parse_json(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(json_file=f, fuse_args=support_cmd_line) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - config = json.load(fin) - return config - else: - raise NotImplementedError() - - -def _parse_list(string, astype=str): - assert isinstance(string, str), "{} is not a string.".format(string) - if ',' not in string: - return [astype(string)] - string = string.replace(',', ' ') - return [astype(i) for i in string.split()] - - -def _try_float(s): - try: - float(s) - return(float(s)) - except: - return s - - -def _check_conf(conf, checklist=None): - assert isinstance(conf, dict), "{} is not a dict.".format(conf) - ret = {} - for k,v in conf.items(): - if isinstance(v, str): - v = _try_float(v) - ret[k] = v - if checklist is not None: - for k, t in checklist: - assert k in ret, "required argument {} is NOT exist in config file.".format(k) - assert isintance(ret[k], t), "value type of argument {} should be {}".format(k, t) - return ret - - -# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None -def _check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): - for name, attr in in_attr.items(): - assert name in out_attr, in_name+': '+name+' not found in '+out_name - if attr != out_attr[name]: - if strict: - raise ValueError(name+': shape or dtype not consistent!') - else: - logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) - - -def _merge_conf(conf1, conf2, conf1_first=True, strict=False): - assert isinstance(conf1, dict), "{} is not a dict.".format(conf1) - assert isinstance(conf2, dict), "{} is not a dict.".format(conf2) - base_conf = conf2 if conf1_first else conf1 - base_conf = base_conf.copy() - new_conf = conf1 if conf1_first else conf2 - - for k, v in new_conf.items(): - if k in base_conf: - if base_conf[k] != v: - raise Warning("value of argument {} has been updated to {}.".format(k, v)) - else: - if strict: - continue - - base_conf[k] = v - return base_conf - - -def _encode_inputs(inputs, scope_name, sep='/', cand_set=None): - outputs = {} - for k, v in inputs.items(): - if cand_set is not None: - if k in cand_set: - outputs[k] = v - if scope_name+sep+k in cand_set: - outputs[scope_name+sep+k] = v - else: - outputs[scope_name+sep+k] = v - return outputs - - -def _decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True): - outputs = {} - for name, value in inputs.items(): - # var for backbone are also available to tasks - if keep_unk_keys and sep not in name: - outputs[name] = value - # var for this inst - if name.startswith(scope_name+'/'): - outputs[name[len(scope_name+'/'):]] = value - return outputs - - -def _init_env(use_gpu): - if use_gpu: - place = fluid.CUDAPlace(0) - dev_count = fluid.core.get_cuda_device_count() - else: - place = fluid.CPUPlace() - dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - return fluid.Executor(place), dev_count - - -def _fit_attr(conf, fit_attr, strict=False): - for i, attr in fit_attr.items(): - if i not in conf: - if strict: - raise Exception('Argument {} is required to create a controller.'.format(i)) - else: - continue - conf[i] = attr(conf[i]) - return conf - - -class Controller(object): - - def __init__(self, tasks, mix_ratios=None, task_reuse_tag=None, use_gpu=True): - """ - Args: - """ - - exe, dev_count = _init_env(use_gpu=use_gpu) - self.exe = exe - self.dev_count = dev_count - - # parse task instances and target tags - for id in len(tasks): - tasks[id]._set_id(id) - - # parse mix ratios - if mix_ratios is not None: - if isinstance(mix_ratios, str): - mix_ratios = _parse_list(mix_ratios, astype=float) - else: - assert isinstance(mix_ratios, list) - assert len(mix_ratios) == len(tasks), "number of mix_ratios is NOT consistent with num_instances." - - for mr, t in zip(mix_ratios, tasks): - t.mix_ratio = mr - - # parse task layer reuse tags - instname_to_reusehost = {i:i for i in instnames} - if task_reuse_tag is not None: - if isinstance(task_reuse_tag, str): - tags = _parse_list(task_reuse_tag, astype=int) - else: - assert isinstance(task_reuse_tag, list) - assert len(task_reuse_tag) == len(tasks), "number of task_reuse_tag is NOT consistent with num_tasks." - tags = task_reuse_tag - - else: - tags = [] - mapper = {} - for inst in tasks: - history = set() - history.add(inst.name) - cur_inst = inst - while True: - if cur_inst.task_reuse_scope in history: - mapper[inst.name] = len(tags) - break - elif cur_inst.task_reuse_scope in mapper: - mapper[inst.name] = mapper[cur_inst.task_reuse_scope] - break - else: - cur_inst = name_to_instance[cur_inst.task_reuse_scope] - history.add(cur_inst.name) - - tags.append(mapper[inst.name]) - - for i in range(1, len(tasks)): - for j in range(i): - if tags[i] == tags[j]: - # assert tasks[i].tasktype == \ - # instances[j].tasktype, \ - # "paradigm of reuse tasks should be consistent" - tasks[i]._task_reuse_scope = task[j].name - break - - # self.instances = instances - # self.mrs = mrs - # self.Backbone = Backbone - # self.bb_conf = bb_conf - # self.bb_name = bb_name - - # self.has_init_train = False - # self.has_init_pred = False - - # if self._for_train: - # print("initialing for training...") - # self._init_train() - # self.has_init_train = True - # - def build_forward(self, backbone, mask_task=[]): - - task_instances = self._tasks - Backbone = self.Backbone - bb_conf = self.bb_conf - bb_name = self.bb_name - dev_count = self.dev_count - num_instances = len(instances) - mrs = self.mrs - - # set first_target/main task instance - main_inst = None - for inst in task_instances: - if inst.is_target: - main_inst = inst - inst._as_main = True - break - - if save_path is not None and not os.path.exists(save_path): - os.makedirs(save_path) - - # create reader, task - # then check i/o across reader, backbone and task_layer - task_attrs = [] - pred_task_attrs = [] - for inst in task_instances: - task_attr_from_reader = _encode_inputs(inst._taskblock['train'].inputs_attrs['reader'], inst.name) - task_attrs.append(task_attr_from_reader) - - _check_io(backbone.inputs_attr, inst._reader['train'].outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') - _check_io(inst.taskblock['train'].inputs_attrs['reader'], inst._reader['train'].outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') - _check_io(inst._taskblock['train'].inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') - - if inst.is_target: - if 'pred_file' not in inst.config: - inst.config['pred_file'] = '' - pred_reader = inst.Reader(inst.config, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) - inst.task_layer['pred'] = pred_parad - task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) - pred_task_attrs.append(task_attr_from_reader) - _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - - # merge reader input attrs from backbone and task_instances - joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) - pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] - - if DEBUG: - print('----- for debug -----') - print('joint input names:') - print(joint_input_names) - print('joint input shape and dtypes:') - print(joint_shape_and_dtypes) - - # load data - for inst in instances: - print(inst.name+": preparing data...", end='') - inst.reader['train'].load_data() - print('ok!') - - # merge dataset iterators and create net input vars - iterators = [] - prefixes = [] - mrs = [] - for inst in instances: - iterators.append(inst.reader['train'].iterator()) - prefixes.append(inst.name) - mrs.append(inst.mix_ratio) - - joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE) - - input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] - pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) - - # build backbone and task layers - train_prog = fluid.default_main_program() - train_init_prog = fluid.default_startup_program() - bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') - assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) - - pred_prog = fluid.Program() - pred_init_prog = fluid.Program() - - with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): - pred_net_inputs = create_net_inputs(pred_input_attrs) - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') - - fluid.framework.switch_main_program(train_prog) - fluid.framework.switch_startup_program(train_init_prog) - - task_output_vars = {} - for inst in instances: - task_inputs = {'backbone': bb_output_vars} - task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) - task_inputs['reader'] = task_inputs_from_reader - - scope = inst.task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) - output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} - old = len(task_output_vars) # for debug - task_output_vars.update(output_vars) - assert len(task_output_vars) - old == len(output_vars) # for debug - - # prepare predict vars for saving inference model - if inst.is_target: - - with fluid.program_guard(pred_prog, pred_init_prog): - cur_inputs = _decode_inputs(pred_net_inputs, inst.name) - inst.pred_input = cur_inputs - pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} - scope = inst.task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) - - - bb_fetches = {k: v.name for k,v in bb_output_vars.items()} - task_fetches = {k: v.name for k,v in task_output_vars.items()} - fetches = task_fetches - fetches['__task_id'] = net_inputs['__task_id'].name - - # compute loss - task_id_var = net_inputs['__task_id'] - task_id_vec = layers.one_hot(task_id_var, num_instances) - losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) - loss = layers.reduce_sum(task_id_vec * losses) - - def init_train(self, basetask, num_epochs, ): - main_reader = main_inst.reader['train'] - - num_examples = main_reader.num_examples - for inst in instances: - max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * (num_examples // main_conf['batch_size'] // dev_count)) - if inst.is_target: - print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) - inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count - inst.expected_train_steps = max_train_steps - - global_max_train_steps = int(main_conf['num_epochs'] * sum(mrs) * (num_examples // main_conf['batch_size'] // dev_count)) - print('Estimated overall train steps {}.'.format(global_max_train_steps)) - - # if 'warmup_proportion' in main_conf and main_conf['warmup_proportion'] > 0: - # warmup_steps = int(global_max_train_steps * main_conf['warmup_proportion']) - # print('Warmup steps: '+str(warmup_steps)) - # else: - # warmup_steps = 0 - - return loss, max_train_steps - - - def build_backward(self, optimizer, use_ema=False, ema_decay=0.9999): - # build optimizer - optimizer.optimize(fluid.default_main_program()) - - # loss.persistable = True - if use_ema: - ema = fluid.optimizer.ExponentialMovingAverage(ema_decay) - ema.update() - - def random_init_params(self): - if not self._init_finish: - # prepare for train - self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - self.saver_program = fluid.default_main_program() - self._init_finish = True - - print("\nRandomly initialize parameters...\n") - self.exe.run(fluid.default_startup_program()) - - def load_pretrain_params(self, pretrain_model_path=None): - # load pretrain model (or ckpt) - if pretrain_model_path is None: - assert 'pretrain_model_path' in self.main_conf, "pretrain_model_path NOT set." - pretrain_model_path = self.main_conf['pretrain_model_path'] - - init_pretraining_params( - self.exe, - pretrain_model_path, - main_program=fluid.default_startup_program()) - - if not self._init_finish: - self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - self.saver_program = fluid.default_main_program() - self._init_finish = True - - def load_infermodel(self, instance, infer_model_path): - inst = instance - if 'pred_output_path' not in inst.config: - inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name) - - if not os.path.exists(inst.config['pred_output_path']): - os.makedirs(inst.config['pred_output_path']) - - pred_backbone = self.Backbone(self.bb_conf, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) - inst.task_layer['pred'] = pred_parad - pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( - pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], - insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - - pred_prog = inst.load(infer_model_path) - if inst.reader['pred'] is None: - pred_reader = inst.Reader(inst.config, phase='pred') - inst.reader['pred'] = pred_reader - return pred_prog - - def train(self, num_epochs): - - if not self._init_finish: - raise Exception('params has not been initialized! Please init params with random_init_params or load_pretrain_params.') - - instances = self.instances - num_instances = self.num_instances - main_inst = self.main_inst - main_conf = main_inst.config - - backbone = self.train_backbone - train_program = self.train_program - saver_program = self.saver_program - fetches = self.fetches - - finish = [] - for inst in instances: - if inst.is_target: - if inst.expected_train_steps > 0: - finish.append(False) - else: - finish.append(True) - print(inst.name+': train finished!') - inst.save() - - def train_finish(): - for inst in instances: - if inst.is_target: - if not inst.train_finish: - return False - return True - - # do training - fetch_names, fetch_list = zip(*fetches.items()) - - main_step = 0 # only count for main task - global_step = 0 # count for all tasks - epoch = 0 - time_begin = time.time() - backbone_buffer = [] - while not train_finish(): - rt_outputs = self.exe.run(train_program, fetch_list=fetch_list) - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() - rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id - cur_task = instances[rt_task_id] - - backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} - backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) - - task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) - - global_step += 1 - cur_task.cur_train_step += 1 - - if cur_task.save_infermodel_every_n_steps > 0 and cur_task.cur_train_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task.cur_train_step)) - - if global_step % main_conf.get('print_every_n_steps', 5) == 0: - loss = rt_outputs[cur_task.name+'/loss'] - loss = np.mean(np.squeeze(loss)).tolist() - - time_end = time.time() - time_cost = time_end - time_begin - - print("Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( - global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, - loss, main_conf.get('print_every_n_steps', 5) / time_cost)) - time_begin = time.time() - - if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: - print(cur_task.name+': train finished!') - cur_task.save() - - if 'save_every_n_steps' in main_conf and global_step % main_conf['save_every_n_steps'] == 0: - save_path = os.path.join(main_conf['save_path'], - "step_" + str(global_step)) - fluid.io.save_persistables(self.exe, save_path, saver_program) - - print("ALL tasks train finished, exiting...") - - def pred(self, task_instance, inference_model_dir=None): - if self._for_train: - raise Exception('This controller is a trainer. Please build a new controller with for_train=False for predicting.') - - assert isinstance(task_instance, str) - if isinstance(inference_model_dir, str): - assert os.path.exists(inference_model_dir), inference_model_dir+" not found." - # if not self.has_init_pred and inference_model_dir is None: - # raise ValueError('infer_model_path is required for prediction.') - if inference_model_dir is None: - assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." - inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') - - instance = None - for inst in self.instances: - if inst.name == task_instance: - instance = inst - break - - if instance is None: - raise ValueError(task_instance + ' is not a valid task_instance.') - - pred_prog = self._init_pred(instance, inference_model_dir) - - inst = instance - print(inst.name+": loading data...") - inst.reader['pred'].load_data() - fetch_names, fetch_vars = inst.pred_fetch_list - - print('predicting...') - mapper = {k:v for k,v in inst.pred_input} - buf = [] - for feed in inst.reader['pred'].iterator(): - feed = _encode_inputs(feed, inst.name, cand_set=mapper) - feed = {mapper[k]: v for k,v in feed.items()} - - rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - inst.postprocess(rt_outputs, phase='pred') - if inst.task_layer['pred'].epoch_inputs_attrs: - reader_outputs = inst.reader['pred'].get_epoch_outputs() - else: - reader_outputs = None - inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') - - -if __name__ == '__main__': - assert len(sys.argv) == 2, "Usage: python mtl_controller.py " - conf_path = sys.argv[1] - del sys.argv[1] - controller = Controller(conf_path) - if controller.main_conf['do_train']: - controller.train() - - - - - - diff --git a/paddlepalm/default_settings.py b/paddlepalm/default_settings.py deleted file mode 100644 index 4f003ea..0000000 --- a/paddlepalm/default_settings.py +++ /dev/null @@ -1,42 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -BACKBONE_DIR='paddlepalm.backbone' -TASK_INSTANCE_DIR='paddlepalm.task_instance' -READER_DIR='paddlepalm.reader' -PARADIGM_DIR='paddlepalm.task_paradigm' -OPTIMIZER_DIR='paddlepalm.optimizer' -OPTIMIZE_METHOD='optimize' - -REQUIRED_ARGS={ - 'task_instance': str, - 'backbone': str, - 'optimizer': str, - 'learning_rate': float, - 'batch_size': int - } - -OPTIONAL_ARGS={ - 'mix_ratio': str, - 'target_tag': str, - 'reuse_rag': str - } - -TASK_REQUIRED_ARGS={ - 'paradigm': str, - 'reader': str, - 'train_file': str - } - diff --git a/paddlepalm/interface.py b/paddlepalm/interface.py deleted file mode 100644 index b8c3f78..0000000 --- a/paddlepalm/interface.py +++ /dev/null @@ -1,177 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""v1.1""" - -class reader(object): - """interface of data manager.""" - - def __init__(self, config): - assert isinstance(config, dict) - - # @property - # def inputs_attr(self): - # """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. - # Return: - # dict类型。对各个输入对象的属性描述。例如, - # 对于文本分类任务,可能需要包含输入文本和所属标签的id - # {"text": ([], 'str'), - # "label": ([], 'int')} - # 对于标注任务,可能需要输入词序列和对应的标签 - # {"tokens", ([-1], 'str'), - # "tags", ([-1], 'str')} - # 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 - # {"paragraph", ([], 'str'), - # "question", ([], 'str'), - # "start_position", ([], 'int') - # """ - # raise NotImplementedError() - - @property - def outputs_attr(self): - """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) - Return: - dict类型。对各个输入对象的属性描述。例如, - 对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象) - {"token_ids": ([-1, max_len], 'int64'), - "input_ids": ([-1, max_len], 'int64'), - "segment_ids": ([-1, max_len], 'int64'), - "input_mask": ([-1, max_len], 'float32'), - "label": ([-1], 'int')} - """ - raise NotImplementedError() - - # def parse_line(self): - # """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。 - # 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。 - # 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。 - # """ - # raise NotImplementedError() - # - # def tokenize(self, line): - # """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。 - # Args: - # - line: a unicode string. - # Return: - # a list of tokens - # """ - # raise NotImplementedError() - - def iterator(self): - """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 - Yield: - (dict) elements that meet the requirements in output_templete - """ - raise NotImplementedError() - - @property - def num_examples(self): - """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" - raise NotImplementedError() - - - -class backbone(object): - """interface of backbone model.""" - - def __init__(self, config, phase): - """ - Args: - config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数 - phase: str类型。运行阶段,目前支持train和predict - """ - assert isinstance(config, dict) - - @property - def inputs_attr(self): - """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个输入对象的属性描述。例如, - 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 - {"token_ids": ([-1, max_len], 'int64'), - "input_ids": ([-1, max_len], 'int64'), - "segment_ids": ([-1, max_len], 'int64'), - "input_mask": ([-1, max_len], 'float32')}""" - raise NotImplementedError() - - @property - def outputs_attr(self): - """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个输出对象的属性描述。例如, - 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 - {"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'), - "sentence_emb": ([-1, hidden_size], 'float32'), - "sim_vec": ([-1, hidden_size], 'float32')}""" - raise NotImplementedError() - - def build(self, inputs): - """建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。 - Args: - inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 - Return: - 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - """ - raise NotImplementedError() - - - - -class task_paradigm(object): - - def __init__(self, config, phase, backbone_config): - """ - config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 - phase: str类型。运行阶段,目前支持train和predict - """ - - @property - def inputs_attrs(self): - """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个对象集及其输入对象的属性描述。""" - raise NotImplementedError() - - @property - def outputs_attr(self): - """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 - """ - - raise NotImplementedError() - - @property - def epoch_inputs_attrs(self): - return {} - - def build(self, inputs, scope_name=""): - """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 - Args: - inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 - Return: - 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - - """ - raise NotImplementedError() - - def postprocess(self, rt_outputs): - """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" - pass - - def epoch_postprocess(self, post_inputs): - pass - diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py deleted file mode 100755 index 6cca513..0000000 --- a/paddlepalm/mtl_controller.py +++ /dev/null @@ -1,746 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import print_function - -import os -import sys -import importlib -import multiprocessing -from paddle import fluid -from paddle.fluid import layers -import yaml -import json -import logging -import time -import numpy as np - -from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint -from paddlepalm.utils.config_helper import PDConfig -from paddlepalm.utils.print_helper import print_dict -from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs -from paddlepalm.distribute import data_feeder, decode_fake - -from default_settings import * -from task_instance import TaskInstance, check_instances - - -DEBUG=False -VERBOSE=0 - -def _get_basename(f): - return os.path.splitext(f)[0] - - -def _get_suffix(f): - return os.path.splitext(f)[-1] - - -def _parse_yaml(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(yaml_file=f, fuse_args=True) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - yaml_config = yaml.load(fin, Loader=yaml.SafeLoader) - return yaml_config - else: - raise NotImplementedError() - - -def _parse_json(f, asdict=True, support_cmd_line=False): - assert os.path.exists(f), "file {} not found.".format(f) - if support_cmd_line: - args = PDConfig(json_file=f, fuse_args=support_cmd_line) - args.build() - return args.asdict() if asdict else args - else: - if asdict: - with open(f, "r") as fin: - config = json.load(fin) - return config - else: - raise NotImplementedError() - - -def _parse_list(string, astype=str): - assert isinstance(string, str), "{} is not a string.".format(string) - if ',' not in string: - return [astype(string)] - string = string.replace(',', ' ') - return [astype(i) for i in string.split()] - - -def _try_float(s): - try: - float(s) - return(float(s)) - except: - return s - - -def _check_conf(conf, checklist=None): - assert isinstance(conf, dict), "{} is not a dict.".format(conf) - ret = {} - for k,v in conf.items(): - if isinstance(v, str): - v = _try_float(v) - ret[k] = v - if checklist is not None: - for k, t in checklist: - assert k in ret, "required argument {} is NOT exist in config file.".format(k) - assert isintance(ret[k], t), "value type of argument {} should be {}".format(k, t) - return ret - - -# TODO: 增加None机制,允许hidden size、batch size和seqlen设置为None -def _check_io(in_attr, out_attr, strict=False, in_name="left", out_name="right"): - for name, attr in in_attr.items(): - assert name in out_attr, in_name+': '+name+' not found in '+out_name - if attr != out_attr[name]: - if strict: - raise ValueError(name+': shape or dtype not consistent!') - else: - logging.warning('{}: shape or dtype not consistent!\n{}:\n{}\n{}:\n{}'.format(name, in_name, attr, out_name, out_attr[name])) - - -def _merge_conf(conf1, conf2, conf1_first=True, strict=False): - assert isinstance(conf1, dict), "{} is not a dict.".format(conf1) - assert isinstance(conf2, dict), "{} is not a dict.".format(conf2) - base_conf = conf2 if conf1_first else conf1 - base_conf = base_conf.copy() - new_conf = conf1 if conf1_first else conf2 - - for k, v in new_conf.items(): - if k in base_conf: - if base_conf[k] != v: - raise Warning("value of argument {} has been updated to {}.".format(k, v)) - else: - if strict: - continue - - base_conf[k] = v - return base_conf - - -def _encode_inputs(inputs, scope_name, sep='/', cand_set=None): - outputs = {} - for k, v in inputs.items(): - if cand_set is not None: - if k in cand_set: - outputs[k] = v - if scope_name+sep+k in cand_set: - outputs[scope_name+sep+k] = v - else: - outputs[scope_name+sep+k] = v - return outputs - - -def _decode_inputs(inputs, scope_name, sep='/', keep_unk_keys=True): - outputs = {} - for name, value in inputs.items(): - # var for backbone are also available to tasks - if keep_unk_keys and sep not in name: - outputs[name] = value - # var for this inst - if name.startswith(scope_name+'/'): - outputs[name[len(scope_name+'/'):]] = value - return outputs - - -def _init_env(use_gpu): - if use_gpu: - place = fluid.CUDAPlace(0) - dev_count = fluid.core.get_cuda_device_count() - else: - place = fluid.CPUPlace() - dev_count = int(os.environ.get('CPU_NUM', multiprocessing.cpu_count())) - return fluid.Executor(place), dev_count - - -def _fit_attr(conf, fit_attr, strict=False): - for i, attr in fit_attr.items(): - if i not in conf: - if strict: - raise Exception('Argument {} is required to create a controller.'.format(i)) - else: - continue - conf[i] = attr(conf[i]) - return conf - - -def create_feed_batch_process_fn(net_inputs): - - - def feed_batch_process_fn(data, id=-1): - # temps = {} - # for i in range(len(net_inputs)): - temp = {} - inputs = net_inputs[id] if id != -1 else net_inputs - - for q, var in inputs.items(): - if isinstance(var, str) or isinstance(var, unicode): - temp[var] = data[q] - else: - temp[var.name] = data[q] - # temps[i] = temp - - return temp - - return feed_batch_process_fn - - -class Controller(object): - - def __init__(self, config, task_dir='.', for_train=True): - """ - Args: - config: (str|dict) 字符串类型时,给出yaml格式的config配置文件路径; - """ - - self._for_train = for_train - assert isinstance(config, str) or isinstance(config, dict), "a config dict or config file path is required to create a Controller." - - if isinstance(config, str): - mtl_conf = _parse_yaml(config, support_cmd_line=True) - else: - mtl_conf = config - - mtl_conf = _check_conf(mtl_conf) - mtl_conf = _fit_attr(mtl_conf, REQUIRED_ARGS, strict=True) - mtl_conf = _fit_attr(mtl_conf, OPTIONAL_ARGS, strict=False) - - exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) - self.exe = exe - self.dev_count = dev_count - self.batch_size = mtl_conf.get('batch_size') - - print_dict(mtl_conf, title='global configuration') - - # parse task instances and target tags - instnames = _parse_list(mtl_conf['task_instance']) - assert len(instnames) == len(set(instnames)), "repeated task_instance is NOT supported." - num_instances = len(instnames) - self.num_instances = num_instances - - instname_to_conf = {} - instname_to_id = {} - for id, instname in enumerate(instnames): - instpath = os.path.join(task_dir, instname+'.yaml') - conf = _parse_yaml(instpath, support_cmd_line=False) - # conf = _check_conf(conf, TASK_INSTANCE_REQUIRED_ARGS) - conf = _check_conf(conf) - temp_conf = _merge_conf(mtl_conf, conf, strict=True) - print_dict(temp_conf, title='{} configuration'.format(instname)) - conf = _merge_conf(mtl_conf, conf) - - instname_to_conf[instname] = conf - instname_to_id[instname] = id - - # prepare backbone - if 'backbone_config_path' in mtl_conf: - bb_conf = _parse_json(mtl_conf['backbone_config_path']) - bb_conf = _merge_conf(mtl_conf, bb_conf) - else: - bb_conf = mtl_conf - print_dict(bb_conf, title = 'backbone configuration'.format(instname)) - - bb_name = mtl_conf['backbone'] - bb_mod = importlib.import_module(BACKBONE_DIR + '.' + bb_name) - Backbone = getattr(bb_mod, 'Model') - - # create task instances - instances = [] - for name in instnames: - instances.append(TaskInstance(name, instname_to_id[name], instname_to_conf[name])) - - check_instances(instances) - - # parse target_tag - if 'target_tag' in mtl_conf: - target_tag = str(mtl_conf['target_tag']) - tags = _parse_list(target_tag, astype=int) - assert len(tags) == len(instnames), "number of target_tag is NOT consistent with that in task_instance." - for tag, inst in zip(tags, instances): - inst.is_target = tag - else: - tags = [i.is_target for i in instances] - num_targets = sum(tags) - num_auxes = num_instances - num_targets - - # parse mix ratios - if 'mix_ratio' in mtl_conf: - mix_ratio = str(mtl_conf['mix_ratio']) - mrs = _parse_list(mix_ratio, astype=float) - assert len(mrs) == num_instances, "number of mix_ratios is NOT consistent with num_instances." - else: - mrs = [1.0] * num_instances - - for mr, inst in zip(mrs, instances): - inst.mix_ratio = mr - - # parse task layer reuse tags - instname_to_reusehost = {i:i for i in instnames} - if 'task_reuse_tag' in mtl_conf: - tags = _parse_list(mtl_conf['task_reuse_tag'], astype=int) - assert len(tags) == num_targets, 'number of reuse_tags is NOT consistent with number of instances.' - else: - tags = [] - mapper = {} - for inst in instances: - history = set() - history.add(inst.name) - cur_inst = inst - while True: - if cur_inst.task_reuse_scope in history: - mapper[inst.name] = len(tags) - break - elif cur_inst.task_reuse_scope in mapper: - mapper[inst.name] = mapper[cur_inst.task_reuse_scope] - break - else: - cur_inst = name_to_instance[cur_inst.task_reuse_scope] - history.add(cur_inst.name) - - tags.append(mapper[inst.name]) - - for i in range(1, num_instances): - for j in range(i): - if tags[i] == tags[j]: - assert instances[i].Paradigm == \ - instances[j].Paradigm, \ - "paradigm of reuse tasks should be consistent" - instances[i].task_reuse_scope = instances[j].name - break - - self.instances = instances - self.mrs = mrs - self.Backbone = Backbone - self.bb_conf = bb_conf - self.bb_name = bb_name - - self.has_init_train = False - self.has_init_pred = False - - if self._for_train: - print("initialing for training...") - self._init_train() - self.has_init_train = True - - def _init_train(self): - - instances = self.instances - Backbone = self.Backbone - bb_conf = self.bb_conf - bb_name = self.bb_name - dev_count = self.dev_count - num_instances = len(instances) - mrs = self.mrs - branch = fluid.data(name="branch",shape=[1],dtype='int64') - - # set first_target/main task instance - main_inst = None - for inst in instances: - if inst.is_target: - main_inst = inst - inst.is_first_target = True - break - main_conf = main_inst.config - if not os.path.exists(main_conf['save_path']): - os.makedirs(main_conf['save_path']) - os.makedirs(os.path.join(main_conf['save_path'], 'ckpt')) - - # prepare backbone - train_backbone = Backbone(bb_conf, phase='train') - pred_backbone = Backbone(bb_conf, phase='pred') - - # create reader, task - # then check i/o across reader, backbone and task_layer - - # check_fns = {} - task_attrs = {} - pred_task_attrs = [] - joint_input_names = {} - joint_shape_and_dtypes = {} - name_to_position = {} - for i in range(num_instances): - # def check_tasks(): - # i = s - # def checkeach(): - - train_reader = instances[i].Reader(instances[i].config, phase='train') - instances[i].reader['train'] = train_reader - train_parad = instances[i].Paradigm(instances[i].config, phase='train', backbone_config=bb_conf) - instances[i].task_layer['train'] = train_parad - task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], instances[i].name) - task_attrs[i] = task_attr_from_reader - - _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') - _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') - _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') - # merge reader input attrs from backbone and task_instances - # pred_joint_input_names = [] - # pred_joint_shape_and_dtypes = [] - if instances[i].is_target: - if 'pred_file' not in instances[i].config: - instances[i].config['pred_file'] = '' - pred_reader = instances[i].Reader(instances[i].config, phase='pred') - pred_parad = instances[i].Paradigm(instances[i].config, phase='pred', backbone_config=bb_conf) - instances[i].task_layer['pred'] = pred_parad - task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], instances[i].name) - pred_task_attrs.append(task_attr_from_reader) - _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') - _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - # pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - # return joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i], pred_joint_input_names, pred_joint_shape_and_dtypes - # return checkeach - # check_fns[i] = check_tasks() - joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i] = merge_input_attrs(train_backbone.inputs_attr, task_attrs[i]) - - pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - - - # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] - - if DEBUG: - print('----- for debug -----') - print('joint input names:') - print(joint_input_names) - print('joint input shape and dtypes:') - print(joint_shape_and_dtypes) - - # load data - data_fns={} - for i in range(num_instances): - print(instances[i].name+": preparing data...", end='') - instances[i].reader['train'].load_data() - print('ok!') - - # merge dataset iterators and create net input vars - iterators = [] - prefixes = [] - mrs = [] - - for inst in instances: - iterators.append(inst.reader['train'].iterator()) - prefixes.append(inst.name) - mrs.append(inst.mix_ratio) - - joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') - self._joint_iterator_fn = joint_iterator_fn - - input_attrs = {} - net_inputs = {} - bb_output_vars = {} - bb_output_fns = {} - - # prepare predict vars for saving inference model - pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - pred_prog = fluid.Program() - pred_init_prog = fluid.Program() - self._pred_prog = pred_prog - - with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): - pred_net_inputs = create_net_inputs(pred_input_attrs) - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') - - task_inputs = {} - task_output_vars = {} - task_fns = {} - - def get_loss(i): - input_attrs[i] = [[m, j, k] for m, (j,k) in zip(joint_input_names[i], joint_shape_and_dtypes[i])] - net_inputs[i] = create_net_inputs(input_attrs[i], async=False) - # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) - bb_output_vars[i] = train_backbone.build(net_inputs[i], scope_name='__paddlepalm_') - assert sorted(bb_output_vars[i].keys()) == sorted(train_backbone.outputs_attr.keys()) - - # build backbone and task layers - task_inputs[i] = {'backbone': bb_output_vars[i]} - task_inputs_from_reader = _decode_inputs(net_inputs[i], instances[i].name) - task_inputs[i]['reader'] = task_inputs_from_reader - - scope = instances[i].task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - output_vars = instances[i].build_task_layer(task_inputs[i], phase='train', scope=scope) - output_vars = {instances[i].name+'/'+key: val for key, val in output_vars.items()} - loss_var = output_vars[instances[i].name+'/loss'] - task_output_vars[i] = output_vars - - if instances[i].is_target: - with fluid.program_guard(pred_prog, pred_init_prog): - cur_inputs = _decode_inputs(pred_net_inputs, instances[i].name) - instances[i].pred_input = cur_inputs - pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} - scope = instances[i].task_reuse_scope + '/' - with fluid.unique_name.guard(scope): - instances[i].build_task_layer(pred_task_inputs, phase='pred', scope=scope) - return loss_var - - for i in range(num_instances): - def task_loss(): - task_id = i - return lambda: get_loss(task_id) - task_fns[i] = task_loss() - - loss = layers.switch_case( - branch_index=branch, - branch_fns=task_fns - ) - self._switched_loss = loss.name - main_reader = main_inst.reader['train'] - - num_examples = main_reader.num_examples - for inst in instances: - max_train_steps = int(main_conf['num_epochs']* inst.mix_ratio * (num_examples // main_conf['batch_size'] // dev_count)) - if inst.is_target: - print('{}: expected train steps {}.'.format(inst.name, max_train_steps)) - inst.steps_pur_epoch = inst.reader['train'].num_examples // main_conf['batch_size'] // dev_count - inst.expected_train_steps = max_train_steps - - global_max_train_steps = int(main_conf['num_epochs'] * sum(mrs) * (num_examples // main_conf['batch_size'] // dev_count)) - print('Estimated overall train steps {}.'.format(global_max_train_steps)) - - if 'warmup_proportion' in main_conf and main_conf['warmup_proportion'] > 0: - warmup_steps = int(global_max_train_steps * main_conf['warmup_proportion']) - print('Warmup steps: '+str(warmup_steps)) - else: - warmup_steps = 0 - - # build optimizer - if 'optimizer' in main_conf: - optim_mod = importlib.import_module(OPTIMIZER_DIR + '.' + main_conf['optimizer']) - optimize = getattr(optim_mod, OPTIMIZE_METHOD) - optimize(loss, main_conf, max_train_steps, warmup_steps, fluid.default_main_program()) - - loss.persistable = True - if main_conf.get('use_ema', False): - assert 'ema_decay' in main_conf, "ema_decay should be set when use_ema is enabled." - ema = fluid.optimizer.ExponentialMovingAverage(main_conf['ema_decay']) - ema.update() - - # prepare for train - self.train_backbone = train_backbone - self.train_program = fluid.CompiledProgram(fluid.default_main_program()).with_data_parallel(loss_name=loss.name) - self.saver_program = fluid.default_main_program() - - self.main_inst = main_inst - self.has_init_train = True - self.has_init_pred = True - self._net_inputs = net_inputs - - self.exe.run(fluid.default_startup_program()) - print("\nRandomly initialize parameters...\n") - - def _init_pred(self, instance, infer_model_path): - inst = instance - if 'pred_output_path' not in inst.config: - inst.config['pred_output_path'] = os.path.join(inst.config.get('save_path', '.'), inst.name) - - if not os.path.exists(inst.config['pred_output_path']): - os.makedirs(inst.config['pred_output_path']) - - pred_backbone = self.Backbone(self.bb_conf, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=self.bb_conf) - inst.task_layer['pred'] = pred_parad - pred_joint_input_names, pred_joint_shape_and_dtypes, name_to_position = merge_input_attrs( - pred_backbone.inputs_attr, inst.task_layer['pred'].inputs_attrs['reader'], - insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) - - pred_prog = inst.load(infer_model_path) - pred_prog = fluid.CompiledProgram(pred_prog).with_data_parallel() - if inst.reader['pred'] is None: - pred_reader = inst.Reader(inst.config, phase='pred') - inst.reader['pred'] = pred_reader - return pred_prog - - def load_pretrain(self, pretrain_path=None): - # load pretrain model (or ckpt) - if pretrain_path is None: - assert 'pretrain_path' in self.main_conf, "pretrain_path NOT set." - pretrain_path = self.main_conf['pretrain_path'] - - init_pretraining_params( - self.exe, - pretrain_path, - main_program=fluid.default_startup_program()) - - - def train(self): - - if not self.has_init_train: - self._init_train() - self.has_init_train = True - - instances = self.instances - num_instances = self.num_instances - main_inst = self.main_inst - main_conf = main_inst.config - - backbone = self.train_backbone - train_program = self.train_program - saver_program = self.saver_program - finish = [] - for inst in instances: - if inst.is_target: - if inst.expected_train_steps > 0: - finish.append(False) - else: - finish.append(True) - print(inst.name+': train finished!') - inst.save() - - def train_finish(): - for inst in instances: - if inst.is_target: - if not inst.train_finish: - return False - return True - - - # do training - fetch_names = {} - fetch_list = [] - main_step = 0 # only count for main task - global_step = 0 # count for all tasks - epoch = 0 - time_begin = time.time() - backbone_buffer = [] - - feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) - distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) - - while not train_finish(): - feed, mask, id = next(distribute_feeder) - for i in range(self.dev_count): - feed[i].update({'branch':np.array([id],dtype='int64')}) - fetch_list.append(self._switched_loss) - rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) - rt_loss = rt_outputs.pop() - - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - cur_task = instances[id] - - # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} - # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) - - # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) - - global_step += 1 - cur_task.cur_train_step += 1 - - cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch - if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task_global_step), prog=self._pred_prog) - - if global_step % main_conf.get('print_every_n_steps', 5) == 0: - loss = rt_loss - loss = np.mean(np.squeeze(loss)).tolist() - - time_end = time.time() - time_cost = time_end - time_begin - - print("Global step: {}. Task: {}, step {}/{} (epoch {}), loss: {:.3f}, speed: {:.2f} steps/s".format( - global_step, cur_task.name, cur_task.cur_train_step, cur_task.steps_pur_epoch, cur_task.cur_train_epoch, - loss, main_conf.get('print_every_n_steps', 5) / time_cost)) - time_begin = time.time() - - if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: - print(cur_task.name+': train finished!') - cur_task.save(prog=self._pred_prog) - - if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: - save_path = os.path.join(main_conf['save_path'], 'ckpt', - "step_" + str(global_step)) - fluid.io.save_persistables(self.exe, save_path, saver_program) - print('checkpoint has been saved at '+save_path) - - save_path = os.path.join(main_conf['save_path'], 'ckpt', - "step_" + str(global_step)) - fluid.io.save_persistables(self.exe, save_path, saver_program) - print('checkpoint has been saved at '+save_path) - - print("ALL tasks train finished, exiting...") - - def pred(self, task_instance, inference_model_dir=None): - if self._for_train: - raise Exception('This controller is a trainer. Please build a new controller with for_train=False for predicting.') - - assert isinstance(task_instance, str) - if isinstance(inference_model_dir, str): - assert os.path.exists(inference_model_dir), inference_model_dir+" not found." - # if not self.has_init_pred and inference_model_dir is None: - # raise ValueError('infer_model_path is required for prediction.') - if inference_model_dir is None: - assert 'save_path' in self.mtl_conf, "one of the `inference_model_dir` and 'save_path' should be set to load inference model." - inference_model_dir = os.path.join(self.mtl_conf['save_path'], task_instance, 'infer_model') - - instance = None - for inst in self.instances: - if inst.name == task_instance: - instance = inst - break - - if instance is None: - raise ValueError(task_instance + ' is not a valid task_instance.') - - pred_prog = self._init_pred(instance, inference_model_dir) - - inst = instance - print(inst.name+": loading data...") - inst.reader['pred'].load_data() - fetch_names, fetch_vars = inst.pred_fetch_list - - print('predicting...') - feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) - distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred') - - buf = [] - for feed, mask, id in distribute_feeder: - - rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) - - num_fakes = decode_fake(len(rt_outputs[0]), mask, self.batch_size) - for _ in range(num_fakes): - for item in rt_outputs: - item.pop() - - rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - inst.postprocess(rt_outputs, phase='pred') - - if inst.task_layer['pred'].epoch_inputs_attrs: - reader_outputs = inst.reader['pred'].get_epoch_outputs() - else: - reader_outputs = None - - inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') - - -if __name__ == '__main__': - assert len(sys.argv) == 2, "Usage: python mtl_controller.py " - conf_path = sys.argv[1] - del sys.argv[1] - controller = Controller(conf_path) - if controller.main_conf['do_train']: - controller.train() - - - -__all__ = ["Controller"] diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index f7c9b51..c1e91d7 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -24,15 +24,15 @@ class ClassifyReader(Reader): For tsv format, training dataset file should have two header areas, i.e., `label` and `text`, and test set only requires `text` area. For example, ``` - label text - 1 Today is a good day. - 0 Such a terriable day! - 1 I feel lucky to meet you, dear. - 1 He likes sunshine and I like him :). - 0 JUST! GO! OUT! - ``` - - CAUTIOUS: The first line of the file must be header! And areas are splited by tab (\\t). + label [TAB] text + 1 [TAB] Today is a good day. + 0 [TAB] Such a terriable day! + 1 [TAB] I feel lucky to meet you, dear. + 1 [TAB] He likes sunshine and I like him :). + 0 [TAB] JUST! GO! OUT! + ``` + + CAUTIOUS: The first line of the file must be header! And areas are splited by tab (\\t). """ diff --git a/paddlepalm/reader/match.py b/paddlepalm/reader/match.py index 7f25c33..0bbe5a4 100644 --- a/paddlepalm/reader/match.py +++ b/paddlepalm/reader/match.py @@ -18,6 +18,31 @@ from paddlepalm.reader.utils.reader4ernie import ClassifyReader as CLSReader class MatchReader(Reader): + """ + The reader completes the loading and processing of matching-like task (e.g, query-query, question-answer, text similarity, natural language inference) dataset. Supported file format: tsv. + + For pointwise learning strategy, there should be two fields in training dataset file, i.e., `text_a`, `text_b` and `label`. For pairwise learning, there should exist three fields, i.e., `text_a`, `text_b` and `text_b_neg`. For predicting, only `text_a` and `text_b` are required. + + A pointwise learning case shows as follows: + ``` + label [TAB] text_a [TAB] text_b + 1 [TAB] Today is a good day. [TAB] what a nice day! + 0 [TAB] Such a terriable day! [TAB] There is a dog. + 1 [TAB] I feel lucky to meet you, dear. [TAB] You are my lucky, darling. + 1 [TAB] He likes sunshine and I like him :). [TAB] I like him. He like sunshine. + 0 [TAB] JUST! GO! OUT! [TAB] Come in please. + ``` + A pairwise learning case shows as follows: + text_a [TAB] text_b [TAB] text_b_neg + Today is a good day. [TAB] what a nice day! [TAB] terriable day! + Such a terriable day! [TAB] So terriable today! [TAB] There is a dog. + I feel lucky to meet you, dear. [TAB] You are my lucky, darling. [TAB] + He likes sunshine and I like him :). [TAB] I like him. He like sunshine. + JUST! GO! OUT! [TAB] Come in please. + + CAUTIOUS: The first line of the file must be header! And areas are splited by tab (\\t). + + """ def __init__(self, vocab_path, max_len, tokenizer='wordpiece', lang='en', seed=None, \ do_lower_case=False, learning_strategy='pointwise', phase='train', dev_count=1, print_prefix=''): # 需要什么加什么 diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py deleted file mode 100644 index bcf053b..0000000 --- a/paddlepalm/task_instance.py +++ /dev/null @@ -1,309 +0,0 @@ -# -*- coding: UTF-8 -*- -# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from paddlepalm.interface import reader as base_reader -from paddlepalm.interface import task_paradigm as base_paradigm -import os -import json -from paddle import fluid -import importlib -from paddlepalm.default_settings import * - - -def check_req_args(conf, name): - assert 'reader' in conf, name+': reader is required to build TaskInstance.' - assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.' - assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.' - - -class TaskInstance(object): - - def __init__(self, name, id, config, verbose=True): - self._name = name - self._config = config - self._verbose = verbose - - check_req_args(config, name) - - # parse Reader and Paradigm - reader_name = config['reader'] - reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) - Reader = getattr(reader_mod, 'Reader') - - parad_name = config['paradigm'] - parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) - Paradigm = getattr(parad_mod, 'TaskParadigm') - - self._Reader = Reader - self._Paradigm = Paradigm - - self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') - self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') - self._save_infermodel_every_n_steps = config.get('save_infermodel_every_n_steps', -1) - - # following flags can be fetch from instance config file - self._is_target = config.get('is_target', True) - self._first_target = config.get('is_first_target', False) - self._task_reuse_scope = config.get('task_reuse_scope', name) - - self._feeded_var_names = None - self._target_vars = None - - # training process management - self._mix_ratio = None - self._expected_train_steps = None - self._expected_train_epochs = None - self._steps_pur_epoch = None - self._cur_train_epoch = 0 - self._cur_train_step = 0 - self._train_finish = False - - # 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例 - self._reader = {'train': None, 'eval': None, 'predict': None} - self._input_layer = None - self._inputname_to_varname = {} - self._task_layer = {'train': None, 'eval': None, 'predict': None} - self._pred_input_name_list = [] - self._pred_input_varname_list = [] - self._pred_fetch_name_list = [] - self._pred_fetch_var_list = [] - - self._exe = fluid.Executor(fluid.CPUPlace()) - - self._save_protocol = { - 'input_names': 'self._pred_input_name_list', - 'input_varnames': 'self._pred_input_varname_list', - 'fetch_list': 'self._pred_fetch_name_list'} - - - def build_task_layer(self, net_inputs, phase, scope=""): - output_vars = self._task_layer[phase].build(net_inputs, scope_name=scope) - if phase == 'predict': - if output_vars is not None: - self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) - else: - self._pred_fetch_name_list = [] - self._pred_fetch_var_list = [] - return output_vars - - def postprocess(self, rt_outputs, phase): - return self._task_layer[phase].postprocess(rt_outputs) - - def epoch_postprocess(self, epoch_inputs, phase): - return self._task_layer[phase].epoch_postprocess(epoch_inputs) - - def save(self, suffix=''): - dirpath = self._save_infermodel_path + suffix - self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] - - # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) - prog = fluid.default_main_program().clone() - fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) - - conf = {} - for k, strv in self._save_protocol.items(): - d = None - v = locals() - exec('d={}'.format(strv), globals(), v) - conf[k] = v['d'] - with open(os.path.join(dirpath, '__conf__'), 'w') as writer: - writer.write(json.dumps(conf, indent=1)) - print(self._name + ': inference model saved at ' + dirpath) - - def load(self, infer_model_path=None): - if infer_model_path is None: - infer_model_path = self._save_infermodel_path - for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): - strv = self._save_protocol[k] - exec('{}=v'.format(strv)) - pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \ - fluid.io.load_inference_model(infer_model_path, self._exe) - print(self._name+': inference model loaded from ' + infer_model_path) - return pred_prog - - @property - def name(self): - return self._name - - @property - def Reader(self): - return self._Reader - - # @Reader.setter - # def Reader(self, cls): - # assert base_reader.__name__ == cls.__bases__[-1].__name__, \ - # "expect: {}, receive: {}.".format(base_reader.__name__, \ - # cls.__bases__[-1].__name__) - # self._Reader = cls - - @property - def Paradigm(self): - return self._Paradigm - - # @Paradigm.setter - # def Paradigm(self, cls): - # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ - # "expect: {}, receive: {}.".format(base_paradigm.__name__, \ - # cls.__bases__[-1].__name__) - # self._Paradigm = cls - - @property - def config(self): - return self._config - - @property - def reader(self): - return self._reader - - @property - def pred_input(self): - return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) - - @pred_input.setter - def pred_input(self, val): - assert isinstance(val, dict) - self._pred_input_name_list, self._pred_input_varname_list = \ - zip(*[[k, v.name] for k,v in val.items()]) - - @property - def pred_fetch_list(self): - return [self._pred_fetch_name_list, self._pred_fetch_var_list] - - @property - def task_layer(self): - return self._task_layer - - @property - def is_first_target(self): - return self._is_first_target - - @is_first_target.setter - def is_first_target(self, value): - self._is_first_target = bool(value) - if self._is_first_target: - assert self._is_target, "ERROR: only target task could be set as main task." - if self._verbose and self._is_first_target: - print("{}: set as main task".format(self._name)) - - @property - def is_target(self): - if self._is_target is not None: - return self._is_target - else: - raise ValueError("{}: is_target is None".format(self._name)) - - @is_target.setter - def is_target(self, value): - self._is_target = bool(value) - if self._verbose: - if self._is_target: - print('{}: set as target task.'.format(self._name)) - else: - print('{}: set as aux task.'.format(self._name)) - - @property - def mix_ratio(self): - if self._mix_ratio is not None: - return self._mix_ratio - else: - raise ValueError("{}: mix_ratio is None".format(self._name)) - - @mix_ratio.setter - def mix_ratio(self, value): - self._mix_ratio = float(value) - if self._verbose: - print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) - - @property - def save_infermodel_every_n_steps(self): - return self._save_infermodel_every_n_steps - - @property - def expected_train_steps(self): - return self._expected_train_steps - - @expected_train_steps.setter - def expected_train_steps(self, value): - self._expected_train_steps = value - self._expected_train_epochs = value / float(self._steps_pur_epoch) - - @property - def expected_train_epochs(self): - return self._expected_train_epochs - - @property - def cur_train_epoch(self): - return self._cur_train_epoch - - @cur_train_epoch.setter - def cur_train_epoch(self, value): - self._cur_train_epoch = value - - @property - def cur_train_step(self): - return self._cur_train_step - - @cur_train_step.setter - def cur_train_step(self, value): - self._cur_train_step = value - if self._cur_train_step > self._steps_pur_epoch: - self._cur_train_epoch += 1 - self._cur_train_step = 1 - if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: - self._train_finish = True - - @property - def steps_pur_epoch(self): - return self._steps_pur_epoch - - @steps_pur_epoch.setter - def steps_pur_epoch(self, value): - self._steps_pur_epoch = value - - @property - def train_finish(self): - return self._train_finish - - @property - def task_reuse_scope(self): - if self._task_reuse_scope is not None: - return self._task_reuse_scope - else: - raise ValueError("{}: task_reuse_scope is None".format(self._name)) - - @task_reuse_scope.setter - def task_reuse_scope(self, scope_name): - self._task_reuse_scope = str(scope_name) - if self._verbose: - print('{}: task_reuse_scope is set to {}'.format(self._name, self._task_reuse_scope)) - - - - - - - -def check_instances(insts): - """to check ids, first_target""" - pass - -def _check_ids(): - pass - -def _check_targets(): - pass - -def _check_reuse_scopes(): - pass diff --git a/setup.cfg b/setup.cfg index 77d6221..4ee53b6 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,13 +1,13 @@ [metadata] -name = paddle-palm +name = paddlepalm author = zhangyiming author_email = zhangyiming04@baidu.com -version = 1.2 +version = 1.0.0 -description = Paddle-PALM +description = PaddlePALM long_description = file: README.md long_description_content_type = text/markdown @@ -27,6 +27,8 @@ classifier = keywords = paddlepaddle paddle + nlp + pretrain multi-task-learning [options] diff --git a/setup.py b/setup.py index 52488eb..f59fa6f 100644 --- a/setup.py +++ b/setup.py @@ -18,7 +18,7 @@ """ Setup script. Authors: zhouxiangyang(zhouxiangyang@baidu.com) -Date: 2019/09/29 21:00:01 +Date: 2020/1/22 12:00:01 """ import setuptools with open("README.md", "r") as fh: @@ -28,10 +28,10 @@ setuptools.setup( version="1.0.0", author="PaddlePaddle", author_email="zhangyiming04@baidu.com", - description="A Multi-task Learning Lib for PaddlePaddle Users.", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/PaddlePadd", + description="a flexible, general and easy-to-use NLP large-scale pretraining and multi-task learning framework.", + # long_description=long_description, + # long_description_content_type="text/markdown", + url="https://github.com/PaddlePaddle/PALM", # packages=setuptools.find_packages(), packages = ['paddlepalm', 'paddlepalm.backbone', @@ -39,16 +39,20 @@ setuptools.setup( 'paddlepalm.optimizer', 'paddlepalm.reader', 'paddlepalm.reader.utils', - 'paddlepalm.task_paradigm', + 'paddlepalm.head', + 'paddlepalm.distribute', + 'paddlepalm.lr_sched', 'paddlepalm.tokenizer', 'paddlepalm.utils'], package_dir={'paddlepalm':'./paddlepalm', 'paddlepalm.backbone':'./paddlepalm/backbone', 'paddlepalm.backbone.utils':'./paddlepalm/backbone/utils', 'paddlepalm.optimizer':'./paddlepalm/optimizer', + 'paddlepalm.lr_sched': './paddlepalm/lr_sched', + 'paddlepalm.distribute': './paddlepalm/distribute', 'paddlepalm.reader':'./paddlepalm/reader', 'paddlepalm.reader.utils':'./paddlepalm/reader/utils', - 'paddlepalm.task_paradigm':'./paddlepalm/task_paradigm', + 'paddlepalm.head':'./paddlepalm/head', 'paddlepalm.tokenizer':'./paddlepalm/tokenizer', 'paddlepalm.utils':'./paddlepalm/utils'}, platforms = "any", @@ -64,7 +68,7 @@ setuptools.setup( 'Programming Language :: Python :: 3.7', ], install_requires = [ - 'paddlepaddle-gpu>=1.6.1' + 'paddlepaddle-gpu>=1.6.3' ] ) -- GitLab