提交 46c175a5 编写于 作者: W wsc

Modify example scripts of bert model

上级 88215d00
......@@ -26,12 +26,16 @@ DATA_DIR=$3
SCHEMA_DIR=$4
export MINDSPORE_HCCL_CONFIG_PATH=$5
export RANK_TABLE_FILE=$5
export RANK_SIZE=$1
for((i=0;i<RANK_SIZE;i++))
do
export DEVICE_ID=$i
start=`expr $i \* 12`
export DEVICE_ID=$i
export RANK_ID=$i
export DEPLOY_MODE=0
export GE_USE_STATIC_MEMORY=1
end=`expr $start \+ 11`
cmdopt=$start"-"$end
......@@ -39,7 +43,6 @@ do
mkdir ./LOG$i
cp *.py ./LOG$i
cd ./LOG$i || exit
export RANK_ID=$i
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../run_pretrain.py \
......@@ -56,7 +59,7 @@ do
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=1000 \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
......
......@@ -84,13 +84,11 @@ def run_pretrain():
if args_opt.distribute == "true":
device_num = args_opt.device_num
context.reset_auto_parallel_context()
context.set_context(enable_hccl=True)
context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True,
device_num=device_num)
D.init()
rank = args_opt.device_id % device_num
else:
context.set_context(enable_hccl=False)
rank = 0
device_num = 1
......@@ -103,7 +101,7 @@ def run_pretrain():
optimizer = Lamb(netwithloss.trainable_params(), decay_steps=ds.get_dataset_size() * ds.get_repeat_count(),
start_learning_rate=cfg.Lamb.start_learning_rate, end_learning_rate=cfg.Lamb.end_learning_rate,
power=cfg.Lamb.power, warmup_steps=cfg.Lamb.warmup_steps, weight_decay=cfg.Lamb.weight_decay,
eps=cfg.Lamb.eps, decay_filter=cfg.Lamb.decay_filter)
eps=cfg.Lamb.eps)
elif cfg.optimizer == 'Momentum':
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=cfg.Momentum.learning_rate,
momentum=cfg.Momentum.momentum)
......
......@@ -38,7 +38,7 @@ python run_pretrain.py \
--enable_data_sink="true" \
--data_sink_steps=1 \
--checkpoint_path="" \
--save_checkpoint_steps=1000 \
--save_checkpoint_steps=10000 \
--save_checkpoint_num=1 \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR > log.txt 2>&1 &
......@@ -76,26 +76,6 @@ def get_config(version='base', batch_size=1):
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float16)
elif version == 'large_mixed':
bert_config = BertConfig(
batch_size=batch_size,
seq_length=128,
vocab_size=21136,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
intermediate_size=4096,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=512,
type_vocab_size=2,
initializer_range=0.02,
use_relative_positions=True,
input_mask_from_dataset=True,
token_type_ids_from_dataset=True,
dtype=mstype.float32,
compute_type=mstype.float32)
else:
bert_config = BertConfig(batch_size=batch_size)
return bert_config
......@@ -136,8 +116,8 @@ class ModelCallback(Callback):
def step_end(self, run_context):
cb_params = run_context.original_args()
self.loss_list.append(cb_params.net_outputs[0].asnumpy()[0])
self.overflow_list.append(cb_params.net_outputs[1])
self.lossscale_list.append(cb_params.net_outputs[2])
self.overflow_list.append(cb_params.net_outputs[1].asnumpy())
self.lossscale_list.append(cb_params.net_outputs[2].asnumpy())
print("epoch: {}, outputs are: {}".format(cb_params.cur_epoch_num, str(cb_params.net_outputs)))
@pytest.mark.level0
......@@ -157,7 +137,7 @@ def test_bert_tdt():
netwithloss = BertNetworkWithLoss(config, True)
optimizer = Momentum(netwithloss.trainable_params(), learning_rate=2e-5, momentum=0.9)
scale_window = 3
scale_manager = DynamicLossScaleManager(2**32, 2, scale_window)
scale_manager = DynamicLossScaleManager(2**16, 2, scale_window)
netwithgrads = BertTrainOneStepWithLossScaleCell(netwithloss, optimizer=optimizer, scale_update_cell=scale_manager.get_update_cell())
netwithgrads.set_train(True)
model = Model(netwithgrads)
......@@ -182,22 +162,21 @@ def test_bert_tdt():
param.default_input = weight_variable(value.asnumpy().shape)
model.train(ds.get_repeat_count(), ds, callbacks=callback, dataset_sink_mode=False)
# assertion occurs while the loss_scale value is wrong
count = 0
for i in range(len(callback.overflow_list)):
if callback.overflow_list[i] == Tensor(True, mstype.bool_) and i > 0:
count = 0
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(0.5, mstype.float32)
if callback.overflow_list[i] == Tensor(False, mstype.bool_):
count = count + 1
if count == scale_window:
count = 0
assert callback.lossscale_list[i] == callback.lossscale_list[i - 1] * Tensor(2.0, mstype.float32)
# assertion occurs while the loss value is wrong
# assertion occurs while the loss value, overflow state or loss_scale value is wrong
loss_value = np.array(callback.loss_list)
expect_value = [12.1918125, 11.966035, 11.972114, 11.982671, 11.976399, 12.616986, 12.180658, 12.850562, 12.415608, 12.640145]
expect_loss_value = [12.1918125, 11.966035, 11.972114, 11.982188, 11.974092, 12.610916, 12.17565, 12.840416, 12.40291, 12.621661]
print("loss value: {}".format(loss_value))
assert np.allclose(loss_value, expect_value, 0.00001, 0.00001)
assert np.allclose(loss_value, expect_loss_value, 0.00001, 0.00001)
overflow = np.array(callback.overflow_list)
expect_overflow = [True, True, False, False, False, True, False, False, False, True]
print("overflow: {}".format(overflow))
assert (overflow == expect_overflow).all()
loss_scale = np.array(callback.lossscale_list)
expect_loss_scale = [32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0, 16384.0, 16384.0, 32768.0, 16384.0]
print("loss scale: {}".format(loss_scale))
assert np.allclose(loss_scale, expect_loss_scale, 0.00001, 0.00001)
if __name__ == '__main__':
test_bert_tdt()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册