未验证 提交 abb108dc 编写于 作者: X Xiaoyao Xi 提交者: GitHub

Update base_backbone.py

上级 c4b03ce3
...@@ -12,23 +12,26 @@ ...@@ -12,23 +12,26 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""v1.1"""
class Backbone(object): class Backbone(object):
"""interface of backbone model.""" """interface of backbone model."""
def __init__(self, config, phase): def __init__(self, phase):
""" """该函数完成一个主干网络的构造,至少需要包含一个phase参数。
注意:实现该构造函数时,必须保证对基类构造函数的调用,以创建必要的框架内建的成员变量。
Args: Args:
config: dict类型。描述了 多任务配置文件+预训练模型配置文件 中定义超参数 phase: str类型。用于区分主干网络被调用时所处的运行阶段,目前支持训练阶段train和预测阶段predict
phase: str类型。运行阶段,目前支持train和predict
""" """
assert isinstance(config, dict) assert isinstance(config, dict)
@property @property
def inputs_attr(self): def inputs_attr(self):
"""描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 """描述backbone从reader处需要得到的输入对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象
为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape
中的相应维度设置为-1。
Return: Return:
dict类型。对各个输入对象的属性描述。例如, dict类型。对各个输入对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象 对于文本分类和匹配任务,bert backbone依赖的reader对象主要包含如下的对象
...@@ -40,7 +43,9 @@ class Backbone(object): ...@@ -40,7 +43,9 @@ class Backbone(object):
@property @property
def outputs_attr(self): def outputs_attr(self):
"""描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 """描述backbone输出对象的属性,包含各个对象的名字、shape以及数据类型。当某个对象为标量数据类型(如
str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。
Return: Return:
dict类型。对各个输出对象的属性描述。例如, dict类型。对各个输出对象的属性描述。例如,
对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象 对于文本分类和匹配任务,bert backbone的输出内容可能包含如下的对象
...@@ -57,4 +62,3 @@ class Backbone(object): ...@@ -57,4 +62,3 @@ class Backbone(object):
需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。
""" """
raise NotImplementedError() raise NotImplementedError()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册