提交 2ef32167 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5378 update tiny bert shell script and readme for support mindrecord

Merge pull request !5378 from guozhijian/update_tinybert_readme
......@@ -44,12 +44,12 @@ After installing MindSpore via the official website, you can start general disti
# run standalone general distill example
bash scripts/run_standalone_gd.sh
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir` and `schema_dir` in the run_standalone_gd.sh file first. If running on GPU, please set the `device_target=GPU`.
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir`, `schema_dir` and `dataset_type` in the run_standalone_gd.sh file first. If running on GPU, please set the `device_target=GPU`.
# For Ascend device, run distributed general distill example
bash scripts/run_distributed_gd_ascend.sh 8 1 /path/hccl.json
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir` and `schema_dir` in the run_distributed_gd_ascend.sh file first.
Before running the shell script, please set the `load_teacher_ckpt_path`, `data_dir`, `schema_dir` and `dataset_type` in the run_distributed_gd_ascend.sh file first.
# For GPU device, run distributed general distill example
bash scripts/run_distributed_gd_gpu.sh 8 1 /path/data/ /path/schema.json /path/teacher.ckpt
......@@ -57,7 +57,7 @@ bash scripts/run_distributed_gd_gpu.sh 8 1 /path/data/ /path/schema.json /path/t
# run task distill and evaluation example
bash scripts/run_standalone_td.sh
Before running the shell script, please set the `task_name`, `load_teacher_ckpt_path`, `load_gd_ckpt_path`, `train_data_dir`, `eval_data_dir` and `schema_dir` in the run_standalone_td.sh file first.
Before running the shell script, please set the `task_name`, `load_teacher_ckpt_path`, `load_gd_ckpt_path`, `train_data_dir`, `eval_data_dir`, `schema_dir` and `dataset_type` in the run_standalone_td.sh file first.
If running on GPU, please set the `device_target=GPU`.
```
......@@ -101,7 +101,7 @@ usage: run_general_distill.py [--distribute DISTRIBUTE] [--epoch_size N] [----
[--save_ckpt_path SAVE_CKPT_PATH]
[--load_teacher_ckpt_path LOAD_TEACHER_CKPT_PATH]
[--save_checkpoint_step N] [--max_ckpt_num N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [train_steps N]
[--data_dir DATA_DIR] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE] [train_steps N]
options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
......@@ -118,6 +118,7 @@ options:
--load_teacher_ckpt_path path to load teacher checkpoint files: PATH, default is ""
--data_dir path to dataset directory: PATH, default is ""
--schema_dir path to schema.json file, PATH, default is ""
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
```
### Task Distill
......@@ -132,7 +133,7 @@ usage: run_general_task.py [--device_target DEVICE_TARGET] [--do_train DO_TRAIN
[--load_td1_ckpt_path LOAD_TD1_CKPT_PATH]
[--train_data_dir TRAIN_DATA_DIR]
[--eval_data_dir EVAL_DATA_DIR]
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR]
[--task_name TASK_NAME] [--schema_dir SCHEMA_DIR] [--dataset_type DATASET_TYPE]
options:
--device_target device where the code will be implemented: "Ascend" | "GPU", default is "Ascend"
......@@ -153,6 +154,7 @@ options:
--eval_data_dir path to eval dataset directory: PATH, default is ""
--task_name classification task: "SST-2" | "QNLI" | "MNLI", default is ""
--schema_dir path to schema.json file, PATH, default is ""
--dataset_type the dataset type which can be tfrecord/mindrecord, default is tfrecord
```
## Options and Parameters
......@@ -344,4 +346,4 @@ In run_general_distill.py, we set the random seed to make sure distribute traini
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
\ No newline at end of file
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
......@@ -55,7 +55,8 @@ def run_general_distill():
parser.add_argument("--load_teacher_ckpt_path", type=str, default="", help="Load checkpoint file path")
parser.add_argument("--data_dir", type=str, default="", help="Data path, it is better to use absolute path")
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord")
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target, device_id=args_opt.device_id)
......
......@@ -68,7 +68,8 @@ def parse_args():
parser.add_argument("--schema_dir", type=str, default="", help="Schema path, it is better to use absolute path")
parser.add_argument("--task_name", type=str, default="", choices=["SST-2", "QNLI", "MNLI"],
help="The name of the task to train.")
parser.add_argument("--dataset_type", type=str, default="tfrecord", help="dataset type, default is tfrecord")
parser.add_argument("--dataset_type", type=str, default="tfrecord",
help="dataset type tfrecord/mindrecord, default is tfrecord")
args = parser.parse_args()
return args
......
......@@ -65,6 +65,7 @@ do
--max_ckpt_num=1 \
--load_teacher_ckpt_path="" \
--data_dir="" \
--schema_dir="" > log.txt 2>&1 &
--schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 &
cd ../
done
......@@ -37,5 +37,6 @@ mpirun --allow-run-as-root -n $RANK_SIZE \
--save_ckpt_path="" \
--data_dir=$DATA_DIR \
--schema_dir=$SCHEMA_DIR \
--dataset_type="tfrecord" \
--enable_data_sink=False \
--load_teacher_ckpt_path=$TEACHER_CKPT_PATH > log.txt 2>&1 &
......@@ -43,5 +43,6 @@ python ${PROJECT_DIR}/../run_task_distill.py \
--load_td1_ckpt_path="" \
--train_data_dir="" \
--eval_data_dir="" \
--schema_dir="" > log.txt 2>&1 &
--schema_dir="" \
--dataset_type="tfrecord" > log.txt 2>&1 &
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册