提交 4e0447b0 编写于 作者: D dengyutao

support minddataset for tinybert

上级 b2cff284
......@@ -28,7 +28,7 @@ from mindspore.train.parallel_utils import ParallelMode
from mindspore.nn.optim import AdamWeightDecay
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore import log as logger
from src.dataset import create_tinybert_dataset
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, BertLearningRate
from src.gd_config import common_cfg, bert_teacher_net_cfg, bert_student_net_cfg
from src.tinybert_for_gd_td import BertTrainWithLossScaleCell, BertNetworkWithLoss_gd, BertTrainCell
......@@ -55,6 +55,7 @@ def run_general_distill():
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
......@@ -99,8 +100,15 @@ def run_general_distill():
student_config=bert_student_net_cfg,
is_training=True, use_one_hot_embeddings=False)
if args_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif arg_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('gd', bert_teacher_net_cfg.batch_size, device_num, rank,
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir)
args_opt.do_shuffle, args_opt.data_dir, args_opt.schema_dir,
data_type=dataset_type)
dataset_size = dataset.get_dataset_size()
print('dataset size: ', dataset_size)
print("dataset repeatcount: ", dataset.get_repeat_count())
......
......@@ -27,7 +27,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
from mindspore.nn.optim import AdamWeightDecay
from mindspore import log as logger
from src.dataset import create_tinybert_dataset
from src.dataset import create_tinybert_dataset, DataType
from src.utils import LossCallBack, ModelSaveCkpt, EvalCallBack, BertLearningRate
from src.assessment_method import Accuracy
from src.td_config import phase1_cfg, phase2_cfg, td_teacher_net_cfg, td_student_net_cfg
......@@ -68,7 +68,7 @@ def parse_args():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
help="The name of the task to train.")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
args = parser.parse_args()
return args
......@@ -119,9 +119,17 @@ def run_predistill():
rank = 0
device_num = 1
if arg_opt.dataset_type == "tfrecord":
dataset_type = DataType.TFRECORD
elif arg_opt.dataset_type == "mindrecord":
dataset_type = DataType.MINDRECORD
else:
raise Exception("dataset format is not supported yet")
dataset = create_tinybert_dataset('td', td_teacher_net_cfg.batch_size,
device_num, rank, args_opt.do_shuffle,
args_opt.train_data_dir, args_opt.schema_dir)
args_opt.train_data_dir, args_opt.schema_dir,
data_tpye=dataset_type)
dataset_size = dataset.get_dataset_size()
print('td1 dataset size: ', dataset_size)
......
......@@ -39,4 +39,5 @@ python ${PROJECT_DIR}/../run_general_distill.py \
--save_ckpt_path="" \
--load_teacher_ckpt_path="" \
--data_dir="" \
--schema_dir="" > log.txt 2>&1 &
--schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 &
......@@ -16,26 +16,38 @@
"""create tinybert dataset"""
import os
from enum import Enum
import mindspore.common.dtype as mstype
import mindspore.dataset.engine.datasets as de
import mindspore.dataset.transforms.c_transforms as C
class DataType(Enum):
"""Enumerate supported dataset format"""
TFRECORD = 1
MINDRECORD = 2
def create_tinybert_dataset(task='td', batch_size=32, device_num=1, rank=0,
do_shuffle="true", data_dir=None, schema_dir=None):
do_shuffle="true", data_dir=None, schema_dir=None,
data_type=DataType.TFRECORD):
"""create tinybert dataset"""
files = os.listdir(data_dir)
data_files = []
for file_name in files:
if "record" in file_name:
if "record" in file_name and "db" not in file_name:
data_files.append(os.path.join(data_dir, file_name))
if task == "td":
columns_list = ["input_ids", "input_mask", "segment_ids", "label_ids"]
else:
columns_list = ["input_ids", "input_mask", "segment_ids"]
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
if data_type == DataType.MINDRECORD:
ds = de.MindDataset(data_files, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank)
else:
ds = de.TFRecordDataset(data_files, schema_dir, columns_list=columns_list,
shuffle=(do_shuffle == "true"), num_shards=device_num, shard_id=rank,
shard_equal_rows=True)
type_cast_op = C.TypeCast(mstype.int32)
ds = ds.map(input_columns="segment_ids", operations=type_cast_op)
ds = ds.map(input_columns="input_mask", operations=type_cast_op)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册