提交 586d6673 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!5307 change enable_data_sink value to true for transformer

Merge pull request !5307 from yuchaojie/transformer2
...@@ -52,7 +52,7 @@ do ...@@ -52,7 +52,7 @@ do
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="false" \ --enable_data_sink="true" \
--checkpoint_path="" \ --checkpoint_path="" \
--save_checkpoint_steps=2500 \ --save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \ --save_checkpoint_num=30 \
......
...@@ -37,7 +37,7 @@ python train.py \ ...@@ -37,7 +37,7 @@ python train.py \
--enable_save_ckpt="true" \ --enable_save_ckpt="true" \
--enable_lossscale="true" \ --enable_lossscale="true" \
--do_shuffle="true" \ --do_shuffle="true" \
--enable_data_sink="false" \ --enable_data_sink="true" \
--checkpoint_path="" \ --checkpoint_path="" \
--save_checkpoint_steps=2500 \ --save_checkpoint_steps=2500 \
--save_checkpoint_num=30 \ --save_checkpoint_num=30 \
......
...@@ -170,7 +170,8 @@ def run_transformer_train(): ...@@ -170,7 +170,8 @@ def run_transformer_train():
netwithgrads.set_train(True) netwithgrads.set_train(True)
model = Model(netwithgrads) model = Model(netwithgrads)
model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true")) model.train(args.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=(args.enable_data_sink == "true"),
sink_size=args.save_checkpoint_steps)
if __name__ == '__main__': if __name__ == '__main__':
run_transformer_train() run_transformer_train()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册