机器翻译模型分布式训练下,每个trainer数据量不一致,参数等待超时挂掉
Created by: ccmeteorljh
paddle版本:v0.14 代码库地址:https://github.com/xuezhong/transformer-nist/tree/master/transformer_cloud 训练参数增加:
parser.add_argument(
'--iterations', type=int, default=100, help='The number of iters to break.')
train代码增加:
for pass_id in xrange(args.pass_num):
pass_start_time = 0.0
pass_num_token = 0
for batch_id, data in enumerate(train_data()):
feed_list = []
if batch_id == 5:
pass_start_time = time.time()
pass_num_token = 0
if batch_id > args.iterations:
break
训练情况: 2 psserver,4 trainer(其中3个会在train 过程中break,1个没有break)
为了让test部分耗时多点,test数据集较多; 训练参数如下:
Namespace(async_mode=False, batch_size=4096, iterations=100, local=False, opts=[], pass_num=1000, pool_size=10000, shuffle=True, shuffle_batch=True, sort_type='pool', special_token=['_GO', '_EOS', '_UNK'], src_vocab_fpath='./nist06n/cn_30001.dict', train_file_pattern='./nist06n/train_split/part-*', trg_vocab_fpath='./nist06n/en_30001.dict', use_gpu=True, use_token_batch=True, val_file_pattern='./nist06n/test/part-*')
train过程没有break的训练结果如下:
epoch: 0, batch: 98, sum loss: 309133.875000, avg loss: 10.245380, ppl: 28152.189453
epoch: 0, batch: 99, sum loss: 310333.437500, avg loss: 10.244056, ppl: 28114.921875
epoch: 0, batch: 100, sum loss: 313575.250000, avg loss: 10.240195, ppl: 28006.593750
F0626 03:53:35.576830 13122 grpc_client.cc:248] var: name:[fc_45.b_0] ep:[10.255.123.15:5002] grpc error:Deadline Exceeded
*** Check failure stack trace: ***
@ 0x7f20ed62c5ad google::LogMessage::Fail()
@ 0x7f20ed63005c google::LogMessage::SendToLog()
@ 0x7f20ed62c0d3 google::LogMessage::Flush()
@ 0x7f20ed63156e google::LogMessageFatal::~LogMessageFatal()
@ 0x7f20ed3e6208 paddle::operators::distributed::GRPCClient::Proceed()
@ 0x7f2140a357e0 execute_native_thread_routine
@ 0x7f21461616ba start_thread
@ 0x7f2145e9741d clone
@ (nil) (unknown)
train过程中当迭代次数超过设置的参数break的结果如下:
epoch: 0, batch: 94, sum loss: 308947.031250, avg loss: 10.253461, ppl: 28380.593750
epoch: 0, batch: 95, sum loss: 310470.312500, avg loss: 10.250605, ppl: 28299.648438
epoch: 0, batch: 96, sum loss: 306242.968750, avg loss: 10.248409, ppl: 28237.587891
epoch: 0, batch: 97, sum loss: 311251.625000, avg loss: 10.249329, ppl: 28263.560547
epoch: 0, batch: 98, sum loss: 310409.750000, avg loss: 10.245561, ppl: 28157.263672
epoch: 0, batch: 99, sum loss: 312839.937500, avg loss: 10.244955, ppl: 28140.216797
epoch: 0, batch: 100, sum loss: 310818.000000, avg loss: 10.243820, ppl: 28108.300781
Total examples: 2851800, total time: 161.33414, 17676.35742 examples/sed
----------begin test----------
epoch: 0, val avg loss: 10.237868, val ppl: 27941.481907, consumed 474.504052s
epoch: 1, batch: 0, sum loss: 300855.437500, avg loss: 10.238401, ppl: 27956.400391
F0626 03:59:05.936844 63356 grpc_client.cc:248] var: name:[fc_95.b_0@GRAD.trainer_3] ep:[10.255.123.16:5002] grpc error:Deadline Exceeded
*** Check failure stack trace: ***
@ 0x7f47aaa6a5ad google::LogMessage::Fail()
@ 0x7f47aaa6e05c google::LogMessage::SendToLog()
@ 0x7f47aaa6a0d3 google::LogMessage::Flush()
@ 0x7f47aaa6f56e google::LogMessageFatal::~LogMessageFatal()
@ 0x7f47aa824208 paddle::operators::distributed::GRPCClient::Proceed()
@ 0x7f47fde737e0 execute_native_thread_routine
@ 0x7f480359f6ba start_thread
@ 0x7f48032d541d clone
@ (nil) (unknown)
4个trainer,其中三个已经当前pass break出来了,在test过程中,而test要好几分钟的时间,剩下一个在等待的trainer还在继续train的过程,因此等待超时,挂了