提交 caecc97a 编写于 作者: X Xin Pan

single device

上级 d2243979
class TrainTaskConfig(object):
use_gpu = False
use_gpu = True
# the epoch number to train.
pass_num = 2
# number of sequences contained in a mini-batch.
batch_size = 64
batch_size = 32
# the hyper params for Adam optimizer.
learning_rate = 0.001
......
import numpy as np
import sys
import time
import paddle.v2 as paddle
import paddle.fluid as fluid
import paddle.fluid.profiler as profiler
from model import transformer, position_encoding_init
from optim import LearningRateScheduler
......@@ -127,23 +130,41 @@ def main():
position_encoding_init(ModelHyperParams.max_length + 1,
ModelHyperParams.d_model), place)
def fn(pass_id, batch_id, data):
t1 = time.time()
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[cost],
use_program_cache=True)
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" cost = " + str(cost_val))
return time.time() - t1
# with open('/tmp/program', 'w') as f:
# f.write('%s' % fluid.framework.default_main_program())
total_time = 0.0
count = 0
for pass_id in xrange(TrainTaskConfig.pass_num):
for batch_id, data in enumerate(train_data()):
# The current program desc is coupled with batch_size, thus all
# mini-batches must have the same number of instances currently.
if len(data) != TrainTaskConfig.batch_size:
continue
data_input = prepare_batch_input(
data, input_data_names, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.max_length,
ModelHyperParams.n_head, place)
lr_scheduler.update_learning_rate(data_input)
outs = exe.run(fluid.framework.default_main_program(),
feed=data_input,
fetch_list=[cost])
cost_val = np.array(outs[0])
print("pass_id = " + str(pass_id) + " batch = " + str(batch_id) +
" cost = " + str(cost_val))
if pass_id == 0 and batch_id >= 10 and batch_id < 12:
with profiler.profiler('All', 'total', '/tmp/transformer'):
duration = fn(pass_id, batch_id, data)
else:
duration = fn(pass_id, batch_id, data)
count += 1
total_time += duration
print("avg: " + str(total_time / count) + " cur: " + str(duration))
sys.stdout.flush()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册