提交 29c47d15 编写于 作者: Y Yu Yang

Do not use feed

上级 6fec6837
...@@ -113,6 +113,21 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx, ...@@ -113,6 +113,21 @@ def prepare_batch_input(insts, data_input_names, util_input_names, src_pad_idx,
return data_input_dict, util_input_dict return data_input_dict, util_input_dict
def read_multiple(reader, count):
def __impl__():
res = []
for item in reader():
res.append(item)
if len(res) == count:
yield res
res = []
if len(res) == count:
yield res
return __impl__
def main(): def main():
place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if TrainTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
...@@ -169,15 +184,8 @@ def main(): ...@@ -169,15 +184,8 @@ def main():
test_ppl = np.exp([min(test_avg_cost, 100)]) test_ppl = np.exp([min(test_avg_cost, 100)])
return test_avg_cost, test_ppl return test_avg_cost, test_ppl
def set_util_input(input_name_value):
tensor = fluid.global_scope().find_var(input_name_value[0]).get_tensor()
tensor.set(input_name_value[1], place)
# Initialize the parameters. # Initialize the parameters.
exe.run(fluid.framework.default_startup_program()) exe.run(fluid.framework.default_startup_program())
for pos_enc_param_name in pos_enc_param_names:
set_util_input((pos_enc_param_name, position_encoding_init(
ModelHyperParams.max_length + 1, ModelHyperParams.d_model)))
data_input_names = encoder_data_input_fields + decoder_data_input_fields[: data_input_names = encoder_data_input_fields + decoder_data_input_fields[:
-1] + label_data_names -1] + label_data_names
...@@ -188,19 +196,37 @@ def main(): ...@@ -188,19 +196,37 @@ def main():
loss_name=avg_cost.name loss_name=avg_cost.name
if TrainTaskConfig.use_avg_cost else sum_cost.name) if TrainTaskConfig.use_avg_cost else sum_cost.name)
local_scopes = train_exe.executor.local_scopes()
dev_count = fluid.core.get_cuda_device_count()
for pos_enc_param_name in pos_enc_param_names:
tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model)
for place_id, local_scope in enumerate(local_scopes):
local_scope.find_var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id))
train_data = read_multiple(reader=train_data, count=dev_count)
for pass_id in xrange(TrainTaskConfig.pass_num): for pass_id in xrange(TrainTaskConfig.pass_num):
pass_start_time = time.time() pass_start_time = time.time()
for batch_id, data in enumerate(train_data()): for batch_id, data in enumerate(train_data()):
for place_id, data_buffer, local_scope in zip(range(len(data)), data, local_scopes):
data_input_dict, util_input_dict = prepare_batch_input( data_input_dict, util_input_dict = prepare_batch_input(
data, data_input_names, util_input_names, data_buffer, data_input_names, util_input_names,
ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx,
ModelHyperParams.n_head, ModelHyperParams.d_model) ModelHyperParams.n_head, ModelHyperParams.d_model)
map(set_util_input,
zip(util_input_dict.keys() + [lr_scheduler.learning_rate.name], local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set(
util_input_dict.values() + lr_scheduler.update_learning_rate(),
[lr_scheduler.update_learning_rate()])) fluid.CUDAPlace(place_id))
outs = train_exe.run(feed_dict=data_input_dict,
fetch_list=[sum_cost.name, token_num.name]) for var_name in data_input_dict:
local_scope.find_var(var_name).get_tensor().set(data_input_dict[var_name],
fluid.CUDAPlace(place_id))
for var_name in util_input_dict:
local_scope.find_var(var_name).get_tensor().set(util_input_dict[var_name],
fluid.CUDAPlace(place_id))
outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name])
sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1]) sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1])
total_sum_cost = sum_cost_val.sum( total_sum_cost = sum_cost_val.sum(
) # sum the cost from multi devices ) # sum the cost from multi devices
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册