From 883622f271bfef5e5ba9161fa511ed1181e9764f Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Tue, 17 Apr 2018 14:11:50 +0800 Subject: [PATCH] Clean code --- .../transformer/train.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/fluid/neural_machine_translation/transformer/train.py b/fluid/neural_machine_translation/transformer/train.py index 228c0cde..14c9acdf 100644 --- a/fluid/neural_machine_translation/transformer/train.py +++ b/fluid/neural_machine_translation/transformer/train.py @@ -196,39 +196,30 @@ def main(): loss_name=avg_cost.name if TrainTaskConfig.use_avg_cost else sum_cost.name) - dev_count = fluid.core.get_cuda_device_count() - - for pos_enc_param_name in pos_enc_param_names: - tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model) - for place_id in xrange(dev_count): - local_scope = train_exe.executor.local_scope(place_id) - local_scope.var(pos_enc_param_name).get_tensor().set(tensor, fluid.CUDAPlace(place_id)) - - train_data = read_multiple(reader=train_data, count=dev_count) + train_data = read_multiple(reader=train_data, count=train_exe.device_count) for pass_id in xrange(TrainTaskConfig.pass_num): pass_start_time = time.time() for batch_id, data in enumerate(train_data()): + data_on_devices = [] + for place_id, data_buffer in enumerate(data): data_input_dict, util_input_dict = prepare_batch_input( data_buffer, data_input_names, util_input_names, ModelHyperParams.eos_idx, ModelHyperParams.eos_idx, ModelHyperParams.n_head, ModelHyperParams.d_model) - local_scope = train_exe.executor.local_scope(place_id) - - local_scope.find_var(lr_scheduler.learning_rate.name).get_tensor().set( - lr_scheduler.update_learning_rate(), - fluid.CUDAPlace(place_id)) + data_input_dict.update(util_input_dict) + data_input_dict.update({ + lr_scheduler.learning_rate.name: lr_scheduler.update_learning_rate() + }) - for var_name in data_input_dict: - local_scope.var(var_name).get_tensor().set(data_input_dict[var_name], - fluid.CUDAPlace(place_id)) + for pos_enc_param_name in pos_enc_param_names: + tensor = position_encoding_init(ModelHyperParams.max_length + 1, ModelHyperParams.d_model) + data_input_dict[pos_enc_param_name] = tensor - for var_name in util_input_dict: - local_scope.var(var_name).get_tensor().set(util_input_dict[var_name], - fluid.CUDAPlace(place_id)) + data_on_devices.append(data_input_dict) - outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name]) + outs = train_exe.run(fetch_list=[sum_cost.name, token_num.name], feed=data_on_devices) sum_cost_val, token_num_val = np.array(outs[0]), np.array(outs[1]) total_sum_cost = sum_cost_val.sum( ) # sum the cost from multi devices -- GitLab