提交 9ce9c215 编写于 作者: W wilfChen

gpu bert script

上级 5aeba82a
......@@ -21,6 +21,7 @@ import os
import argparse
import numpy
import mindspore.communication.management as D
import mindspore.common.dtype as mstype
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.parallel_utils import ParallelMode
......@@ -28,6 +29,7 @@ from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.train.callback import Callback, ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.optim import Lamb, Momentum, AdamWeightDecayDynamicLR
from mindspore import log as logger
from src import BertNetworkWithLoss, BertTrainOneStepCell, BertTrainOneStepWithLossScaleCell
from src.dataset import create_bert_dataset
from src.config import cfg, bert_net_cfg
......@@ -55,6 +57,8 @@ class LossCallBack(Callback):
def run_pretrain():
"""pre-train bert_clue"""
parser = argparse.ArgumentParser(description='bert pre_training')
parser.add_argument('--device_target', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='device where the code will be implemented. (Default: Ascend)')
parser.add_argument("--distribute", type=str, default="false", help="Run distribute, default is false.")
parser.add_argument("--epoch_size", type=int, default="1", help="Epoch size, default is 1.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default is 0.")
......@@ -74,11 +78,21 @@ def run_pretrain():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=args_opt.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
context.set_context(reserve_class_name_in_scope=False)
ckpt_save_dir = args_opt.checkpoint_path
if args_opt.distribute == "true":
device_num = args_opt.device_num
if args_opt.device_target == 'Ascend':
D.init('hccl')
device_num = args_opt.device_num
rank = args_opt.device_id % device_num
else:
D.init('nccl')
device_num = D.get_group_size()
rank = D.get_rank()
ckpt_save_dir = args_opt.checkpoint_path + 'ckpt_' + str(rank) + '/'
context.reset_auto_parallel_context()
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
......@@ -93,12 +107,15 @@ def run_pretrain():
auto_parallel_context().set_all_reduce_fusion_split_indices([30, 90, 150, 210, 270, 330, 390, 421])
else:
auto_parallel_context().set_all_reduce_fusion_split_indices([38, 93, 148, 203, 258, 313, 368, 397])
D.init()
rank = args_opt.device_id % device_num
else:
rank = 0
device_num = 1
if args_opt.device_target == 'GPU' and bert_net_cfg.compute_type != mstype.float32:
logger.warning('Gpu only support fp32 temporarily, run with fp32.')
bert_net_cfg.compute_type = mstype.float32
ds, new_repeat_count = create_bert_dataset(args_opt.epoch_size, device_num, rank, args_opt.do_shuffle,
args_opt.enable_data_sink, args_opt.data_sink_steps,
args_opt.data_dir, args_opt.schema_dir)
......@@ -130,7 +147,7 @@ def run_pretrain():
if args_opt.enable_save_ckpt == "true":
config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_checkpoint_steps,
keep_checkpoint_max=args_opt.save_checkpoint_num)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix='checkpoint_bert', directory=ckpt_save_dir, config=config_ck)
callback.append(ckpoint_cb)
if args_opt.checkpoint_path:
......
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_distribute_pretrain.sh DEVICE_NUM EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_distribute_pretrain.sh 8 40 /path/zh-wiki/ /path/Schema.json"
echo "It is better to use absolute path."
echo "=============================================================================================================="
RANK_SIZE=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
mpirun --allow-run-as-root -n $RANK_SIZE \
python run_pretrain.py \
--device_target="GPU" \
--distribute="true" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
......@@ -37,7 +37,7 @@ python run_pretrain.py \
--enable_lossscale="true" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=100 \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
......
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
echo "=============================================================================================================="
echo "Please run the scipt as: "
echo "bash run_standalone_pretrain.sh DEVICE_ID EPOCH_SIZE DATA_DIR SCHEMA_DIR"
echo "for example: bash run_standalone_pretrain.sh 0 40 /path/zh-wiki/ /path/Schema.json"
echo "=============================================================================================================="
DEVICE_ID=$1
EPOCH_SIZE=$2
DATA_DIR=$3
SCHEMA_DIR=$4
export CUDA_VISIBLE_DEVICES=$DEVICE_ID
mkdir -p ms_log
CUR_DIR=`pwd`
export GLOG_log_dir=${CUR_DIR}/ms_log
export GLOG_logtostderr=0
python run_pretrain.py \
--device_target="GPU" \
--distribute="false" \
--epoch_size=$EPOCH_SIZE \
--enable_save_ckpt="true" \
--enable_lossscale="false" \
--do_shuffle="true" \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册