From 061b9e9a3ce092d5f9dff27fe4af1f96f6a20eaf Mon Sep 17 00:00:00 2001 From: xixiaoyao Date: Thu, 12 Dec 2019 20:56:07 +0800 Subject: [PATCH] add trainer --- paddlepalm/base_reader.py | 96 ++++++ paddlepalm/base_task.py | 62 ++++ paddlepalm/basebackbone.py | 177 ++++++++++ paddlepalm/interface.py | 177 ++++++++++ .../{task_paradigm => task}/__init__.py | 0 .../cls.py => task/classify.py} | 0 paddlepalm/{task_paradigm => task}/match.py | 0 paddlepalm/{task_paradigm => task}/mlm.py | 0 paddlepalm/{task_paradigm => task}/mrc.py | 0 paddlepalm/task_instance.py | 309 ++++++++++++++++++ paddlepalm/trainer.py | 18 +- 11 files changed, 832 insertions(+), 7 deletions(-) create mode 100644 paddlepalm/base_reader.py create mode 100644 paddlepalm/base_task.py create mode 100644 paddlepalm/basebackbone.py create mode 100644 paddlepalm/interface.py rename paddlepalm/{task_paradigm => task}/__init__.py (100%) rename paddlepalm/{task_paradigm/cls.py => task/classify.py} (100%) rename paddlepalm/{task_paradigm => task}/match.py (100%) rename paddlepalm/{task_paradigm => task}/mlm.py (100%) rename paddlepalm/{task_paradigm => task}/mrc.py (100%) create mode 100644 paddlepalm/task_instance.py diff --git a/paddlepalm/base_reader.py b/paddlepalm/base_reader.py new file mode 100644 index 0000000..d3e378e --- /dev/null +++ b/paddlepalm/base_reader.py @@ -0,0 +1,96 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""v1.1""" +from copy import copy +class reader(object): + """interface of data manager.""" + + def __init__(self, config, phase='train'): + assert isinstance(config, dict) + self._config = config + self._phase = phase + + def copy(self, phase=self._phase): + if phase == self._phase: + return copy(self) + else: + ret = copy(self) + ret._phase = phase + return ret + + + + # @property + # def inputs_attr(self): + # """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. + # Return: + # dict类型。对各个输入对象的属性描述。例如, + # 对于文本分类任务,可能需要包含输入文本和所属标签的id + # {"text": ([], 'str'), + # "label": ([], 'int')} + # 对于标注任务,可能需要输入词序列和对应的标签 + # {"tokens", ([-1], 'str'), + # "tags", ([-1], 'str')} + # 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 + # {"paragraph", ([], 'str'), + # "question", ([], 'str'), + # "start_position", ([], 'int') + # """ + # raise NotImplementedError() + + @property + def outputs_attr(self): + """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) + Return: + dict类型。对各个输入对象的属性描述。例如, + 对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象) + {"token_ids": ([-1, max_len], 'int64'), + "input_ids": ([-1, max_len], 'int64'), + "segment_ids": ([-1, max_len], 'int64'), + "input_mask": ([-1, max_len], 'float32'), + "label": ([-1], 'int')} + """ + raise NotImplementedError() + + # def parse_line(self): + # """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。 + # 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。 + # 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。 + # """ + # raise NotImplementedError() + # + # def tokenize(self, line): + # """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。 + # Args: + # - line: a unicode string. + # Return: + # a list of tokens + # """ + # raise NotImplementedError() + + def iterator(self): + """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 + Yield: + (dict) elements that meet the requirements in output_templete + """ + raise NotImplementedError() + + @property + def num_examples(self): + """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" + raise NotImplementedError() + + diff --git a/paddlepalm/base_task.py b/paddlepalm/base_task.py new file mode 100644 index 0000000..51b2025 --- /dev/null +++ b/paddlepalm/base_task.py @@ -0,0 +1,62 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class task(object): + + def __init__(self, config, phase, backbone_config): + """ + config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 + phase: str类型。运行阶段,目前支持train和predict + """ + + @property + def inputs_attrs(self): + """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个对象集及其输入对象的属性描述。""" + raise NotImplementedError() + + @property + def outputs_attr(self): + """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 + """ + + raise NotImplementedError() + + @property + def epoch_inputs_attrs(self): + return {} + + def build(self, inputs, scope_name=""): + """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 + Args: + inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 + Return: + 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + + """ + raise NotImplementedError() + + def postprocess(self, rt_outputs): + """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" + pass + + def epoch_postprocess(self, post_inputs): + pass + diff --git a/paddlepalm/basebackbone.py b/paddlepalm/basebackbone.py new file mode 100644 index 0000000..b8c3f78 --- /dev/null +++ b/paddlepalm/basebackbone.py @@ -0,0 +1,177 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""v1.1""" + +class reader(object): + """interface of data manager.""" + + def __init__(self, config): + assert isinstance(config, dict) + + # @property + # def inputs_attr(self): + # """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. + # Return: + # dict类型。对各个输入对象的属性描述。例如, + # 对于文本分类任务,可能需要包含输入文本和所属标签的id + # {"text": ([], 'str'), + # "label": ([], 'int')} + # 对于标注任务,可能需要输入词序列和对应的标签 + # {"tokens", ([-1], 'str'), + # "tags", ([-1], 'str')} + # 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 + # {"paragraph", ([], 'str'), + # "question", ([], 'str'), + # "start_position", ([], 'int') + # """ + # raise NotImplementedError() + + @property + def outputs_attr(self): + """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) + Return: + dict类型。对各个输入对象的属性描述。例如, + 对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象) + {"token_ids": ([-1, max_len], 'int64'), + "input_ids": ([-1, max_len], 'int64'), + "segment_ids": ([-1, max_len], 'int64'), + "input_mask": ([-1, max_len], 'float32'), + "label": ([-1], 'int')} + """ + raise NotImplementedError() + + # def parse_line(self): + # """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。 + # 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。 + # 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。 + # """ + # raise NotImplementedError() + # + # def tokenize(self, line): + # """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。 + # Args: + # - line: a unicode string. + # Return: + # a list of tokens + # """ + # raise NotImplementedError() + + def iterator(self): + """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 + Yield: + (dict) elements that meet the requirements in output_templete + """ + raise NotImplementedError() + + @property + def num_examples(self): + """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" + raise NotImplementedError() + + + +class backbone(object): + """interface of backbone model.""" + + def __init__(self, config, phase): + """ + Args: + config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数 + phase: str类型。运行阶段,目前支持train和predict + """ + assert isinstance(config, dict) + + @property + def inputs_attr(self): + """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输入对象的属性描述。例如, + 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 + {"token_ids": ([-1, max_len], 'int64'), + "input_ids": ([-1, max_len], 'int64'), + "segment_ids": ([-1, max_len], 'int64'), + "input_mask": ([-1, max_len], 'float32')}""" + raise NotImplementedError() + + @property + def outputs_attr(self): + """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输出对象的属性描述。例如, + 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 + {"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'), + "sentence_emb": ([-1, hidden_size], 'float32'), + "sim_vec": ([-1, hidden_size], 'float32')}""" + raise NotImplementedError() + + def build(self, inputs): + """建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。 + Args: + inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 + Return: + 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + """ + raise NotImplementedError() + + + + +class task_paradigm(object): + + def __init__(self, config, phase, backbone_config): + """ + config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 + phase: str类型。运行阶段,目前支持train和predict + """ + + @property + def inputs_attrs(self): + """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个对象集及其输入对象的属性描述。""" + raise NotImplementedError() + + @property + def outputs_attr(self): + """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 + """ + + raise NotImplementedError() + + @property + def epoch_inputs_attrs(self): + return {} + + def build(self, inputs, scope_name=""): + """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 + Args: + inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 + Return: + 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + + """ + raise NotImplementedError() + + def postprocess(self, rt_outputs): + """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" + pass + + def epoch_postprocess(self, post_inputs): + pass + diff --git a/paddlepalm/interface.py b/paddlepalm/interface.py new file mode 100644 index 0000000..b8c3f78 --- /dev/null +++ b/paddlepalm/interface.py @@ -0,0 +1,177 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""v1.1""" + +class reader(object): + """interface of data manager.""" + + def __init__(self, config): + assert isinstance(config, dict) + + # @property + # def inputs_attr(self): + # """描述reader输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1. + # Return: + # dict类型。对各个输入对象的属性描述。例如, + # 对于文本分类任务,可能需要包含输入文本和所属标签的id + # {"text": ([], 'str'), + # "label": ([], 'int')} + # 对于标注任务,可能需要输入词序列和对应的标签 + # {"tokens", ([-1], 'str'), + # "tags", ([-1], 'str')} + # 对于机器阅读理解任务,可能需要包含上下文、问题、回答、答案区域的起止位置等 + # {"paragraph", ([], 'str'), + # "question", ([], 'str'), + # "start_position", ([], 'int') + # """ + # raise NotImplementedError() + + @property + def outputs_attr(self): + """描述reader输出对象(被yield出的对象)的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + 注意:当使用mini-batch梯度下降学习策略时,,应为常规的输入对象设置batch_size维度(一般为-1) + Return: + dict类型。对各个输入对象的属性描述。例如, + 对于文本分类和匹配任务,yield的输出内容可能包含如下的对象(下游backbone和task可按需访问其中的对象) + {"token_ids": ([-1, max_len], 'int64'), + "input_ids": ([-1, max_len], 'int64'), + "segment_ids": ([-1, max_len], 'int64'), + "input_mask": ([-1, max_len], 'float32'), + "label": ([-1], 'int')} + """ + raise NotImplementedError() + + # def parse_line(self): + # """框架内部使用字典描述每个样本,字典的key为inputs_attr,value为每个input对应的符合attr描述的值。 + # 该函数负责将文本行解析成符合inputs_attr描述的字典类型的样本。默认的parse_line方法会读取json格式的数据集文件,数据集的每一行为json格式描述的样本。 + # 用户可通过对该方法的继承改写来适配不同格式的数据集,例如csv格式甚至tfrecord文件。 + # """ + # raise NotImplementedError() + # + # def tokenize(self, line): + # """框架中内置了word piece tokenizer等分词器,用户可通过修改tokenizer超参数来制定使用的分词器,若内置的分词器均无法满足需求,用户可通过对该方法的继承改写来自定义分词器。 + # Args: + # - line: a unicode string. + # Return: + # a list of tokens + # """ + # raise NotImplementedError() + + def iterator(self): + """数据集遍历接口,注意,当数据集遍历到尾部时该接口应自动完成指针重置,即重新从数据集头部开始新的遍历。 + Yield: + (dict) elements that meet the requirements in output_templete + """ + raise NotImplementedError() + + @property + def num_examples(self): + """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" + raise NotImplementedError() + + + +class backbone(object): + """interface of backbone model.""" + + def __init__(self, config, phase): + """ + Args: + config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数 + phase: str类型。运行阶段,目前支持train和predict + """ + assert isinstance(config, dict) + + @property + def inputs_attr(self): + """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输入对象的属性描述。例如, + 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 + {"token_ids": ([-1, max_len], 'int64'), + "input_ids": ([-1, max_len], 'int64'), + "segment_ids": ([-1, max_len], 'int64'), + "input_mask": ([-1, max_len], 'float32')}""" + raise NotImplementedError() + + @property + def outputs_attr(self): + """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输出对象的属性描述。例如, + 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 + {"word_emb": ([-1, max_seqlen, word_emb_size], 'float32'), + "sentence_emb": ([-1, hidden_size], 'float32'), + "sim_vec": ([-1, hidden_size], 'float32')}""" + raise NotImplementedError() + + def build(self, inputs): + """建立backbone的计算图。将符合inputs_attr描述的静态图Variable输入映射成符合outputs_attr描述的静态图Variable输出。 + Args: + inputs: dict类型。字典中包含inputs_attr中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 + Return: + 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + """ + raise NotImplementedError() + + + + +class task_paradigm(object): + + def __init__(self, config, phase, backbone_config): + """ + config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 + phase: str类型。运行阶段,目前支持train和predict + """ + + @property + def inputs_attrs(self): + """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个对象集及其输入对象的属性描述。""" + raise NotImplementedError() + + @property + def outputs_attr(self): + """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 + Return: + dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 + """ + + raise NotImplementedError() + + @property + def epoch_inputs_attrs(self): + return {} + + def build(self, inputs, scope_name=""): + """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 + Args: + inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 + Return: + 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 + + """ + raise NotImplementedError() + + def postprocess(self, rt_outputs): + """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" + pass + + def epoch_postprocess(self, post_inputs): + pass + diff --git a/paddlepalm/task_paradigm/__init__.py b/paddlepalm/task/__init__.py similarity index 100% rename from paddlepalm/task_paradigm/__init__.py rename to paddlepalm/task/__init__.py diff --git a/paddlepalm/task_paradigm/cls.py b/paddlepalm/task/classify.py similarity index 100% rename from paddlepalm/task_paradigm/cls.py rename to paddlepalm/task/classify.py diff --git a/paddlepalm/task_paradigm/match.py b/paddlepalm/task/match.py similarity index 100% rename from paddlepalm/task_paradigm/match.py rename to paddlepalm/task/match.py diff --git a/paddlepalm/task_paradigm/mlm.py b/paddlepalm/task/mlm.py similarity index 100% rename from paddlepalm/task_paradigm/mlm.py rename to paddlepalm/task/mlm.py diff --git a/paddlepalm/task_paradigm/mrc.py b/paddlepalm/task/mrc.py similarity index 100% rename from paddlepalm/task_paradigm/mrc.py rename to paddlepalm/task/mrc.py diff --git a/paddlepalm/task_instance.py b/paddlepalm/task_instance.py new file mode 100644 index 0000000..0915269 --- /dev/null +++ b/paddlepalm/task_instance.py @@ -0,0 +1,309 @@ +# -*- coding: UTF-8 -*- +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddlepalm.interface import reader as base_reader +from paddlepalm.interface import task_paradigm as base_paradigm +import os +import json +from paddle import fluid +import importlib +from paddlepalm.default_settings import * + + +def check_req_args(conf, name): + assert 'reader' in conf, name+': reader is required to build TaskInstance.' + assert 'paradigm' in conf, name+': paradigm is required to build TaskInstance.' + assert 'train_file' in conf or 'pred_file' in conf, name+': at least train_file or pred_file should be provided to build TaskInstance.' + + +class TaskInstance(object): + + def __init__(self, name, id, config, verbose=True): + self._name = name + self._config = config + self._verbose = verbose + + check_req_args(config, name) + + # parse Reader and Paradigm + reader_name = config['reader'] + reader_mod = importlib.import_module(READER_DIR + '.' + reader_name) + Reader = getattr(reader_mod, 'Reader') + + parad_name = config['paradigm'] + parad_mod = importlib.import_module(PARADIGM_DIR + '.' + parad_name) + Paradigm = getattr(parad_mod, 'TaskParadigm') + + self._Reader = Reader + self._Paradigm = Paradigm + + self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') + self._save_ckpt_path = os.path.join(self._config['save_path'], 'ckpt') + self._save_infermodel_every_n_steps = config.get('save_infermodel_every_n_steps', -1) + + # following flags can be fetch from instance config file + self._is_target = config.get('is_target', True) + self._first_target = config.get('is_first_target', False) + self._task_reuse_scope = config.get('task_reuse_scope', name) + + self._feeded_var_names = None + self._target_vars = None + + # training process management + self._mix_ratio = None + self._expected_train_steps = None + self._expected_train_epochs = None + self._steps_pur_epoch = None + self._cur_train_epoch = 0 + self._cur_train_step = 0 + self._train_finish = False + + # 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例 + self._reader = {'train': None, 'eval': None, 'pred': None} + self._input_layer = None + self._inputname_to_varname = {} + self._task_layer = {'train': None, 'eval': None, 'pred': None} + self._pred_input_name_list = [] + self._pred_input_varname_list = [] + self._pred_fetch_name_list = [] + self._pred_fetch_var_list = [] + + self._exe = fluid.Executor(fluid.CPUPlace()) + + self._save_protocol = { + 'input_names': 'self._pred_input_name_list', + 'input_varnames': 'self._pred_input_varname_list', + 'fetch_list': 'self._pred_fetch_name_list'} + + + def build_task_layer(self, net_inputs, phase, scope=""): + output_vars = self._task_layer[phase].build(net_inputs, scope_name=scope) + if phase == 'pred': + if output_vars is not None: + self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) + else: + self._pred_fetch_name_list = [] + self._pred_fetch_var_list = [] + return output_vars + + def postprocess(self, rt_outputs, phase): + return self._task_layer[phase].postprocess(rt_outputs) + + def epoch_postprocess(self, epoch_inputs, phase): + return self._task_layer[phase].epoch_postprocess(epoch_inputs) + + def save(self, suffix=''): + dirpath = self._save_infermodel_path + suffix + self._pred_input_varname_list = [str(i) for i in self._pred_input_varname_list] + + # fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, export_for_deployment = True) + prog = fluid.default_main_program().clone() + fluid.io.save_inference_model(dirpath, self._pred_input_varname_list, self._pred_fetch_var_list, self._exe, prog) + + conf = {} + for k, strv in self._save_protocol.items(): + d = None + v = locals() + exec('d={}'.format(strv), globals(), v) + conf[k] = v['d'] + with open(os.path.join(dirpath, '__conf__'), 'w') as writer: + writer.write(json.dumps(conf, indent=1)) + print(self._name + ': inference model saved at ' + dirpath) + + def load(self, infer_model_path=None): + if infer_model_path is None: + infer_model_path = self._save_infermodel_path + for k,v in json.load(open(os.path.join(infer_model_path, '__conf__'))).items(): + strv = self._save_protocol[k] + exec('{}=v'.format(strv)) + pred_prog, self._pred_input_varname_list, self._pred_fetch_var_list = \ + fluid.io.load_inference_model(infer_model_path, self._exe) + print(self._name+': inference model loaded from ' + infer_model_path) + return pred_prog + + @property + def name(self): + return self._name + + @property + def Reader(self): + return self._Reader + + # @Reader.setter + # def Reader(self, cls): + # assert base_reader.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_reader.__name__, \ + # cls.__bases__[-1].__name__) + # self._Reader = cls + + @property + def Paradigm(self): + return self._Paradigm + + # @Paradigm.setter + # def Paradigm(self, cls): + # assert base_paradigm.__name__ == cls.__bases__[-1].__name__, \ + # "expect: {}, receive: {}.".format(base_paradigm.__name__, \ + # cls.__bases__[-1].__name__) + # self._Paradigm = cls + + @property + def config(self): + return self._config + + @property + def reader(self): + return self._reader + + @property + def pred_input(self): + return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) + + @pred_input.setter + def pred_input(self, val): + assert isinstance(val, dict) + self._pred_input_name_list, self._pred_input_varname_list = \ + zip(*[[k, v.name] for k,v in val.items()]) + + @property + def pred_fetch_list(self): + return [self._pred_fetch_name_list, self._pred_fetch_var_list] + + @property + def task_layer(self): + return self._task_layer + + @property + def is_first_target(self): + return self._is_first_target + + @is_first_target.setter + def is_first_target(self, value): + self._is_first_target = bool(value) + if self._is_first_target: + assert self._is_target, "ERROR: only target task could be set as main task." + if self._verbose and self._is_first_target: + print("{}: set as main task".format(self._name)) + + @property + def is_target(self): + if self._is_target is not None: + return self._is_target + else: + raise ValueError("{}: is_target is None".format(self._name)) + + @is_target.setter + def is_target(self, value): + self._is_target = bool(value) + if self._verbose: + if self._is_target: + print('{}: set as target task.'.format(self._name)) + else: + print('{}: set as aux task.'.format(self._name)) + + @property + def mix_ratio(self): + if self._mix_ratio is not None: + return self._mix_ratio + else: + raise ValueError("{}: mix_ratio is None".format(self._name)) + + @mix_ratio.setter + def mix_ratio(self, value): + self._mix_ratio = float(value) + if self._verbose: + print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) + + @property + def save_infermodel_every_n_steps(self): + return self._save_infermodel_every_n_steps + + @property + def expected_train_steps(self): + return self._expected_train_steps + + @expected_train_steps.setter + def expected_train_steps(self, value): + self._expected_train_steps = value + self._expected_train_epochs = value / float(self._steps_pur_epoch) + + @property + def expected_train_epochs(self): + return self._expected_train_epochs + + @property + def cur_train_epoch(self): + return self._cur_train_epoch + + @cur_train_epoch.setter + def cur_train_epoch(self, value): + self._cur_train_epoch = value + + @property + def cur_train_step(self): + return self._cur_train_step + + @cur_train_step.setter + def cur_train_step(self, value): + self._cur_train_step = value + if self._cur_train_step > self._steps_pur_epoch: + self._cur_train_epoch += 1 + self._cur_train_step = 1 + if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: + self._train_finish = True + + @property + def steps_pur_epoch(self): + return self._steps_pur_epoch + + @steps_pur_epoch.setter + def steps_pur_epoch(self, value): + self._steps_pur_epoch = value + + @property + def train_finish(self): + return self._train_finish + + @property + def task_reuse_scope(self): + if self._task_reuse_scope is not None: + return self._task_reuse_scope + else: + raise ValueError("{}: task_reuse_scope is None".format(self._name)) + + @task_reuse_scope.setter + def task_reuse_scope(self, scope_name): + self._task_reuse_scope = str(scope_name) + if self._verbose: + print('{}: task_reuse_scope is set to {}'.format(self._name, self._task_reuse_scope)) + + + + + + + +def check_instances(insts): + """to check ids, first_target""" + pass + +def _check_ids(): + pass + +def _check_targets(): + pass + +def _check_reuse_scopes(): + pass diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index e3c0f18..ed3cc6f 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -22,17 +22,21 @@ import importlib from paddlepalm.default_settings import * -def Task(object): - def __init__(self, name, reader, taskblock, mix_ratio=1.0, \ - pred_reader=None, pred_taskblock=None, - infermodel_save_path=None, save_infermodel_every_n_steps=-1, \ - as_target_task=True, task_layer_reuse=None, silent=False): +def Trainer(object): + + def __init__(self, name, reader, task, mix_ratio=1.0, \ + save_predict_model=True, save_path=None, save_steps=-1)\ + reuse_with=None, silent=False): self._name = name self._verbose = not silent - if infermodel_save_path is None: - self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') + if save_predict_model: + assert save_path is not None, "save_path is required when save_predict_model is set." + assert save_steps == -1 or save_steps > 0, "save_steps should be -1 (only save the last step of this task) or larger than 0" + assert pred_reader is not None and pred_task is not None, "" + + self._save_infermodel_path = os.path.join(self._config['save_path'], self._name, 'infer_model') else: self._save_infermodel_path = infermodel_save_path -- GitLab