/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/framework/parallel_executor.h" #include #include #include #include #include #include #include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/details/async_ssa_graph_executor.h" #include "paddle/fluid/framework/details/bind_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/fast_threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/multi_devices_helper.h" #include "paddle/fluid/framework/details/op_handle_base.h" #include "paddle/fluid/framework/details/parallel_ssa_graph_executor.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/memory_optimization_var_info.h" #include "paddle/fluid/framework/ir/memory_optimize_pass/reference_count_pass_helper.h" #include "paddle/fluid/framework/ir/multi_devices_graph_pass/set_reader_device_info_utils.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/fluid/platform/event.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler/event_tracing.h" #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #include "paddle/fluid/platform/cuda_device_guard.h" #endif DECLARE_double(eager_delete_tensor_gb); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) DECLARE_bool(sync_nccl_allreduce); #endif #ifdef WITH_GPERFTOOLS #include "gperftools/profiler.h" #endif PADDLE_DEFINE_EXPORTED_string( pe_profile_fname, "", "Profiler filename for PE, which generated by gperftools." "Only valid when compiled `WITH_PRIFILER=ON`. Empty if disable."); PADDLE_DEFINE_EXPORTED_bool( enable_parallel_graph, false, "Force disable parallel graph execution mode if set false."); namespace paddle { namespace framework { static std::once_flag gProfileOnce; #ifdef WITH_GPERFTOOLS static bool gProfileStarted = false; #endif #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::once_flag p2p_init_flag; #endif class ParallelExecutorPrivate { public: ParallelExecutorPrivate(const std::vector &places, Scope *global_scope) : places_(places), global_scope_(global_scope) { if (!FLAGS_pe_profile_fname.empty()) { std::call_once(gProfileOnce, [] { #ifdef WITH_GPERFTOOLS ProfilerStart(FLAGS_pe_profile_fname.c_str()); gProfileStarted = true; #else LOG(WARNING) << "Paddle is not compiled with gperftools. " "FLAGS_pe_profile_fname will be ignored"; #endif }); } } ~ParallelExecutorPrivate() { if (own_local_scope_) { for (size_t i = 1; i < local_scopes_.size(); ++i) { // Skip the first scope, since it is the global scope. Scope *local_scope = local_scopes_[i]; if (global_scope_->HasKid(local_scope)) { global_scope_->DeleteScope(local_scope); } } } } bool IsUseCUDA(DeviceType use_device); void SetHasFeed(size_t dev_idx, bool has_feed = true); bool AllowPartialFeed() const; ir::Graph *ApplyMemoryOptimizePass(ir::Graph *graph); inline bool HasGarbageCollectors() const { return !gcs_.empty(); } void ApplyFixOpRunOrderPass(ir::Graph *graph) { if (build_strategy_.fix_op_run_order_) { auto pass = ir::PassRegistry::Instance().Get("fix_op_run_order_pass"); pass->Apply(graph); } } /** * NOTE(zengjinle): the fed variables of users should not be reused, * because users may feed them into another network. Changing the fed * variables that users can visit may cause calculation wrong, which is * a very subtle bug when traning networks. However, these variables * can be garbage collected. * * ParallelExecutor provides 2 methods to feed variables: * * - FeedTensorsIntoLocalScopes: this method would share memory of fed * variables, so we have to skip these. * * - FeedAndSplitTensorIntoLocalScopes: this method would copy data of fed * variables, so we do not need to skip * them. */ inline void SetSkipMemoryReuse(size_t scope_idx, const std::string &name) { if (mem_opt_var_infos_.size() == 0) { VLOG(4) << "The mem_opt_var_infos_ is empty, maybe no memory " "optimization strategy is enabled"; return; } auto iter = mem_opt_var_infos_[scope_idx].find(name); if (iter != mem_opt_var_infos_[scope_idx].end()) { iter->second->SetSkipMemoryReuse(true); } } #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) void InitNCCLCtxs(framework::Scope *scope, const BuildStrategy &bst) { VLOG(1) << "nccl comm num:" << bst.nccl_comm_num_ << ", nranks:" << nranks_ << ", num_trainers:" << bst.num_trainers_ << ", trainer_id:" << bst.trainer_id_; if (bst.use_hierarchical_allreduce_) { VLOG(1) << ", use_hierarchical_allreduce:" << bst.use_hierarchical_allreduce_ << ", inter_trainers_num:" << bst.hierarchical_allreduce_inter_nranks_ << ", exter_trainers_num:" << bst.hierarchical_allreduce_exter_nranks_; } std::vector flat_nccl_ids; if (nranks_ == 1) { // FIXME(gongwb): need not to create ncclid when nranks==1 nccl_ctxs_->InitFlatCtxs( places_, flat_nccl_ids, bst.num_trainers_, bst.trainer_id_); return; } if (bst.enable_parallel_graph_) { VLOG(1) << "use only one ncclid in pg model"; ncclUniqueId *nccl_id = nullptr; std::string var_name = platform::GetFlatNCCLVarName(0); auto nccl_id_var = scope->FindVar(var_name); if (nccl_id_var) { nccl_id = nccl_id_var->GetMutable(); VLOG(10) << "find nccl_id_var:" << var_name << ", nccl_id:" << nccl_id; } else { nccl_id = new ncclUniqueId(); PADDLE_ENFORCE_EQ( platform::dynload::ncclGetUniqueId(nccl_id), ncclSuccess, platform::errors::PreconditionNotMet( "PaddlePaddle failed to get NCCL unique ID. It may due to your " "system settings or NCCL library error, please debug on NCCL")); VLOG(10) << "can't find nccl_id_var:" << var_name << ", nccl_id:" << nccl_id; } flat_nccl_ids.push_back(nccl_id); nccl_ctxs_->InitFlatCtxs( places_, flat_nccl_ids, bst.num_trainers_, bst.trainer_id_); VLOG(1) << "init bst nccl context complete!"; return; } // num_trainers ==1 && places > 1 if (bst.num_trainers_ == 1) { nccl_ctxs_->InitFlatCtxs( places_, flat_nccl_ids, bst.num_trainers_, bst.trainer_id_); return; } for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetFlatNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL( nccl_id_var, platform::errors::NotFound("Can't find nccl_id_var '%s'.", var_name)); auto nccl_id = nccl_id_var->GetMutable(); flat_nccl_ids.push_back(nccl_id); } nccl_ctxs_->InitFlatCtxs( places_, flat_nccl_ids, bst.num_trainers_, bst.trainer_id_); if (bst.use_hierarchical_allreduce_) { std::vector inter_nccl_ids; for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetHierarchicalInterNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(nccl_id_var, platform::errors::NotFound( "Can't find nccl_id_var '%s'.", var_name)); auto inter_nccl_id = nccl_id_var->GetMutable(); inter_nccl_ids.push_back(inter_nccl_id); } std::vector exter_nccl_ids; for (int i = 0; i < static_cast(bst.nccl_comm_num_); i++) { std::string var_name = platform::GetHierarchicalExterNCCLVarName(i); auto nccl_id_var = scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL(nccl_id_var, platform::errors::NotFound( "Can't find nccl_id_var '%s'.", var_name)); auto nccl_id = nccl_id_var->GetMutable(); exter_nccl_ids.push_back(nccl_id); } nccl_ctxs_->InitHierarchicalCtxs( places_, inter_nccl_ids, exter_nccl_ids, bst.num_trainers_, bst.trainer_id_, bst.hierarchical_allreduce_inter_nranks_, bst.hierarchical_allreduce_exter_nranks_); } } void InitOrGetNCCLCommunicator(framework::Scope *scope, BuildStrategy *bst) { const std::string var_name = "NCCLCommunicator"; auto var = scope->FindVar(var_name); if (var != nullptr) { PADDLE_ENFORCE_EQ(var->IsInitialized(), true, platform::errors::PreconditionNotMet( "if %s exists, it must be initialized", var_name)); VLOG(1) << "find " << var_name << " in scope, so use it and does not recreate!"; nccl_ctxs_ = var->GetMutable(); return; } if (bst->use_hierarchical_allreduce_) { PADDLE_ENFORCE_GT( bst->num_trainers_, 1, platform::errors::PreconditionNotMet( "The num_trainers should be greater than 1, but received %llu.", bst->num_trainers_)); PADDLE_ENFORCE_GT( bst->hierarchical_allreduce_inter_nranks_, 1, platform::errors::PreconditionNotMet( "The inter_nranks should be greater than 1, but received %d.", bst->hierarchical_allreduce_inter_nranks_)); PADDLE_ENFORCE_EQ( bst->num_trainers_ % bst->hierarchical_allreduce_inter_nranks_, 0, platform::errors::PreconditionNotMet( "num_trainers:%llu mod inter_nranks:%d != 0", bst->num_trainers_, bst->hierarchical_allreduce_inter_nranks_)); bst->hierarchical_allreduce_exter_nranks_ = bst->num_trainers_ / bst->hierarchical_allreduce_inter_nranks_; } VLOG(1) << "not find " << var_name << " in scope, so recreate it!"; nccl_ctxs_ = scope->Var(var_name)->GetMutable(); InitNCCLCtxs(scope, *bst); } #endif #if defined(PADDLE_WITH_XPU_BKCL) void InitBKCLCtxs(framework::Scope *scope, const BuildStrategy &bst) { VLOG(1) << "bkcl comm num:" << bst.bkcl_comm_num_ << ", nranks:" << nranks_ << ", num_trainers:" << bst.num_trainers_ << ", trainer_id:" << bst.trainer_id_; PADDLE_ENFORCE_EQ(bst.use_hierarchical_allreduce_, false, platform::errors::Unimplemented( "xpu doesn't support use_hierarchical_allreduce")); std::vector flat_bkcl_ids; if (nranks_ == 1) { // FIXME(gongwb): need not to create bkclid when nranks==1 bkcl_ctxs_->InitFlatCtxs( places_, flat_bkcl_ids, bst.num_trainers_, bst.trainer_id_); return; } if (bst.enable_parallel_graph_) { VLOG(1) << "use only one bkclid in pg model"; BKCLUniqueId *bkcl_id = nullptr; std::string var_name = platform::GetFlatBKCLVarName(0); auto bkcl_id_var = scope->FindVar(var_name); std::unique_ptr id(new BKCLUniqueId()); if (bkcl_id_var) { bkcl_id = bkcl_id_var->GetMutable(); } else { PADDLE_ENFORCE_EQ( bkcl_get_unique_id(id.get()), BKCL_SUCCESS, platform::errors::Unavailable("bkcl get unique id failed")); bkcl_id = id.get(); } flat_bkcl_ids.push_back(bkcl_id); bkcl_ctxs_->InitFlatCtxs( places_, flat_bkcl_ids, bst.num_trainers_, bst.trainer_id_); VLOG(1) << "init bst bkcl context complete!"; return; } // num_trainers ==1 && places > 1 if (bst.num_trainers_ == 1) { bkcl_ctxs_->InitFlatCtxs( places_, flat_bkcl_ids, bst.num_trainers_, bst.trainer_id_); return; } for (int i = 0; i < static_cast(bst.bkcl_comm_num_); i++) { std::string var_name = platform::GetFlatBKCLVarName(i); auto bkcl_id_var = scope->FindVar(var_name); PADDLE_ENFORCE_NOT_NULL( bkcl_id_var, platform::errors::NotFound("can't find %s bkcl_id_var", var_name)); auto bkcl_id = bkcl_id_var->GetMutable(); flat_bkcl_ids.push_back(bkcl_id); } bkcl_ctxs_->InitFlatCtxs( places_, flat_bkcl_ids, bst.num_trainers_, bst.trainer_id_); } void InitOrGetBKCLCommunicator(framework::Scope *scope, const BuildStrategy &bst) { const std::string var_name = "BKCLCommunicator"; auto var = scope->FindVar(var_name); if (var != nullptr) { PADDLE_ENFORCE_EQ(var->IsInitialized(), true, platform::errors::PreconditionNotMet( "if %s exists, it must be initialized", var_name)); VLOG(1) << "find " << var_name << " in scope, so use it and does not recreate!"; bkcl_ctxs_ = var->GetMutable(); return; } VLOG(1) << "not find " << var_name << " in scope, so recreate it!"; bkcl_ctxs_ = scope->Var(var_name)->GetMutable(); InitBKCLCtxs(scope, bst); } #endif inline bool IsPersistable(const std::string &name) const { auto iter = is_persistable_.find(name); return iter != is_persistable_.end() && iter->second; } BuildStrategy build_strategy_; std::vector places_; std::vector local_scopes_; std::vector local_exec_scopes_; Scope *global_scope_; // not owned std::unique_ptr executor_; std::unordered_map is_persistable_; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) platform::NCCLCommunicator *nccl_ctxs_{nullptr}; #elif defined(PADDLE_WITH_XPU_BKCL) platform::BKCLCommunicator *bkcl_ctxs_{nullptr}; #endif bool own_local_scope_; DeviceType use_device_; bool use_all_reduce_; size_t nranks_; ir::MemOptVarInfoMapList mem_opt_var_infos_; ir::GarbageCollectorMap gcs_; details::ParallelSSAGraphExecutor *inference_executor_{nullptr}; }; bool ParallelExecutorPrivate::IsUseCUDA(DeviceType use_device) { return use_device == p::kCUDA; } void ParallelExecutorPrivate::SetHasFeed(size_t dev_idx, bool has_feed) { if (inference_executor_) { inference_executor_->SetHasFeed(dev_idx, has_feed); } } bool ParallelExecutorPrivate::AllowPartialFeed() const { return inference_executor_ && inference_executor_->SupportPartialFeed(); } ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) { /** * NOTE(zengjinle): If BuildStrategy.memory_optimize = None in Python, * set BuildStrategy.memory_optimize according to whether gc is enabled. * If gc is enabled, BuildStrategy.memory_optimize = False. * If gc is disabled, BuildStrategy.memory_optimize = True. * This is because gc+memory_optimize is worse than gc only. * * As an option, users can enable BuildStrategy.memory_optimize forcely * by setting True, and disable it forcely by setting False. */ bool is_gc_enabled = (GetEagerDeletionThreshold() >= 0); if (!build_strategy_.memory_optimize_) { build_strategy_.memory_optimize_ = !is_gc_enabled; } bool need_mem_opt = build_strategy_.enable_inplace_ || build_strategy_.enable_addto_ || build_strategy_.memory_optimize_.get() || is_gc_enabled; if (!need_mem_opt) return graph; std::vector last_live_ops_of_vars; auto ref_cnt_pass = ir::PassRegistry::Instance().Get("reference_count_pass"); ref_cnt_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); ref_cnt_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); graph = ref_cnt_pass->Apply(graph); VLOG(10) << "ReferenceCountPass Applied"; if (build_strategy_.enable_addto_) { auto addto_pass = ir::PassRegistry::Instance().Get("inplace_addto_op_pass"); addto_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); addto_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); addto_pass->Set(ir::kUseCuda, new bool(use_device_ == p::kCUDA)); VLOG(10) << "Start to apply inplace_addto_op_pass"; graph = addto_pass->Apply(graph); VLOG(10) << "inplace_addto_op_pass Applied"; } if (build_strategy_.enable_inplace_) { auto inplace_pass = ir::PassRegistry::Instance().Get("buffer_shared_inplace_pass"); inplace_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); inplace_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); inplace_pass->Set(ir::kUseCuda, new bool(use_device_ == p::kCUDA)); VLOG(10) << "Start to apply buffer_shared_inplace_pass"; graph = inplace_pass->Apply(graph); VLOG(10) << "buffer_shared_inplace_pass Applied"; VLOG(1) << "Inplace strategy is enabled, when " "build_strategy.enable_inplace = True"; } if (build_strategy_.memory_optimize_.get()) { auto cross_op_memory_reuse_pass = ir::PassRegistry::Instance().Get( "buffer_shared_cross_op_memory_reuse_pass"); cross_op_memory_reuse_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); cross_op_memory_reuse_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); cross_op_memory_reuse_pass->Set(ir::kUseCuda, new bool(use_device_ == p::kCUDA)); VLOG(10) << "Start to apply buffer_shared_cross_op_memory_reuse_pass"; graph = cross_op_memory_reuse_pass->Apply(graph); VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied"; LOG(INFO) << "Cross op memory reuse strategy is enabled, when " "build_strategy.memory_optimize = True or garbage collection " "strategy is disabled, which is not recommended"; } if (!is_gc_enabled) { return graph; } size_t max_memory_size = static_cast(GetEagerDeletionThreshold()); for (size_t i = 0; i < places_.size(); ++i) { auto &place = places_[i]; if (gcs_.count(place) > 0) { continue; } std::unique_ptr gc; if (platform::is_gpu_place(place)) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (IsFastEagerDeletionModeEnabled()) { gc.reset(new UnsafeFastGPUGarbageCollector(place, max_memory_size)); } else { gc.reset(new StreamGarbageCollector(place, max_memory_size)); } VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use CUDA device since it's not compiled with CUDA," "Please recompile or reinstall Paddle with GPU support.")); #endif } else if (platform::is_mlu_place(place)) { #ifdef PADDLE_WITH_MLU if (IsFastEagerDeletionModeEnabled()) { gc.reset(new MLUUnsafeFastGarbageCollector(place, max_memory_size)); } else { gc.reset(new MLUStreamGarbageCollector(place, max_memory_size)); } VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use MLU device since it's not compiled with MLU," "Please recompile or reinstall Paddle with MLU support.")); #endif } else if (platform::is_xpu_place(place)) { #if defined(PADDLE_WITH_XPU) gc.reset(new XPUGarbageCollector(place, max_memory_size)); VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use XPU device since it's not compiled with XPU," "Please recompile or reinstall Paddle with XPU support.")); #endif } else if (platform::is_ipu_place(place)) { #if defined(PADDLE_WITH_IPU) gc.reset(new IPUGarbageCollector(place, max_memory_size)); VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use IPU device since it's not compiled with IPU," "Please recompile or reinstall Paddle with IPU support.")); #endif } else if (platform::is_custom_place(place)) { #if defined(PADDLE_WITH_CUSTOM_DEVICE) if (IsFastEagerDeletionModeEnabled()) { gc.reset( new CustomDeviceUnsafeFastGarbageCollector(place, max_memory_size)); } else { gc.reset(new CustomStreamGarbageCollector(place, max_memory_size)); } VLOG(10) << "Created " << i << "-th GarbageCollector at " << place; #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use custom device since it's not compiled with " "CustomDevice," "Please recompile or reinstall Paddle with CustomDevice support.")); #endif } else if (platform::is_cpu_place(place)) { gc.reset(new CPUGarbageCollector(place, max_memory_size)); VLOG(10) << "Created GarbageCollector at " << place; } else { PADDLE_THROW(platform::errors::PreconditionNotMet( "Unsupported place for garbage collection")); } gcs_.emplace(place, std::move(gc)); } if (!gcs_.empty()) { auto eager_deletion_pass = ir::PassRegistry::Instance().Get("eager_deletion_pass"); eager_deletion_pass->SetNotOwned(ir::kMemOptVarInfoMapList, &mem_opt_var_infos_); eager_deletion_pass->SetNotOwned(ir::kGarbageCollector, &gcs_); eager_deletion_pass->SetNotOwned(ir::kLastLiveOpsOfVars, &last_live_ops_of_vars); eager_deletion_pass->SetNotOwned(ir::kAllPlaces, &places_); graph = eager_deletion_pass->Apply(graph); VLOG(10) << "EagerDeletionPass Applied"; VLOG(1) << "Garbage collection strategy is enabled, when " << "FLAGS_eager_delete_tensor_gb = " << FLAGS_eager_delete_tensor_gb; } return graph; } class ResetHasFeedGuard { public: explicit ResetHasFeedGuard(ParallelExecutorPrivate *pe_member) : pe_member_(pe_member) {} ~ResetHasFeedGuard() { for (size_t i = 0; i < pe_member_->places_.size(); ++i) { pe_member_->SetHasFeed(i, false); } } private: ParallelExecutorPrivate *pe_member_; }; size_t ParallelExecutor::DeviceCount() const { return member_->places_.size(); } std::vector &ParallelExecutor::GetLocalScopes() { return member_->local_scopes_; } void ParallelExecutor::DropLocalExeScopes() { auto executor = dynamic_cast( member_->executor_.get()); if (executor) { executor->DropLocalExeScopes(); } } bool ParallelExecutor::NeedCreateLocalExeScope() { auto executor = dynamic_cast( member_->executor_.get()); return executor && executor->NeedCreateLocalExeScope(); } void InitP2P(const std::vector &places) { #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) std::call_once(p2p_init_flag, [&]() { int count = places.size(); if (count <= 1) return; std::vector devices; for (int i = 0; i < count; i++) { if (!platform::is_gpu_place(places[i])) return; platform::CUDAPlace device = places[i]; devices.push_back(device.GetDeviceId()); } for (int i = 0; i < count; ++i) { for (int j = 0; j < count; ++j) { if (devices[i] == devices[j]) continue; int can_acess = -1; #ifdef PADDLE_WITH_HIP hipError_t ret = hipDeviceCanAccessPeer(&can_acess, devices[i], devices[j]); if (ret != hipSuccess || can_acess != 1) { #else cudaError_t ret = cudaDeviceCanAccessPeer(&can_acess, devices[i], devices[j]); if (ret != cudaSuccess || can_acess != 1) { #endif LOG(WARNING) << "Cannot enable P2P access from " << devices[i] << " to " << devices[j]; } else { platform::CUDADeviceGuard guard(devices[i]); #ifdef PADDLE_WITH_HIP hipDeviceEnablePeerAccess(devices[j], 0); #else cudaDeviceEnablePeerAccess(devices[j], 0); #endif } } } VLOG(1) << "init p2p"; }); #endif } ParallelExecutor::ParallelExecutor(const std::vector &places, const std::vector &bcast_vars, const std::string &loss_var_name, Scope *scope, const std::vector &local_scopes, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, ir::Graph *graph) : member_(new ParallelExecutorPrivate(places, scope)) { PADDLE_ENFORCE_EQ(places.size() > 0 && !platform::is_npu_place(places[0]), true, platform::errors::Unavailable( "NPU is not supported in ParallelExecutor.")); InitP2P(places); ir::InitReaderQueueDeviceCount( graph, *(member_->global_scope_), member_->places_.size()); // Initialize necessary info of member_ with strategy. InitExecutorPrivateMemberInfo( exec_strategy, build_strategy, places.size(), *graph); // Step 1. Create local scopes and Clone graph into multi device CreateLocalScopes(scope, local_scopes, /*create_new*/ true); std::vector graphs = CloneGraphToMultiDevices(graph); PrepareNCCLCommunicator(scope); // broadcast parameters from the 0th device to others: auto need_broadcast = [&]() -> bool { if (member_->build_strategy_.num_trainers_ > 1) { // 1. num_tariners would be grater than 1 for nccl distributed training. return true; } else if (member_->local_scopes_.size() != 1 && local_scopes.empty()) { // 2. Only one trainer process, but ParallelExecutor hold multiple // devices. return true; } return false; }; if (need_broadcast()) { BCastParamsToDevices(bcast_vars, member_->build_strategy_.trainer_id_); } // Step 2. Convert main_program to SSA form and dependency graph. Also, insert // ncclOp std::vector async_graphs = CompileGraphWithBuildStrategy(graph, &graphs, loss_var_name); PrepareForCUDAGraphCapture(graph); graph = member_->ApplyMemoryOptimizePass(graph); async_graphs[0] = graph; // Step 3. Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars std::vector var_infos; CreateVariableInfos(&var_infos, graph); std::unordered_map scope_map = CreateLocalExecScopes(member_->local_scopes_, /*create_new*/ true); // Step 4. Create SSAGraph executor std::vector final_graphs = CreateSSAGraphExecutor(exec_strategy, &async_graphs, graph); VLOG(3) << "use ScopeBufferedSSAGraphExecutor"; if (!member_->build_strategy_.async_mode_) { member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, std::move(var_infos), member_->places_, std::move(member_->executor_))); } ResetOpHandleScopeMapOfGraphs(final_graphs, scope_map); SetReaderOpDeviceInfoOfGraphs(final_graphs); } ParallelExecutor::ParallelExecutor(const platform::Place &place, Scope *scope, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, ir::Graph *graph) : member_(new ParallelExecutorPrivate({place}, scope)) { // Initialize necessary info of member_ with strategy. InitExecutorPrivateMemberInfo(exec_strategy, build_strategy, /*device_count=*/1, *graph); CreateLocalScopes(scope, /*local_scope=*/{scope}, /*create_new=*/false); // Apply BuildStrategy to compile graph. std::vector graphs = {graph}; std::vector async_graphs = CompileGraphWithBuildStrategy(graph, &graphs, /*loss_var_name=*/""); graph = member_->ApplyMemoryOptimizePass(graph); // Create vars in each scope. Passes may also create new vars. // skip control vars and empty vars CreateVariableInfos(&var_infos_, graph); // Create local execution scopes std::unordered_map scope_map = CreateLocalExecScopes(member_->local_scopes_, /*create_new=*/false); std::vector final_graphs = CreateSSAGraphExecutor(exec_strategy, &async_graphs, graph); // Set scope_map of op from each graph ResetOpHandleScopeMapOfGraphs(final_graphs, scope_map); } void ParallelExecutor::PrepareVariables(Scope *scope) { for (auto &info : var_infos_) { auto var = scope->FindVar(info.name_); if (var != nullptr) { VLOG(2) << info.name_ << " has been initialized beforehand in global scope, skipped."; continue; } framework::InitializeVariable(scope->Var(info.name_), info.type_); } } void ParallelExecutor::BCastParamsToDevices( const std::vector &vars, int trainer_id) const { VLOG(3) << "BCastParamsToDevices"; // the initializing bcast, all vars would be bcast from device(0). for (auto &var : vars) { framework::Variable *main_var = member_->local_scopes_[0]->FindVar(var); if (main_var == nullptr || !main_var->IsType()) { continue; } auto &main_tensor = main_var->Get(); if (!main_tensor.IsInitialized()) { VLOG(3) << "one in var not inited, return!"; continue; } auto &dims = main_tensor.dims(); if (paddle::platform::is_gpu_place(main_tensor.place())) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) std::vector buffers; buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); auto dtype = framework::TransToProtoVarType(main_tensor.dtype()); ncclDataType_t data_type = platform::ToNCCLDataType(dtype); for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; void *buffer; if (i == 0 && trainer_id == 0) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); t->Resize(dims); buffer = t->mutable_data(place, main_tensor.dtype()); } buffers.push_back(buffer); } PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), platform::errors::PreconditionNotMet( "variables' buffer size to bcast is %d, which is " "NOT equal to places size %d", buffers.size(), member_->places_.size())); if (member_->nccl_ctxs_ != nullptr) { auto *nccl_ctxs = member_->nccl_ctxs_->DefaultFlatCtx(); platform::NCCLGroupGuard guard; for (size_t i = 0; i < member_->places_.size(); ++i) { auto &nccl_ctx = nccl_ctxs->at(member_->places_[i]); platform::dynload::ncclBcast(buffers[i], numel, data_type, 0, nccl_ctx.comm_, nccl_ctx.stream()); } nccl_ctxs->WaitAll(); } else { auto src_place = member_->places_[0]; auto src_dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(src_place)); auto sizeof_dtype = framework::SizeOfType(dtype) * numel; for (size_t i = 1; i < member_->places_.size(); ++i) { auto dst_place = member_->places_[i]; auto dst_dev_ctx = static_cast( platform::DeviceContextPool::Instance().Get(dst_place)); src_dev_ctx->Wait(); dst_dev_ctx->Wait(); memory::Copy(dst_place, buffers[i], src_place, buffers[0], sizeof_dtype, src_dev_ctx->stream()); src_dev_ctx->Wait(); dst_dev_ctx->Wait(); } } #endif } else if (paddle::platform::is_xpu_place(main_tensor.place())) { #if defined(PADDLE_WITH_XPU_BKCL) std::vector buffers; buffers.reserve(member_->places_.size()); size_t numel = main_tensor.numel(); // TODO(liuyuhui): BKCL only support parameters using float type, // other parameters need to be strongly converted to float before // broadcasting, // but broadcast is equivalent to no type of operation, does not affect // correctness. BKCLDataType data_type = BKCL_FLOAT; // BKCLDataType data_type = // platform::ToBKCLDataType(framework::TransToProtoVarType(main_tensor.dtype())); for (size_t i = 0; i < member_->places_.size(); ++i) { auto place = member_->places_[i]; void *buffer; if (i == 0 && trainer_id == 0) { buffer = const_cast(main_tensor.data()); } else { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); t->Resize(dims); buffer = t->mutable_data(place, main_tensor.dtype()); } buffers.push_back(buffer); } PADDLE_ENFORCE_EQ(member_->places_.size(), buffers.size(), platform::errors::PreconditionNotMet( "variables' buffer size to bcast is %d, which is " "NOT equal to places size %d", buffers.size(), member_->places_.size())); { auto *bkcl_ctxs = member_->bkcl_ctxs_->DefaultFlatCtx(); PADDLE_ENFORCE_EQ( bkcl_group_start(), BKCL_SUCCESS, platform::errors::Unavailable("bkcl_group_start failed")); for (size_t i = 0; i < member_->places_.size(); ++i) { auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[i]); auto broadcast_numel = numel; if (framework::TransToProtoVarType(main_tensor.dtype()) == framework::proto::VarType::INT64) { broadcast_numel *= 2; } PADDLE_ENFORCE_EQ( bkcl_broadcast(bkcl_ctx.comm(), buffers[i], buffers[i], broadcast_numel, data_type, 0, NULL), BKCL_SUCCESS, platform::errors::Unavailable("bkcl_broadcast failed")); } PADDLE_ENFORCE_EQ( bkcl_group_end(), BKCL_SUCCESS, platform::errors::Unavailable("bkcl_group_end failed")); } #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with BKCL.")); #endif } else { platform::CPUPlace cpu; for (size_t i = 1; i < member_->places_.size(); ++i) { auto local_scope = member_->local_scopes_[i]; auto *t = local_scope->Var(var)->GetMutable(); auto copy_memory = [&] { t->Resize(dims); t->mutable_data(cpu, main_tensor.dtype()); paddle::framework::TensorCopy(main_tensor, cpu, t); }; auto share_memory = [&] { t->ShareDataWith(main_tensor); }; // FIXME(zcd): LR_DECAY_COUNTER should not be shared. This is a hot fix. if (member_->build_strategy_.async_mode_) { share_memory(); } else if (member_->use_all_reduce_ || member_->IsUseCUDA(member_->use_device_) || var == "@LR_DECAY_COUNTER@") { copy_memory(); } else { share_memory(); } } } } } FetchUnmergedList ParallelExecutor::Run( const std::vector &fetch_tensors) { LOG_FIRST_N(INFO, 1) << "ParallelExecutor is Running (Run)."; PreludeToRun(fetch_tensors); platform::RecordBlock b(0); ResetHasFeedGuard reset_has_feed_guard(member_); ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors, member_->HasGarbageCollectors()); VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; auto fetch_data = member_->executor_->Run(fetch_tensors, /*return_merged=*/false); return PADDLE_GET(FetchUnmergedList, fetch_data); } FetchList ParallelExecutor::RunAndMerge( const std::vector &fetch_tensors) { LOG_FIRST_N(INFO, 1) << "ParallelExecutor is Running (RunAndMerge)."; PreludeToRun(fetch_tensors); platform::RecordBlock b(0); ResetHasFeedGuard reset_has_feed_guard(member_); ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), fetch_tensors, member_->HasGarbageCollectors()); VLOG(3) << "ParallelExecutor begin to run member_->executor_->RunAndMerge"; auto fetch_data = member_->executor_->Run(fetch_tensors, /*return_merged=*/true); return PADDLE_GET(FetchList, fetch_data); } void ParallelExecutor::RunWithoutFetch( const std::vector &skip_eager_vars) { VLOG(3) << "enter ParallelExecutor RunWithoutFetch"; #ifdef WITH_GPERFTOOLS if (gProfileStarted) { ProfilerFlush(); } #endif platform::RecordBlock b(0); ResetHasFeedGuard reset_has_feed_guard(member_); ir::SkipMemOptVarsGuard guard(&(member_->mem_opt_var_infos_), skip_eager_vars, member_->HasGarbageCollectors()); VLOG(3) << "ParallelExecutor begin to run member_->executor_->Run"; member_->executor_->Run(/*fetch_tensors*/ {}, /*return_merged*/ false); } void ParallelExecutor::SkipMemoryReuse( size_t scope_idx, const std::vector &skip_vars) { for (auto &var_name : skip_vars) { bool is_persistable = member_->IsPersistable(var_name); if (!is_persistable) { VLOG(3) << "SkipMemoryReuse for var: " << var_name; member_->SetSkipMemoryReuse(scope_idx, var_name); } } } void ParallelExecutor::FeedTensorsIntoLocalScopes( const std::vector> &tensors) { if (platform::IsCUDAGraphCapturing()) { for (auto &tensor : tensors) { PADDLE_ENFORCE_EQ( tensor.empty(), true, platform::errors::PermissionDenied( "Feeding data is not permitted when capturing CUDA Graph.")); } return; } if (!member_->AllowPartialFeed()) { PADDLE_ENFORCE_EQ(tensors.size(), member_->local_scopes_.size(), platform::errors::Unimplemented( "The feed data number %d does not match the device " "number %d. If you are using DataLoader to feed " "data, this may be because you set drop_last=False " "in training network. Currently, drop_last=False for " "DataLoader is not supported for training network. " "Please set drop_last=True when defining DataLoader.", tensors.size(), member_->local_scopes_.size())); } else { PADDLE_ENFORCE_GE(member_->local_scopes_.size(), tensors.size(), platform::errors::InvalidArgument( "The feed tensor number exceeds the device number")); } size_t feed_num = 0; for (size_t i = 0; i < tensors.size(); ++i) { auto &map = tensors[i]; if (map.empty()) { continue; } member_->SetHasFeed(i); ++feed_num; for (auto &pair : map) { bool is_persistable = member_->IsPersistable(pair.first); if (!is_persistable) { member_->SetSkipMemoryReuse(i, pair.first); } auto *feed_scope = is_persistable ? member_->local_scopes_[i] : member_->local_exec_scopes_[i]; auto *feed_var = feed_scope->Var(pair.first); auto *trg = feed_var->GetMutable(); trg->ShareDataWith(pair.second); trg->set_lod(pair.second.lod()); } } if (!member_->AllowPartialFeed()) { PADDLE_ENFORCE_EQ(feed_num, member_->local_scopes_.size(), platform::errors::Unimplemented( "The feed data number %d does not match the device " "number %d. If you are using DataLoader to feed " "data, this may be because you set drop_last=False " "in training network. Currently, drop_last=False for " "DataLoader is not supported for training network. " "Please set drop_last=True when defining DataLoader.", feed_num, member_->local_scopes_.size())); } } void ParallelExecutor::FeedAndSplitTensorIntoLocalScopes( const std::unordered_map &tensors) { if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ( tensors.empty(), true, platform::errors::PermissionDenied( "Feeding data is not permitted when capturing CUDA Graph.")); return; } size_t num_places = member_->places_.size(); bool allow_partial_feed = member_->AllowPartialFeed(); size_t persistable_feed_len = -1UL; size_t non_persistable_feed_len = -1UL; for (auto &pair : tensors) { bool is_persistable = member_->IsPersistable(pair.first); VLOG(3) << "Split " << (is_persistable ? "persistable" : "no persistable") << " data (" << pair.first << "), dim:" << pair.second.dims() << ", place: " << pair.second.place(); auto lod_tensors = SplitLoDTensor(pair.second, member_->places_); bool is_cpu_place = platform::is_cpu_place(member_->places_.front()); if (!is_persistable && num_places != lod_tensors.size() && !allow_partial_feed) { auto error_info = string::Sprintf( "The number(%d) of samples[%s] of current batch is less than the " "count(%d) of devices(%s), currently, it is not allowed. ", lod_tensors.size(), pair.first, num_places, (is_cpu_place ? "CPU" : "GPU")); if (is_cpu_place) { error_info += "You should set the environment variable CPU_NUM in the system " "to determine the number of devices you need."; } PADDLE_THROW(platform::errors::PreconditionNotMet(error_info)); } else if (is_persistable) { if (lod_tensors.size() == 1) { lod_tensors.reserve(num_places); auto &tensor = lod_tensors.front(); PADDLE_ENFORCE_EQ( tensor.dims(), pair.second.dims(), platform::errors::PreconditionNotMet("The dim doesn't match.")); PADDLE_ENFORCE_EQ( tensor.place(), member_->places_.at(0), platform::errors::PreconditionNotMet("The place doesn't match.")); for (size_t i = 1; i < num_places; ++i) { lod_tensors.emplace_back(); auto &tmp = lod_tensors.back(); framework::TensorCopy(pair.second, member_->places_.at(i), &tmp); } } if (lod_tensors.size() != num_places && !allow_partial_feed) { auto error_info = string::Sprintf( "The number(%d) of samples[%s] of the current batch does not match " "the count(%d) of devices(%s). Because that %s is a persistable " "variable, you can feed just one sample, in that case, the input " "sample will be copied in %d copies and be sent to different " "places separately. If you need that different place has different " "value, you should feed %d samples.", lod_tensors.size(), pair.first, num_places, (is_cpu_place ? "CPU" : "GPU"), pair.first, num_places, num_places); PADDLE_THROW(platform::errors::PreconditionNotMet(error_info)); } } if (allow_partial_feed) { if (is_persistable) { if (persistable_feed_len == -1UL) { persistable_feed_len = lod_tensors.size(); } else { PADDLE_ENFORCE_EQ( persistable_feed_len, lod_tensors.size(), platform::errors::InvalidArgument( "The feeded number of different persistable variables " "should be the same")); } } else { if (non_persistable_feed_len == -1UL) { non_persistable_feed_len = lod_tensors.size(); } else { PADDLE_ENFORCE_EQ( non_persistable_feed_len, lod_tensors.size(), platform::errors::InvalidArgument( "The feeded number of different non-persistable variables " "should be the same")); } } } for (size_t j = 0; j < lod_tensors.size(); ++j) { auto *feed_scope = is_persistable ? member_->local_scopes_[j] : member_->local_exec_scopes_[j]; auto *feed_var = feed_scope->Var(pair.first); auto t = feed_var->GetMutable(); t->ShareDataWith(lod_tensors[j]); t->set_lod(lod_tensors[j].lod()); } } if (allow_partial_feed && persistable_feed_len != -1UL && non_persistable_feed_len != -1UL) { VLOG(10) << "Persistable len " << persistable_feed_len; VLOG(10) << "Non persistable len " << non_persistable_feed_len; PADDLE_ENFORCE_GE(persistable_feed_len, non_persistable_feed_len, platform::errors::InvalidArgument( "The feeded number of persistable variables should " "not be less than non-persistable variables")); } if (non_persistable_feed_len != -1UL) { for (size_t i = 0; i < non_persistable_feed_len; ++i) { member_->SetHasFeed(i); } } } ParallelExecutor::~ParallelExecutor() { for (auto &p : member_->places_) { platform::DeviceContextPool::Instance().Get(p)->Wait(); } delete member_; } bool ParallelExecutor::EnableParallelGraphExecution( const ir::Graph &graph, const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy) const { if (!FLAGS_enable_parallel_graph) { return false; } bool enable_parallel_graph = true; for (ir::Node *node : graph.Nodes()) { if (node->IsVar() && node->Var()) { // TODO(Yancey1989): support sparse update in ParallelGraph mode. if (node->Var()->GetType() == proto::VarType::SELECTED_ROWS) { enable_parallel_graph = false; break; } } else if (node->IsOp() && node->Op()) { // TODO(Yancey1989): support pserver mode if (node->Op()->Type() == "send" || node->Op()->Type() == "recv") { enable_parallel_graph = false; break; } } } if (!member_->use_all_reduce_ || !member_->IsUseCUDA(member_->use_device_)) { if (build_strategy.enable_sequential_execution_ || exec_strategy.type_ == ExecutionStrategy::ExecutorType::kExperimental) { enable_parallel_graph = false; } } #ifdef WIN32 VLOG(1) << "Windows has no support to parallel graph, enable_parallel_graph " "would be forced to false."; enable_parallel_graph = false; #endif return enable_parallel_graph; } void ParallelExecutor::InitExecutorPrivateMemberInfo( const ExecutionStrategy &exec_strategy, const BuildStrategy &build_strategy, size_t device_count, const ir::Graph &graph) { member_->use_device_ = exec_strategy.use_device_; member_->build_strategy_ = build_strategy; member_->use_all_reduce_ = member_->build_strategy_.reduce_ == BuildStrategy::ReduceStrategy::kAllReduce; member_->nranks_ = build_strategy.num_trainers_ * device_count; if (!member_->use_all_reduce_ && member_->nranks_ == 1) { LOG(INFO) << "If you set build_strategy.reduce with 'Reduce'," "the number of places should be greater than 1."; member_->build_strategy_.reduce_ = BuildStrategy::ReduceStrategy::kAllReduce; member_->use_all_reduce_ = true; } #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && defined(_WIN32) if (member_->IsUseCUDA(member_->use_device_)) { PADDLE_ENFORCE_EQ( device_count, 1, platform::errors::Unavailable("Windows can support Single GPU only.")); } #endif #if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ (!defined(PADDLE_WITH_NCCL) && !defined(PADDLE_WITH_RCCL)) if (member_->IsUseCUDA(member_->use_device_)) { PADDLE_ENFORCE_EQ( device_count, 1, platform::errors::PermissionDenied( "Your machine has multiple cards, " "but the WITH_NCCL option is not turned on during compilation, " "and you cannot use multi-card training or prediction. " "Please recompile and turn on the WITH_NCCL option.")); } #endif std::string device_name; if (member_->use_device_ == p::kCPU) { device_name = "CPU"; } else if (member_->use_device_ == p::kCUDA) { device_name = "CUDA"; } else { device_name = "XPU"; } VLOG(1) << string::Sprintf( "The Program will be executed on %s using ParallelExecutor, %lu " "cards are used, so %lu programs are executed in parallel.", device_name, device_count, device_count); // FIXME(Yancey1989): parallel graph mode get better performance // in GPU allreduce distributed training. Need an elegant way to // choice the execution strategy. member_->build_strategy_.enable_parallel_graph_ = EnableParallelGraphExecution( graph, exec_strategy, member_->build_strategy_); if (member_->build_strategy_.enable_parallel_graph_) { LOG(INFO) << "The Executor would execute the graph by ParallelGraph " "Execution which can get better performance," << "you can force it off by env FLAGS_enable_parallel_graph=0"; } } void ParallelExecutor::CreateLocalScopes( Scope *global_scope, const std::vector &local_scopes, bool create_new) { if (local_scopes.empty()) { member_->own_local_scope_ = true; member_->local_scopes_.emplace_back(global_scope); for (size_t i = 1; i < member_->places_.size(); ++i) { member_->local_scopes_.emplace_back(&global_scope->NewScope()); } } else { member_->own_local_scope_ = false; PADDLE_ENFORCE_EQ(member_->places_.size(), local_scopes.size(), platform::errors::PreconditionNotMet( "member_->places_.size() = %d is not equal to " "local_scopes.size() = %d", member_->places_.size(), local_scopes.size())); for (size_t i = 0; i < member_->places_.size(); ++i) { if (create_new) { member_->local_scopes_.emplace_back(&local_scopes[i]->NewScope()); } else { // Use local scopes directly member_->local_scopes_.emplace_back(local_scopes[i]); } } } } std::unordered_map ParallelExecutor::CreateLocalExecScopes( const std::vector &local_scopes, bool create_new) { std::unordered_map scope_map; for (auto *scope : local_scopes) { Scope *local_exec_scope = scope; if (create_new) { local_exec_scope = &scope->NewScope(); } member_->local_exec_scopes_.emplace_back(local_exec_scope); scope_map.emplace(scope, local_exec_scope); } PADDLE_ENFORCE_EQ(member_->local_scopes_.size(), member_->local_exec_scopes_.size(), platform::errors::PreconditionNotMet( "member_->local_scopes_.size() = %d is not equal to " "member_->local_exec_scopes_.size() = %d", member_->local_scopes_.size(), member_->local_exec_scopes_.size())); return scope_map; } std::vector ParallelExecutor::CloneGraphToMultiDevices( ir::Graph *graph) { std::vector graphs; if (member_->build_strategy_.async_mode_) { PADDLE_ENFORCE_EQ(member_->IsUseCUDA(member_->use_device_), false, platform::errors::Unavailable( "gpu mode does not support async_mode_ now!")); graphs.push_back(graph); for (size_t i = 1; i < member_->places_.size(); ++i) { auto *tmp_graph = new ir::Graph(graph->OriginProgram()); async_graphs_.emplace_back(tmp_graph); graphs.push_back(tmp_graph); } } return graphs; } void ParallelExecutor::PreludeToRun( const std::vector &fetch_tensors) { platform::RecordEvent record_run( "ParallelExecutor::Run", platform::TracerEventType::UserDefined, 1); VLOG(3) << "enter ParallelExecutor Run"; #ifdef PADDLE_WITH_CUDA if (platform::IsCUDAGraphCapturing()) { PADDLE_ENFORCE_EQ(fetch_tensors.empty(), true, platform::errors::InvalidArgument( "Cannot fetch data when using CUDA Graph.")); PADDLE_ENFORCE_EQ( member_->build_strategy_.allow_cuda_graph_capture_, true, platform::errors::InvalidArgument( "You must turn on build_strategy.allow_cuda_graph_capture = True " "to enable CUDA Graph capturing.")); PADDLE_ENFORCE_EQ( member_->places_[0], platform::CUDAGraphCapturingPlace(), platform::errors::InvalidArgument("The place to capture CUDAGraph is " "not the same as the place to run.")); } #endif #ifdef WITH_GPERFTOOLS if (gProfileStarted) { ProfilerFlush(); } #endif } void ParallelExecutor::PrepareNCCLCommunicator(Scope *global_scope) { if (member_->build_strategy_.reduce_ == BuildStrategy::ReduceStrategy::kNoReduce) { return; } if (member_->IsUseCUDA(member_->use_device_) && member_->nranks_ > 1) { #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) member_->InitOrGetNCCLCommunicator(global_scope, &member_->build_strategy_); // Initialize device context's nccl comm, will be used by normal // Operators like sync_batch_norm, and collective ops. // NOTE: more than one ParallelExecutor with same place, the nccl comm will // be rewrite and there will be some problem. // NOTE: NCCL group-calls and non-group-calls can not use the same // NCCL communicator, so for ParallelGraph and Multi-Process mode, re-use // same communicators. auto *nccl_ctxs = member_->nccl_ctxs_->GetSyncBatchNormCtx( global_scope, member_->places_); auto &pool = platform::DeviceContextPool::Instance(); for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { auto *dev_ctx = static_cast(pool.Get(member_->places_[dev_id])); auto &nccl_ctx = nccl_ctxs->at(member_->places_[dev_id]); dev_ctx->set_nccl_comm(nccl_ctx.comm()); } #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with CUDA.")); #endif } if (member_->use_device_ == p::kXPU && member_->nranks_ > 1) { #if defined(PADDLE_WITH_XPU_BKCL) member_->InitOrGetBKCLCommunicator(global_scope, member_->build_strategy_); auto *bkcl_ctxs = member_->bkcl_ctxs_->GetSyncBatchNormCtx( global_scope, member_->places_); auto &pool = platform::DeviceContextPool::Instance(); for (size_t dev_id = 0; dev_id < member_->places_.size(); ++dev_id) { auto *dev_ctx = static_cast( pool.Get(member_->places_[dev_id])); auto &bkcl_ctx = bkcl_ctxs->at(member_->places_[dev_id]); dev_ctx->SetBkclContext(bkcl_ctx.comm()); } #else PADDLE_THROW( platform::errors::PreconditionNotMet("Not compiled with XPU.")); #endif } } std::vector ParallelExecutor::CompileGraphWithBuildStrategy( ir::Graph *graph, std::vector *device_graphs, const std::string &loss_var_name) { auto device_count = member_->places_.size(); std::vector async_graphs(device_count); auto &graphs = *device_graphs; #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) if (member_->build_strategy_.async_mode_) { PADDLE_ENFORCE_EQ(graphs.size(), device_count, platform::errors::PreconditionNotMet( "graphs.size() shoule be %d, but received %d", device_count, graphs.size())); VLOG(3) << "use local async mode"; graph = member_->build_strategy_.Apply(graph, {member_->places_[0]}, loss_var_name, {member_->local_scopes_[0]}, 1, member_->use_device_, member_->nccl_ctxs_); for (size_t i = 1; i < device_count; ++i) { graphs[i] = member_->build_strategy_.Apply(graphs[i], {member_->places_[i]}, loss_var_name, {member_->local_scopes_[i]}, 1, member_->use_device_, member_->nccl_ctxs_); async_graphs[i] = graphs[i]; } } else { graph = member_->build_strategy_.Apply(graph, member_->places_, loss_var_name, member_->local_scopes_, member_->nranks_, member_->use_device_, member_->nccl_ctxs_); } #elif defined(PADDLE_WITH_XPU_BKCL) if (member_->build_strategy_.async_mode_) { PADDLE_ENFORCE_EQ(graphs.size(), device_count, platform::errors::PreconditionNotMet( "graphs.size() shoule be %d, but received %d", device_count, graphs.size())); VLOG(3) << "use local async mode"; graph = member_->build_strategy_.Apply(graph, {member_->places_[0]}, loss_var_name, {member_->local_scopes_[0]}, 1, member_->use_device_, member_->bkcl_ctxs_); for (size_t i = 1; i < device_count; ++i) { graphs[i] = member_->build_strategy_.Apply(graphs[i], {member_->places_[i]}, loss_var_name, {member_->local_scopes_[i]}, 1, member_->use_device_, member_->bkcl_ctxs_); async_graphs[i] = graphs[i]; } } else { graph = member_->build_strategy_.Apply(graph, member_->places_, loss_var_name, member_->local_scopes_, member_->nranks_, member_->use_device_, member_->bkcl_ctxs_); } #else if (member_->build_strategy_.async_mode_) { VLOG(3) << "use local async mode"; graph = member_->build_strategy_.Apply(graph, {member_->places_[0]}, loss_var_name, {member_->local_scopes_[0]}, 1, member_->use_device_); for (size_t i = 1; i < device_count; ++i) { graphs[i] = member_->build_strategy_.Apply(graphs[i], {member_->places_[i]}, loss_var_name, {member_->local_scopes_[i]}, 1, member_->use_device_); async_graphs[i] = graphs[i]; } } else { graph = member_->build_strategy_.Apply(graph, member_->places_, loss_var_name, member_->local_scopes_, member_->nranks_, member_->use_device_); } #endif return async_graphs; } void ParallelExecutor::CreateVariableInfos( std::vector *var_infos, ir::Graph *graph) { PADDLE_ENFORCE_EQ( var_infos->size(), 0, platform::errors::PreconditionNotMet( "var_infos->size() shoule be 0, but received %d", var_infos->size())); PADDLE_ENFORCE_EQ( member_->is_persistable_.size(), 0, platform::errors::PreconditionNotMet( "member_->is_persistable_.size() shoule be 0, but received %d", member_->is_persistable_.size())); for (auto &node : graph->Nodes()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { var_infos->emplace_back(); var_infos->back().name_ = node->Var()->Name(); var_infos->back().type_ = node->Var()->GetType(); var_infos->back().persistable_ = node->Var()->Persistable(); member_->is_persistable_.emplace(node->Var()->Name(), node->Var()->Persistable()); } } if (graph->Has(details::kFusedVars)) { auto &fused_vars = graph->Get(details::kFusedVars); for (auto &fused_var : fused_vars) { var_infos->emplace_back(); var_infos->back() = fused_var.second; member_->is_persistable_.emplace(fused_var.first, fused_var.second.persistable_); } } } std::vector ParallelExecutor::CreateSSAGraphExecutor( const ExecutionStrategy &exec_strategy, std::vector *async_graphs, ir::Graph *graph) { std::vector final_graphs; if (member_->build_strategy_.async_mode_) { VLOG(3) << "use AsyncSSAGraphExecutor"; member_->executor_.reset( new details::AsyncSSAGraphExecutor(exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, *async_graphs)); final_graphs = *async_graphs; } else if (member_->build_strategy_.enable_parallel_graph_) { VLOG(3) << "use ParallelSSAGraphExecutor"; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) // TODO(Yancey1989): Remove passing in the main_program when // allreduce_seq_pass doesn't need it as the attr. bool is_inference = details::IsDataParallelInferenceGraph(*graph); bool has_drop_last_read_op = details::HasDropLastReadOp(*graph); auto *pg_exe = new details::ParallelSSAGraphExecutor(exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph); final_graphs = pg_exe->Graphs(); member_->executor_.reset(pg_exe); if (is_inference && member_->places_.size() > 1) { member_->inference_executor_ = pg_exe; if (!has_drop_last_read_op) { VLOG(5) << "Enable partial feed support in inference phase"; pg_exe->EnablePartialFeedSupport(); } } #else PADDLE_THROW(platform::errors::PreconditionNotMet( "Paddle should be compiled with CUDA for ParallelGraph Execution.")); #endif } else { bool has_drop_last_read_op = details::HasDropLastReadOp(*graph); auto possible_inference_graphs = details::TrySeparateToMultipleSingleDeviceGraphs(graph); if (!possible_inference_graphs.empty()) { for (auto &g : possible_inference_graphs) { member_->ApplyFixOpRunOrderPass(g.get()); } VLOG(5) << "Use ParallelSSAGraphExecutor in inference phase"; auto *pg_exe = new details::ParallelSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, std::move(possible_inference_graphs)); if (!has_drop_last_read_op) { VLOG(5) << "Enable partial feed support in inference phase"; pg_exe->EnablePartialFeedSupport(); } final_graphs = pg_exe->Graphs(); member_->executor_.reset(pg_exe); member_->inference_executor_ = pg_exe; } else { if (member_->places_.size() == 1) { member_->ApplyFixOpRunOrderPass(graph); } LOG_IF(WARNING, details::HasKeepLastReadOp(*graph)) << "drop_last=False for DataLoader is not supported in training " "network. It is automatically turned to drop_last=True."; if (exec_strategy.type_ == ExecutionStrategy::kDefault) { VLOG(3) << "use ThreadedSSAGraphExecutor"; member_->executor_.reset( new details::ThreadedSSAGraphExecutor(exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph)); } else { if (member_->use_device_ == p::kXPU) { #if defined(PADDLE_WITH_XPU) VLOG(3) << "use BindThreadedSSAGraphExecutor"; member_->executor_.reset(new details::BindThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph)); #else PADDLE_THROW(platform::errors::PermissionDenied( "Paddle can't use XPU device since it's not compiled with XPU," "Please recompile or reinstall Paddle with XPU support.")); #endif } else { VLOG(3) << "use FastThreadedSSAGraphExecutor"; member_->executor_.reset(new details::FastThreadedSSAGraphExecutor( exec_strategy, member_->local_scopes_, member_->local_exec_scopes_, member_->places_, graph)); } } final_graphs.emplace_back(graph); } } return final_graphs; } void ParallelExecutor::ResetOpHandleScopeMapOfGraphs( const std::vector &final_graphs, const std::unordered_map &scope_map) { PADDLE_ENFORCE_GE( final_graphs.size(), 1, platform::errors::PreconditionNotMet( "final_graphs shoule contain at least one graph, but received %d", final_graphs.size())); PADDLE_ENFORCE_GT(scope_map.size(), 0, platform::errors::PreconditionNotMet( "scope_map shoule contain at least one " "element, but received %d", scope_map.size())); for (auto *g : final_graphs) { auto ops = ir::FilterByNodeWrapper(*g); for (auto *op : ops) { op->SetLocalExecScopes(scope_map); op->SetIsVariantScope(true); } } } void ParallelExecutor::ResetOpHandleScopeMapOfGraphs( const std::unordered_map &scope_map) { auto inner_graph = const_cast(&Graph()); std::vector graphs = {inner_graph}; ResetOpHandleScopeMapOfGraphs(graphs, scope_map); } void ParallelExecutor::SetReaderOpDeviceInfoOfGraphs( const std::vector &final_graphs) { if (final_graphs.size() == 1) { ir::SetReaderOpDeviceInfo(final_graphs[0], member_->places_.size()); } else { for (size_t i = 0; i < final_graphs.size(); ++i) { ir::SetReaderOpDeviceInfo(final_graphs[i], member_->places_.size(), i); } } } const ir::Graph &ParallelExecutor::Graph() const { return member_->executor_->Graph(); } void ParallelExecutor::PrepareForCUDAGraphCapture(ir::Graph *graph) { const auto &build_strategy = member_->build_strategy_; if (!build_strategy.allow_cuda_graph_capture_) return; #ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_EQ( build_strategy.async_mode_, false, platform::errors::InvalidArgument( "Async Executor does not support CUDA Graph capturing.")); PADDLE_ENFORCE_EQ( platform::IsCUDAGraphCapturing(), false, platform::errors::PermissionDenied("CUDA Graph is not allowed to capture " "when running the first batch.")); PADDLE_ENFORCE_EQ( member_->places_.size(), 1, platform::errors::InvalidArgument( "CUDA Graph is only supported when one GPU device is running.")); PADDLE_ENFORCE_EQ(platform::is_gpu_place(member_->places_[0]), true, platform::errors::InvalidArgument( "CUDA Graph is only supported on NVIDIA GPU device.")); PADDLE_ENFORCE_EQ(FLAGS_sync_nccl_allreduce, false, platform::errors::InvalidArgument( "FLAGS_sync_nccl_allreduce must be False to support " "CUDA Graph capturing.")); std::unordered_map> all_vars; for (auto &node : graph->Nodes()) { if (node->IsVar() && !node->IsCtrlVar() && node->Var()) { auto *var_desc = node->Var(); all_vars[var_desc->Name()].emplace_back(var_desc); } } auto mark_var_as_persistable = [&all_vars](const std::string &name) { auto iter = all_vars.find(name); if (iter != all_vars.end()) { for (auto *var_desc : iter->second) { var_desc->SetPersistable(true); } } }; // Step 1: All fused vars must be persistable. if (graph->Has(details::kFusedVars)) { auto &fused_vars = graph->Get(details::kFusedVars); for (auto &fused_var : fused_vars) { fused_var.second.persistable_ = true; mark_var_as_persistable(fused_var.first); } } // Step 2: All pinned vars must be persistable. if (graph->Has(details::kPinnedVars)) { auto &pinned_vars = graph->Get(details::kPinnedVars); for (auto &pinned_var : pinned_vars) { mark_var_as_persistable(pinned_var); } } // Step 3: Move all main programs to startup programs to make sure that // the main programs would only be run once. if (graph->Has(details::kProgramDescs)) { auto &startup_programs = graph->GetOrInit(details::kStartupProgramDescs); auto &main_programs = graph->Get(details::kProgramDescs); for (auto &main_program : main_programs) { startup_programs.emplace_back(main_program); } graph->Erase(details::kProgramDescs); } // Step 4: Mark all vars in startup programs to be persistable. if (graph->Has(details::kStartupProgramDescs)) { auto &startup_programs = graph->GetOrInit(details::kStartupProgramDescs); for (auto &startup_program : startup_programs) { for (auto &op_desc : startup_program.Block(0).AllOps()) { for (auto &output : op_desc->OutputArgumentNames()) { mark_var_as_persistable(output); } } } } // Step 5: ScaleLossGrad must be run beforehand to avoid H2D copy. auto ops = ir::FilterByNodeWrapper(*graph); auto *scope = member_->local_scopes_[0]; for (auto *op : ops) { auto *loss_grad_op = dynamic_cast(op); if (loss_grad_op == nullptr) continue; auto loss_grad_name = loss_grad_op->LossGradName(); mark_var_as_persistable(loss_grad_name); loss_grad_op->RunOnVar(scope->Var(loss_grad_name)); loss_grad_op->SetSkipRunning(true); } #else PADDLE_THROW(platform::errors::Unimplemented( "CUDA Graph is only supported on NVIDIA GPU device.")); #endif } } // namespace framework } // namespace paddle USE_PASS(reference_count_pass); USE_PASS(eager_deletion_pass); USE_PASS(buffer_shared_inplace_pass); USE_PASS(buffer_shared_cross_op_memory_reuse_pass); USE_PASS(inplace_addto_op_pass); USE_PASS(fix_op_run_order_pass);