未验证 提交 7babcff1 编写于 作者: T tianxin 提交者: GitHub

Merge pull request #197 from zhengya01/ce_ernie

add ce for ERNIE
set -eux
export FLAGS_sync_nccl_allreduce=1
MODEL_PATH=ERNIE_1.0.1
TASK_DATA_PATH=task_data
train() {
python -u run_classifier.py \
--use_cuda true \
--do_train true \
--do_val true \
--do_test true \
--verbose true \
--batch_size 8192 \
--in_tokens true \
--init_pretraining_params ${MODEL_PATH}/params \
--train_set ${TASK_DATA_PATH}/xnli/train.tsv \
--dev_set ${TASK_DATA_PATH}/xnli/dev.tsv \
--test_set ${TASK_DATA_PATH}/xnli/test.tsv \
--vocab_path config/vocab.txt \
--label_map ${TASK_DATA_PATH}/xnli/label_map.json \
--ernie_config_path config/ernie_config.json \
--checkpoints ./checkpoints \
--save_steps 2000 \
--weight_decay 0.01 \
--warmup_proportion 0.0 \
--validation_steps 25 \
--epoch 1 \
--max_seq_len 512 \
--learning_rate 1e-4 \
--skip_steps 10 \
--num_iteration_per_drop_scope 1 \
--num_labels 3 \
--random_seed 100 \
--enable_ce \
--shuffle false
}
export CUDA_VISIBLE_DEVICES=0
train | python _ce.py
export CUDA_VISIBLE_DEVICES=0,1,2,3
train | python _ce.py
####this file is only used for continuous evaluation test!
import os
import sys
sys.path.insert(0, os.environ['ceroot'])
from kpi import CostKpi, DurationKpi, AccKpi
#### NOTE kpi.py should shared in models in some way!!!!
train_loss_card1_kpi = CostKpi('train_loss_card1', 0.03, 0, actived=True)
train_acc_card1_kpi = AccKpi('train_acc_card1', 0.06, 0, actived=True)
train_duration_card1_kpi = DurationKpi(
'train_duration_card1', 0.01, 0, actived=True)
train_loss_card4_kpi = CostKpi('train_loss_card4', 0.01, 0, actived=True)
train_acc_card4_kpi = AccKpi('train_acc_card4', 0.02, 0, actived=True)
train_duration_card4_kpi = DurationKpi(
'train_duration_card4', 0.02, 0, actived=True)
tracking_kpis = [
train_loss_card1_kpi,
train_acc_card1_kpi,
train_duration_card1_kpi,
train_loss_card4_kpi,
train_acc_card4_kpi,
train_duration_card4_kpi,
]
def parse_log(log):
'''
This method should be implemented by model developers.
The suggestion:
each line in the log should be key, value, for example:
"
train_loss\t1.0
test_loss\t1.0
train_loss\t1.0
train_acc\t1.2
"
'''
for line in log.split('\n'):
fs = line.strip().split('\t')
print(fs)
if len(fs) == 3 and fs[0] == 'kpis':
print("-----%s" % fs)
kpi_name = fs[1]
kpi_value = float(fs[2])
yield kpi_name, kpi_value
def log_to_ce(log):
kpi_tracker = {}
for kpi in tracking_kpis:
kpi_tracker[kpi.name] = kpi
for (kpi_name, kpi_value) in parse_log(log):
print(kpi_name, kpi_value)
kpi_tracker[kpi_name].add_record(kpi_value)
kpi_tracker[kpi_name].persist()
if __name__ == '__main__':
log = sys.stdin.read()
print("*****")
print(log)
print("****")
log_to_ce(log)
......@@ -74,4 +74,7 @@ run_type_g.add_arg("do_train", bool, True, "Whether to pe
run_type_g.add_arg("do_val", bool, True, "Whether to perform evaluation on dev data set.")
run_type_g.add_arg("do_test", bool, True, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("metrics", bool, True, "Whether to perform evaluation on test data set.")
run_type_g.add_arg("shuffle", bool, True, "")
parser.add_argument("--enable_ce", action='store_true', help="The flag indicating whether to run the task for continuous evaluation.")
# yapf: enable
......@@ -29,6 +29,7 @@ from finetune.classifier import create_model, evaluate
from optimization import optimization
from utils.args import print_arguments, check_cuda
from utils.init import init_pretraining_params, init_checkpoint
from utils.cards import get_cards
from finetune_args import parser
args = parser.parse_args()
......@@ -67,7 +68,7 @@ def main(args):
input_file=args.train_set,
batch_size=args.batch_size,
epoch=args.epoch,
shuffle=True,
shuffle=args.shuffle,
phase="train")
num_train_examples = reader.get_num_examples(args.train_set)
......@@ -85,6 +86,8 @@ def main(args):
print("Num warmup steps: %d" % warmup_steps)
train_program = fluid.Program()
if args.random_seed is not None and args.enable_ce:
train_program.random_seed = args.random_seed
with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard():
......@@ -187,6 +190,7 @@ def main(args):
if warmup_steps > 0:
graph_vars["learning_rate"] = scheduled_lr
ce_info = []
time_begin = time.time()
while True:
try:
......@@ -213,6 +217,7 @@ def main(args):
(current_epoch, current_example, num_train_examples,
steps, outputs["loss"], outputs["accuracy"],
args.skip_steps / used_time))
ce_info.append([outputs["loss"], outputs["accuracy"], used_time])
time_begin = time.time()
if steps % args.save_steps == 0:
......@@ -246,6 +251,24 @@ def main(args):
fluid.io.save_persistables(exe, save_path, train_program)
train_pyreader.reset()
break
if args.enable_ce:
card_num = get_cards()
ce_loss = 0
ce_acc = 0
ce_time = 0
try:
ce_loss = ce_info[-2][0]
ce_acc = ce_info[-2][1]
ce_time = ce_info[-2][2]
except:
print("ce info error")
print("kpis\ttrain_duration_card%s\t%s" %
(card_num, ce_time))
print("kpis\ttrain_loss_card%s\t%f" %
(card_num, ce_loss))
print("kpis\ttrain_acc_card%s\t%f" %
(card_num, ce_acc))
# final eval on dev set
if args.do_val:
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
def get_cards():
"""
get gpu cards number
"""
num = 0
cards = os.environ.get('CUDA_VISIBLE_DEVICES', '')
if cards != '':
num = len(cards.split(","))
return num
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册