提交 50d1027b 编写于 作者: Z Zeyu Chen

update typo and remove useless variables

上级 14535dd4
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetuning on classification tasks."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import time
import argparse
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle_hub as hub
# yapf: disable
parser = argparse.ArgumentParser(__doc__)
parser.add_argument("--num_epoch", type=int, default=3, help="Number of epoches for fine-tuning.")
parser.add_argument("--learning_rate", type=float, default=5e-5, help="Learning rate used to train with warmup.")
parser.add_argument("--hub_module_dir", type=str, default=None, help="PaddleHub module directory")
parser.add_argument("--lr_scheduler", type=str, default="linear_warmup_decay",
help="scheduler of learning rate.", choices=['linear_warmup_decay', 'noam_decay'])
parser.add_argument("--weight_decay", type=float, default=0.01, help="Weight decay rate for L2 regularizer.")
parser.add_argument("--data_dir", type=str, default=None, help="Path to training data.")
parser.add_argument("--checkpoint_dir", type=str, default=None, help="Directory to model checkpoint")
parser.add_argument("--max_seq_len", type=int, default=512, help="Number of words of the longest seqence.")
parser.add_argument("--batch_size", type=int, default=32, help="Total examples' number in batch for training.")
args = parser.parse_args()
# yapf: enable.
if __name__ == '__main__':
strategy = hub.BERTFinetuneStrategy(weight_decay=args.weight_decay)
config = hub.FinetuneConfig(
log_interval=10,
eval_interval=100,
save_ckpt_interval=200,
checkpoint_dir=args.checkpoint_dir,
learning_rate=args.learning_rate,
num_epoch=args.num_epoch,
batch_size=args.batch_size,
strategy=strategy)
# loading Paddlehub BERT
module = hub.Module(module_dir=args.hub_module_dir)
# Use BERTTokenizeReader to tokenize the dataset according to model's
# vocabulary
reader = hub.reader.BERTTokenizeReader(
dataset=hub.dataset.ChnSentiCorp(), # download chnsenticorp dataset
vocab_path=module.get_vocab_path(),
max_seq_len=args.max_seq_len)
num_labels = len(reader.get_labels())
input_dict, output_dict, program = module.context(
sign_name="tokens", trainable=True, max_seq_len=args.max_seq_len)
with fluid.program_guard(program):
label = fluid.layers.data(name="label", shape=[1], dtype='int64')
# Use "pooled_output" for classification tasks on an entire sentence.
# Use "sequence_outputs" for token-level output.
pooled_output = output_dict["pooled_output"]
# Setup feed list for data feeder
# Must feed all the tensor of bert's module need
feed_list = [
input_dict["input_ids"].name, input_dict["position_ids"].name,
input_dict["segment_ids"].name, input_dict["input_mask"].name,
label.name
]
# Define a classfication finetune task by PaddleHub's API
cls_task = hub.append_mlp_classifier(
pooled_output, label, num_classes=num_labels)
# Finetune and evaluate by PaddleHub's API
# will finish training, evaluation, testing, save model automatically
hub.finetune_and_eval(
task=cls_task,
data_reader=reader,
feed_list=feed_list,
config=config)
export CUDA_VISIBLE_DEVICES=5
DATA_PATH=./chnsenticorp_data
HUB_MODULE_DIR="./hub_module/bert_chinese_L-12_H-768_A-12.hub_module"
#HUB_MODULE_DIR="./hub_module/ernie_stable.hub_module"
CKPT_DIR="./ckpt"
#rm -rf $CKPT_DIR
python -u finetune_with_hub.py \
--batch_size 32 \
--hub_module_dir=$HUB_MODULE_DIR \
--data_dir ${DATA_PATH} \
--weight_decay 0.01 \
--checkpoint_dir $CKPT_DIR \
--num_epoch 3 \
--max_seq_len 128 \
--learning_rate 5e-5
...@@ -34,7 +34,7 @@ from .io.type import DataType ...@@ -34,7 +34,7 @@ from .io.type import DataType
from .finetune.network import append_mlp_classifier from .finetune.network import append_mlp_classifier
from .finetune.finetune import finetune_and_eval from .finetune.finetune import finetune_and_eval
from .finetune.config import FinetuneConfig from .finetune.config import RunConfig
from .finetune.task import Task from .finetune.task import Task
from .finetune.strategy import BERTFinetuneStrategy from .finetune.strategy import BERTFinetuneStrategy
from .finetune.strategy import DefaultStrategy from .finetune.strategy import DefaultStrategy
......
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
import time import time
from .strategy import DefaultStrategy from .strategy import DefaultStrategy
from paddle_hub.common.utils import md5 from datetime import datetime
from paddle_hub.common.logger import logger
class FinetuneConfig(object):
class RunConfig(object):
""" This class specifies the configurations for PaddleHub to finetune """ """ This class specifies the configurations for PaddleHub to finetune """
def __init__(self, def __init__(self,
...@@ -45,9 +47,13 @@ class FinetuneConfig(object): ...@@ -45,9 +47,13 @@ class FinetuneConfig(object):
self._strategy = strategy self._strategy = strategy
self._enable_memory_optim = enable_memory_optim self._enable_memory_optim = enable_memory_optim
if checkpoint_dir is None: if checkpoint_dir is None:
self._checkpoint_dir = "hub_cpkt_" + md5(str(time.time()))[0:20]
now = int(time.time())
time_str = time.strftime("%Y%m%d%H%M%S", time.localtime(now))
self._checkpoint_dir = "ckpt_" + time_str
else: else:
self._checkpoint_dir = checkpoint_dir self._checkpoint_dir = checkpoint_dir
logger.info("Checkpoint dir: {}".format(self._checkpoint_dir))
@property @property
def log_interval(self): def log_interval(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册