diff --git a/benchmark/fluid/args.py b/benchmark/fluid/args.py index ed696e82f8723eba573e8affd3f25e2aa6426e63..0d5c9652de6b814627e54018366137e214726619 100644 --- a/benchmark/fluid/args.py +++ b/benchmark/fluid/args.py @@ -140,5 +140,11 @@ def parse_args(): '--use_lars', action='store_true', help='If set, use lars for optimizers, ONLY support resnet module.') + parser.add_argument( + '--reduce_strategy', + type=str, + choices=['reduce', 'all_reduce'], + default='all_reduce', + help='Specify the reduce strategy, can be reduce, all_reduce') args = parser.parse_args() return args diff --git a/benchmark/fluid/fluid_benchmark.py b/benchmark/fluid/fluid_benchmark.py index 25622ee06c69e13181f34dfffadd5e299d31c8a8..ddd9fe809853a830ca676cc98f1819f683866def 100644 --- a/benchmark/fluid/fluid_benchmark.py +++ b/benchmark/fluid/fluid_benchmark.py @@ -170,6 +170,14 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, strategy = fluid.ExecutionStrategy() strategy.num_threads = args.cpus strategy.allow_op_delay = False + build_strategy = fluid.BuildStrategy() + if args.reduce_strategy == "reduce": + build_strategy.reduce_strategy = fluid.BuildStrategy( + ).ReduceStrategy.Reduce + else: + build_strategy.reduce_strategy = fluid.BuildStrategy( + ).ReduceStrategy.AllReduce + avg_loss = train_args[0] if args.update_method == "pserver": @@ -184,6 +192,7 @@ def train_parallel(train_args, test_args, args, train_prog, test_prog, avg_loss.name, main_program=train_prog, exec_strategy=strategy, + build_strategy=build_strategy, num_trainers=num_trainers, trainer_id=trainer_id) diff --git a/benchmark/fluid/models/mnist.py b/benchmark/fluid/models/mnist.py index cef8657ee629dcbc19221fd3440844a56627e920..f123e07fb711bd8ff67c1ecf5ec9a02c1e79eb1d 100644 --- a/benchmark/fluid/models/mnist.py +++ b/benchmark/fluid/models/mnist.py @@ -67,11 +67,14 @@ def cnn_model(data): def get_model(args, is_train, main_prog, startup_prog): # NOTE: mnist is small, we don't implement data sharding yet. - filelist = [ - os.path.join(args.data_path, f) for f in os.listdir(args.data_path) - ] + opt = None + data_file_handle = None with fluid.program_guard(main_prog, startup_prog): if args.use_reader_op: + filelist = [ + os.path.join(args.data_path, f) + for f in os.listdir(args.data_path) + ] data_file_handle = fluid.layers.open_files( filenames=filelist, shapes=[[-1, 1, 28, 28], (-1, 1)], @@ -100,7 +103,7 @@ def get_model(args, is_train, main_prog, startup_prog): if is_train: opt = fluid.optimizer.AdamOptimizer( learning_rate=0.001, beta1=0.9, beta2=0.999) - opt.minimize() + opt.minimize(avg_cost) if args.memory_optimize: fluid.memory_optimize(main_prog) diff --git a/paddle/fluid/framework/details/all_reduce_op_handle.cc b/paddle/fluid/framework/details/all_reduce_op_handle.cc index bf493a3fa44e48deec734250d04b2a413c3ed9da..8450d8eb8b057d18d2e004964d1e25b32c142823 100644 --- a/paddle/fluid/framework/details/all_reduce_op_handle.cc +++ b/paddle/fluid/framework/details/all_reduce_op_handle.cc @@ -46,7 +46,12 @@ AllReduceOpHandle::AllReduceOpHandle(ir::Node *node, #endif void AllReduceOpHandle::RunImpl() { - platform::RecordEvent r("all_reduce", nullptr); + if (dev_ctxes_.size() > 0UL) { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + } else { + platform::RecordEvent record_event(Name(), nullptr); + } + if (NoDummyInputSize() == 1) { return; // No need to all reduce when GPU count = 1; } else { diff --git a/paddle/fluid/framework/details/broadcast_op_handle.cc b/paddle/fluid/framework/details/broadcast_op_handle.cc index 1d9f1bd6e417e30f0799f0bbed1739cedb4e8fbf..35962ade99c024235d2bb4a203bec52bf4ef2063 100644 --- a/paddle/fluid/framework/details/broadcast_op_handle.cc +++ b/paddle/fluid/framework/details/broadcast_op_handle.cc @@ -15,12 +15,19 @@ #include "paddle/fluid/framework/details/broadcast_op_handle.h" #include "paddle/fluid/framework/details/container_cast.h" #include "paddle/fluid/framework/details/variable_visitor.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { namespace details { void BroadcastOpHandle::RunImpl() { + if (dev_ctxes_.size() > 0UL) { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + } else { + platform::RecordEvent record_event(Name(), nullptr); + } + if (places_.size() == 1) return; // The input and output may have dummy vars. diff --git a/paddle/fluid/framework/details/data_balance_op_handle.cc b/paddle/fluid/framework/details/data_balance_op_handle.cc index 525d24322442ef4dd6e8c24212af61c908959b87..91f6a42e6eab29ca5aa2ddf0442e6706f66e205f 100644 --- a/paddle/fluid/framework/details/data_balance_op_handle.cc +++ b/paddle/fluid/framework/details/data_balance_op_handle.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/framework/details/data_balance_op_handle.h" #include #include "paddle/fluid/framework/details/container_cast.h" +#include "paddle/fluid/platform/profiler.h" namespace paddle { namespace framework { @@ -86,6 +87,12 @@ std::vector> DataBalanceOpHandle::GetBalancePlan( } void DataBalanceOpHandle::RunImpl() { + if (dev_ctxes_.size() > 0UL) { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + } else { + platform::RecordEvent record_event(Name(), nullptr); + } + PADDLE_ENFORCE_GT(places_.size(), 1, "Data balance can only be enabled when the number of " "places to run larger than 1."); diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.cc b/paddle/fluid/framework/details/multi_devices_graph_pass.cc index 7a99169849debcbc57d6f197b36c5045b211f3ef..cd6c8b50a9370428e6fbd79dec81c409cab8fc64 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.cc +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.cc @@ -348,14 +348,31 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( size_t cur_device_id = 0; bool is_forwarding = true; + bool is_dist_train = false; for (ir::Node *node : sorted_ops) { if (boost::get( node->Op()->GetAttr(OpProtoAndCheckerMaker::OpRoleAttrName())) == static_cast(OpRole::kRPC)) { - CreateRPCOp(&result, node); + int op_dev_id = CreateRPCOp(&result, node); + PADDLE_ENFORCE(op_dev_id != -1, + "Can not schedule the RPC operator to the right place."); + if (node->Op()->Type() == "recv") { + auto recv_vars_attr = + boost::get>(node->Op()->GetNullableAttr( + OpProtoAndCheckerMaker::OpRoleVarAttrName())); + PADDLE_ENFORCE(recv_vars_attr.size() == 2UL); // [parameter, gradient] + if (recv_vars_attr[0].find(".block") == std::string::npos) { + bcast_var_name_set[op_dev_id].emplace(recv_vars_attr[0]); + } + } + is_dist_train = true; } else if (IsDistTrainOp(node, send_vars, recv_vars)) { - CreateDistTrainOp(&result, node); + int op_dev_id = CreateDistTrainOp(&result, node); + if (node->Op()->Type() == "concat") { + auto origin_param_name = node->Op()->OutputArgumentNames()[0]; + bcast_var_name_set[op_dev_id].emplace(origin_param_name); + } } else if (IsScaleLossOp(node)) { // user can customize loss@grad if not use_default_grad_scale_ if (strategy_.gradient_scale_ != @@ -414,7 +431,10 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( CreateReduceOp(&result, g_name, cur_device_id); graph->Get(kShardedVarDevice) .emplace(g_name, cur_device_id); - bcast_var_name_set[cur_device_id].emplace(p_name); + if (!is_dist_train) { + // will send gradients directly when distributed training + bcast_var_name_set[cur_device_id].emplace(p_name); + } break; case BuildStrategy::ReduceStrategy::kAllReduce: if (IsSparseGradient(g_name)) { @@ -436,14 +456,14 @@ std::unique_ptr MultiDevSSAGraphBuilder::ApplyImpl( } } } - bool use_gpu = false; #ifdef PADDLE_WITH_CUDA use_gpu = nccl_ctxs_ != nullptr; #endif - if (use_gpu || - strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce) { + if ((use_gpu && + strategy_.reduce_ == BuildStrategy::ReduceStrategy::kReduce) || + is_dist_train) { // Insert BCast Ops for (size_t dev_id = 0; dev_id < bcast_var_name_set.size(); ++dev_id) { auto &to_bcast_set = bcast_var_name_set[dev_id]; @@ -676,8 +696,8 @@ VarHandle *MultiDevSSAGraphBuilder::CreateReduceOp(ir::Graph *result, return var; } -void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, - ir::Node *node) const { +int MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, + ir::Node *node) const { int op_dev_id = -1; std::vector input_var_names; std::vector output_var_names; @@ -720,6 +740,7 @@ void MultiDevSSAGraphBuilder::CreateDistTrainOp(ir::Graph *result, node->Op()->Type()); CreateComputationalOp(result, node, op_dev_id); + return op_dev_id; } void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { @@ -738,8 +759,8 @@ void SetOpInputsAllPlaces(ir::Graph *result, ir::Node *node, int num_places) { } // Create RPC related op handles that connects its in ops and out ops. -void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, - ir::Node *node) const { +int MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, + ir::Node *node) const { int op_dev_id = -1; if (node->Op()->Type() == "send") { // TODO(paddle-dev): getting the first var is not safe. @@ -825,6 +846,7 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(ir::Graph *result, CreateOpOutput(result, op_handle, new_node, p, outvar_dev_id); } } + return op_dev_id; } bool MultiDevSSAGraphBuilder::IsScaleLossOp(ir::Node *node) const { diff --git a/paddle/fluid/framework/details/multi_devices_graph_pass.h b/paddle/fluid/framework/details/multi_devices_graph_pass.h index ac6d9c5a64cfde60f75c76dae0a30cc7d735e996..1ca8c4b855f9468589e537245380451a91a50b14 100644 --- a/paddle/fluid/framework/details/multi_devices_graph_pass.h +++ b/paddle/fluid/framework/details/multi_devices_graph_pass.h @@ -54,8 +54,8 @@ class MultiDevSSAGraphBuilder : public ir::Pass { bool IsScaleLossOp(ir::Node *node) const; - void CreateRPCOp(ir::Graph *result, ir::Node *node) const; - void CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; + int CreateRPCOp(ir::Graph *result, ir::Node *node) const; + int CreateDistTrainOp(ir::Graph *result, ir::Node *node) const; /** * Is this operator as the end-point operator before/after send operator. diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 6c7e5c1fb06620b1c071b00fcfcc1b4a29bf8d62..878828693bcbf637d8456bc1a63915bcde1f774b 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -27,7 +27,11 @@ namespace framework { namespace details { void ReduceOpHandle::RunImpl() { - platform::RecordEvent r("reduce", nullptr); + if (dev_ctxes_.size() > 0UL) { + platform::RecordEvent record_event(Name(), dev_ctxes_.begin()->second); + } else { + platform::RecordEvent record_event(Name(), nullptr); + } if (places_.size() == 1) return; // the input and output may have dummy var. auto in_var_handles = DynamicCast(inputs_); diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc index 609e18581957f62b040e04e937873b7a8fa5785a..ba243979b34aa1f683de707525403becaf0a1c00 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.cc @@ -51,7 +51,7 @@ void ScaleLossGradOpHandle::RunImpl() { ->stream(); memory::Copy(boost::get(place_), tmp, platform::CPUPlace(), &coeff_, sizeof(float), stream); - VLOG(1) << place_ << "RUN Scale loss grad op"; + VLOG(10) << place_ << "RUN Scale loss grad op"; }); #endif } diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 20fc08e21dadc12be8903476df374abb5caecf61..8bc30fc123163983f4bddc19af489920db93e0c0 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -683,7 +683,6 @@ All parameter, weight, gradient are variables in Paddle. const std::string &, Scope *, std::vector &, const ExecutionStrategy &, const BuildStrategy &, size_t, size_t>()) - .def("_bcast_params", &ParallelExecutor::BCastParamsToDevices) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* diff --git a/python/paddle/fluid/parallel_executor.py b/python/paddle/fluid/parallel_executor.py index 4790e0f6119e96b11b049bfdd3b46d40a382683b..058f414e9b7130bc6ff68c4b9d7d65ca7db300cc 100644 --- a/python/paddle/fluid/parallel_executor.py +++ b/python/paddle/fluid/parallel_executor.py @@ -279,21 +279,11 @@ class ParallelExecutor(object): self.executor.run(fetch_list, fetch_var_name) arr = self.scope.find_var(fetch_var_name).get_lod_tensor_array() - if self.is_dist: - self._bcast_params() - if return_numpy: return executor.as_numpy(arr) return [arr[i] for i in range(len(arr))] - def _bcast_params(self): - """ - Broadcast the parameters to other devices. It is used during - distributed training. - """ - self.executor._bcast_params(set(self.persistable_vars)) - @property def device_count(self): return len(self._act_places)