提交 883622f2 编写于 作者: Y Yu Yang

Clean code

上级 8e3e6392
......@@ -196,39 +196,30 @@ def main():
loss_name=avg_cost.name
if TrainTaskConfig.use_avg_cost else sum_cost.name)
dev_count = fluid.core.get_cuda_device_count()
for pos_enc_param_name in pos_enc_param_names:
tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
for place_id in xrange(dev_count):
local_scope = train_exe.executor.local_scope(place_id)
local_scope.var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))
train_data = read_multiple(reader=train_data, count=dev_count)
train_data = read_multiple(reader=train_data, count=train_exe.device_count)
for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time()
for batch_id, data in enumerate(train_data()):
data_on_devices = []
for place_id, data_buffer in enumerate(data):
data_input_dict, util_input_dict = prepare_batch_input(
data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model)
local_scope = train_exe.executor.local_scope(place_id)
local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set(
lr_scheduler.update_learning_rate(),
fluid.CUDAPlace(place_id))
data_input_dict.update(util_input_dict)
data_input_dict.update({
lr_scheduler.learning_rate.name: lr_scheduler.update_learning_rate()
})
for var_name in data_input_dict:
local_scope.var(var_name).get_tensor().set(data_input_dict[var_name],
fluid.CUDAPlace(place_id))
for pos_enc_param_name in pos_enc_param_names:
tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
data_input_dict[pos_enc_param_name] = tensor
for var_name in util_input_dict:
local_scope.var(var_name).get_tensor().set(util_input_dict[var_name],
fluid.CUDAPlace(place_id))
data_on_devices.append(data_input_dict)
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name])
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], feed=data_on_devices)
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi devices
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册