提交 f98797ef 编写于 作者: Z Zeyu Chen

add hub task

上级 02086c0d
...@@ -26,3 +26,4 @@ from .tools.logger import logger ...@@ -26,3 +26,4 @@ from .tools.logger import logger
from .tools.paddle_helper import connect_program from .tools.paddle_helper import connect_program
from .io.type import DataType from .io.type import DataType
from .hub_server import default_hub_server from .hub_server import default_hub_server
from .finetune.task import append_mlp_classifier
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import paddle.fluid as fluid
__all__ = ['append_mlp_classifier']
def append_mlp_classifier(feature, label, num_classes=2, hidden_units=None):
cls_feats = fluid.layers.dropout(
x=feature, dropout_prob=0.1, dropout_implementation="upscale_in_train")
# append fully connected layer according to hidden_units
if hidden_units != None:
for n_hidden in hidden_units:
cls_feats = fluid.layers.fc(input=cls_feats, size=n_hidden)
logits = fluid.layers.fc(
input=cls_feats,
size=num_classes,
param_attr=fluid.ParamAttr(
name="cls_out_w",
initializer=fluid.initializer.TruncatedNormal(scale=0.02)),
bias_attr=fluid.ParamAttr(
name="cls_out_b", initializer=fluid.initializer.Constant(0.)))
ce_loss, probs = fluid.layers.softmax_with_cross_entropy(
logits=logits, label=label, return_softmax=True)
loss = fluid.layers.mean(x=ce_loss)
num_example = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(
input=probs, label=label, total=num_example)
# TODO: encapsulate to Task
return loss, probs, accuracy, num_example
class Task(object):
def __init__(self):
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册