提交 69cbf585 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!1902 Fix bert scripts.

Merge pull request !1902 from chenhaozhe/fix-bert-scripts
......@@ -19,6 +19,7 @@ python run_pretrain.py
import os
import argparse
import numpy
import mindspore.communication.management as D
from mindspore import context
from mindspore.train.model import Model
......@@ -142,4 +143,5 @@ def run_pretrain():
model = Model(netwithgrads)
model.train(new_repeat_count, ds, callbacks=callback, dataset_sink_mode=(args_opt.enable_data_sink == "true"))
if __name__ == '__main__':
numpy.random.seed(0)
run_pretrain()
......@@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
......@@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(new_repeat_count)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
......@@ -32,7 +32,6 @@ from .bert_model import BertModel
GRADIENT_CLIP_TYPE = 1
GRADIENT_CLIP_VALUE = 1.0
_nn_clip_by_norm = nn.ClipByNorm()
clip_grad = C.MultitypeFuncGraph("clip_grad")
......@@ -57,7 +56,7 @@ def _clip_grad(clip_type, clip_value, grad):
new_grad = C.clip_by_value(grad, F.cast(F.tuple_to_array((-clip_value,)), dt),
F.cast(F.tuple_to_array((clip_value,)), dt))
else:
new_grad = _nn_clip_by_norm(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
new_grad = nn.ClipByNorm()(grad, F.cast(F.tuple_to_array((clip_value,)), dt))
return new_grad
......
......@@ -56,7 +56,7 @@ if cfg.bert_network == 'base':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
vocab_size=21128,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
......@@ -77,7 +77,7 @@ if cfg.bert_network == 'nezha':
bert_net_cfg = BertConfig(
batch_size=32,
seq_length=128,
vocab_size=21136,
vocab_size=21128,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
......@@ -98,7 +98,7 @@ if cfg.bert_network == 'large':
bert_net_cfg = BertConfig(
batch_size=16,
seq_length=512,
vocab_size=30528,
vocab_size=30522,
hidden_size=1024,
num_hidden_layers=24,
num_attention_heads=16,
......
......@@ -39,6 +39,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
ori_dataset_size = ds.get_dataset_size()
print('origin dataset size: ', ori_dataset_size)
new_size = ori_dataset_size
if enable_data_sink == "true":
new_size = data_sink_steps * bert_net_cfg.batch_size
......@@ -53,7 +54,7 @@ def create_bert_dataset(epoch_size=1, device_num=1, rank=0, do_shuffle="true", e
ds = ds.map(input_columns="input_ids", operations=type_cast_op)
# apply batch operations
ds = ds.batch(bert_net_cfg.batch_size, drop_remainder=True)
ds = ds.repeat(new_repeat_count)
ds = ds.repeat(max(new_repeat_count, repeat_count))
logger.info("data size: {}".format(ds.get_dataset_size()))
logger.info("repeatcount: {}".format(ds.get_repeat_count()))
return ds, new_repeat_count
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册