From 92c2dcbdef28e7c8b04ea20e784666dc4022d7da Mon Sep 17 00:00:00 2001 From: LiYuRio <63526175+LiYuRio@users.noreply.github.com> Date: Mon, 20 Mar 2023 10:15:27 +0800 Subject: [PATCH] Cherry-pick fleet executor and auto parallel (#50071) --- cmake/third_party.cmake | 3 +- .../distributed/fleet_executor/CMakeLists.txt | 6 + .../fleet_executor/amplifier_interceptor.cc | 6 +- .../fleet_executor/amplifier_interceptor.h | 2 +- .../distributed/fleet_executor/carrier.cc | 128 +- .../distributed/fleet_executor/carrier.h | 6 +- .../fleet_executor/compute_interceptor.cc | 386 ++-- .../fleet_executor/compute_interceptor.h | 37 +- .../fleet_executor/cond_interceptor.cc | 167 ++ .../fleet_executor/cond_interceptor.h | 55 + .../fleet_executor/fleet_executor.cc | 174 +- .../fleet_executor/fleet_executor.h | 7 +- .../distributed/fleet_executor/interceptor.h | 4 - .../fleet_executor/interceptor_message.proto | 16 + .../fleet_executor/sink_interceptor.h | 2 +- .../fleet_executor/source_interceptor.h | 2 +- .../fleet_executor/start_interceptor.cc | 115 ++ .../fleet_executor/start_interceptor.h | 39 + .../distributed/fleet_executor/task_node.cc | 62 +- .../distributed/fleet_executor/task_node.h | 63 +- .../test/compute_interceptor_run_op_test.cc | 5 +- .../test/compute_interceptor_test.cc | 60 +- .../test/interceptor_ping_pong_test.cc | 1 - .../interceptor_ping_pong_with_brpc_test.cc | 1 - .../interceptor_pipeline_long_path_test.cc | 14 +- .../interceptor_pipeline_short_path_test.cc | 9 +- .../test/sink_interceptor_test.cc | 7 +- .../test/source_interceptor_test.cc | 5 +- .../operators/collective/c_broadcast_op.cu.cc | 2 + .../operators/collective/c_embedding_op.cu | 100 +- paddle/fluid/pybind/bind_fleet_executor.cc | 15 +- .../phi/kernels/gpu/embedding_grad_kernel.cu | 8 + .../distributed/auto_parallel/completion.py | 1188 ++++++++---- .../distributed/auto_parallel/constants.py | 10 + .../auto_parallel/cost/estimate_cost.py | 4 +- .../distributed/auto_parallel/dist_context.py | 627 +++++-- .../distributed/auto_parallel/dist_op.py | 154 +- .../distributed/auto_parallel/engine.py | 180 +- .../distributed/auto_parallel/interface.py | 90 +- .../auto_parallel/operators/__init__.py | 1 + .../auto_parallel/operators/common.py | 192 +- .../auto_parallel/operators/dist_default.py | 244 ++- .../dist_fill_constant_batch_size_like.py | 54 +- .../auto_parallel/operators/dist_scale.py | 90 + .../distributed/auto_parallel/parallelizer.py | 321 ++-- .../auto_parallel/parallelizer_v2.py | 225 ++- .../auto_parallel/process_group.py | 77 +- .../distributed/auto_parallel/process_mesh.py | 58 +- .../distributed/auto_parallel/reshard.py | 1666 +++++++++++------ .../distributed/auto_parallel/strategy.py | 24 +- .../auto_parallel/tuner/profiler.py | 79 +- .../paddle/distributed/auto_parallel/utils.py | 6 + .../distributed/fleet/fleet_executor_utils.py | 326 ++-- python/paddle/distributed/parallel.py | 7 + python/paddle/distributed/passes/__init__.py | 1 + .../passes/auto_parallel_grad_clip.py | 106 +- .../passes/auto_parallel_pipeline.py | 635 +++++++ python/paddle/fluid/executor.py | 1503 +++++++++------ .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/auto_parallel/CMakeLists.txt | 1 + .../auto_parallel/amp_pass_unittest.py | 10 +- .../auto_parallel/clip_grad_by_global_norm.py | 12 +- .../generation_pipeline_pass_unittest.py | 177 ++ .../gradient_merge_pass_unittest.py | 21 +- .../auto_parallel/recompute_pass_unittest.py | 6 +- .../auto_parallel/sharding_pass_unittest.py | 24 +- .../auto_parallel/test_dist_context.py | 190 +- .../test_pass_generation_pipeline.py | 58 + .../unittests/test_auto_parallel_reshard.py | 236 ++- .../test_auto_parallel_reshard_serial.py | 124 +- .../test_fleet_executor_cond_interceptor.py | 217 +++ .../test_fleet_executor_task_node.py | 27 +- .../test_fleet_executor_with_task_nodes.py | 42 +- python/paddle/tensor/stat.py | 1 + 74 files changed, 7436 insertions(+), 3086 deletions(-) create mode 100644 paddle/fluid/distributed/fleet_executor/cond_interceptor.cc create mode 100644 paddle/fluid/distributed/fleet_executor/cond_interceptor.h create mode 100644 paddle/fluid/distributed/fleet_executor/start_interceptor.cc create mode 100644 paddle/fluid/distributed/fleet_executor/start_interceptor.h create mode 100644 python/paddle/distributed/auto_parallel/operators/dist_scale.py create mode 100644 python/paddle/distributed/passes/auto_parallel_pipeline.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py create mode 100644 python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py diff --git a/cmake/third_party.cmake b/cmake/third_party.cmake index 868be06cf82..96a78b527ac 100755 --- a/cmake/third_party.cmake +++ b/cmake/third_party.cmake @@ -426,7 +426,8 @@ endif() if(WITH_DISTRIBUTE AND NOT WITH_PSLIB - AND NOT WITH_PSCORE) + AND NOT WITH_PSCORE + AND NOT WITH_RPC) include(external/snappy) list(APPEND third_party_deps extern_snappy) diff --git a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt index cc5ed287e95..ff8ed811ee6 100755 --- a/paddle/fluid/distributed/fleet_executor/CMakeLists.txt +++ b/paddle/fluid/distributed/fleet_executor/CMakeLists.txt @@ -36,6 +36,8 @@ cc_library( interceptor.cc compute_interceptor.cc amplifier_interceptor.cc + cond_interceptor.cc + start_interceptor.cc source_interceptor.cc sink_interceptor.cc message_service.cc @@ -66,6 +68,10 @@ if(WITH_DISTRIBUTE) set_source_files_properties( amplifier_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties( + cond_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) + set_source_files_properties( + start_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( source_interceptor.cc PROPERTIES COMPILE_FLAGS ${DISTRIBUTE_COMPILE_FLAGS}) set_source_files_properties( diff --git a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc index 72c689732b5..a166ff0b6df 100644 --- a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.cc @@ -33,7 +33,7 @@ void AmplifierInterceptor::RunOps() { // run_per_steps_, run_at_offset_ // 4, 0 --> run at step 0, 4, 8, 12 // 4, 3 --> run at step 3, 7, 11, 15 - if ((step_ % run_per_steps_) == run_at_offset_) { + if ((cur_scope_id_ % run_per_steps_) == run_at_offset_) { ComputeInterceptor::RunOps(); } } @@ -41,7 +41,7 @@ void AmplifierInterceptor::RunOps() { void AmplifierInterceptor::SendDataReadyToDownStream() { // run multi times, send ready one times to downstream, that is // input multi times, output one times - if (step_ % send_down_per_steps_ == 0) { + if (cur_scope_id_ % send_down_per_steps_ == 0) { ComputeInterceptor::SendDataReadyToDownStream(); } } @@ -49,7 +49,7 @@ void AmplifierInterceptor::SendDataReadyToDownStream() { void AmplifierInterceptor::ReplyCompletedToUpStream() { // run multi times, reply one times to upstream, that is // input one times, output multi times - if (step_ % reply_up_per_steps_ == 0) { + if (cur_scope_id_ % reply_up_per_steps_ == 0) { ComputeInterceptor::ReplyCompletedToUpStream(); } } diff --git a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h index 776aa8d3e88..93e8ffa1d75 100644 --- a/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/amplifier_interceptor.h @@ -21,7 +21,7 @@ namespace paddle { namespace distributed { -class AmplifierInterceptor : public ComputeInterceptor { +class AmplifierInterceptor final : public ComputeInterceptor { public: AmplifierInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/carrier.cc b/paddle/fluid/distributed/fleet_executor/carrier.cc index 6fb0d55a485..9b023e12a88 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.cc +++ b/paddle/fluid/distributed/fleet_executor/carrier.cc @@ -15,6 +15,7 @@ #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include +#include #include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/interceptor.h" @@ -24,6 +25,7 @@ #include "paddle/fluid/framework/garbage_collector.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/framework/variable_helper.h" namespace paddle { @@ -33,6 +35,8 @@ USE_INTERCEPTOR(Source); USE_INTERCEPTOR(Compute); USE_INTERCEPTOR(Amplifier); USE_INTERCEPTOR(Sink); +USE_INTERCEPTOR(Cond); +USE_INTERCEPTOR(Start); void Carrier::Init( int64_t rank, @@ -54,24 +58,38 @@ void Carrier::Init( framework::Scope* scope, int64_t num_micro_batches, const platform::Place& place, - const std::vector& inference_root_scope_vars) { + const std::vector& inference_root_scope_vars, + const std::vector& micro_scope_list) { rank_ = rank; interceptor_id_to_rank_ = interceptor_id_to_rank; interceptor_id_to_node_ = interceptor_id_to_node; place_ = place; root_scope_ = scope; dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); + bool need_create_scope = micro_scope_list.empty(); PADDLE_ENFORCE_NOT_NULL( root_scope_, platform::errors::InvalidArgument("root_scope can not be nullptr")); - minibatch_scope_ = &root_scope_->NewScope(); - microbatch_scopes_.resize(num_micro_batches); - for (int i = 0; i < num_micro_batches; ++i) { - microbatch_scopes_[i] = &minibatch_scope_->NewScope(); - CopyParameters(i, program, inference_root_scope_vars); + + if (need_create_scope) { + minibatch_scope_ = &root_scope_->NewScope(); + microbatch_scopes_.resize(num_micro_batches); + for (int i = 0; i < num_micro_batches; ++i) { + microbatch_scopes_[i] = &minibatch_scope_->NewScope(); + CopyParameters(i, program, inference_root_scope_vars); + } + } else { + microbatch_scopes_ = micro_scope_list; + for (int i = 0; i < num_micro_batches; ++i) { + CopyParameters(i, program, inference_root_scope_vars); + } } + // Add source and sink interceptor id to rank + interceptor_id_to_rank_.emplace(SOURCE_ID, rank); + interceptor_id_to_rank_.emplace(SINK_ID, rank); + // TODO(fleet_exe dev): thread pool thread_num_ = 1; thread_pool_.SetThreadNum(thread_num_); @@ -93,29 +111,30 @@ void Carrier::CopyParameters( int microbatch_id, const framework::ProgramDesc& program, const std::vector& inference_root_scope_vars) { - auto& global_block = program.Block(0); - std::map inference_root_scope_var_map; for (auto var_name : inference_root_scope_vars) { inference_root_scope_var_map.insert({var_name, 1}); } - for (auto& var : global_block.AllVars()) { - std::string var_name = var->Name(); - bool force_root = inference_root_scope_var_map.find(var_name) != - inference_root_scope_var_map.end(); - if (force_root) { - VLOG(4) << var_name << " will be forced to be created in the root scope."; - } - if ((var->Persistable() || force_root) && microbatch_id == 0) { - auto* ptr = root_scope_->Var(var->Name()); - InitializeVariable(ptr, var->GetType()); - VLOG(5) << "Create persistable var: " << var->Name() - << ", which pointer is " << ptr; - } else if (!var->Persistable()) { - auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); - VLOG(5) << "Create variable " << var->Name() << " for microbatch " - << microbatch_id << ", which pointer is " << ptr << "."; - InitializeVariable(ptr, var->GetType()); + for (size_t i = 0; i < program.Size(); ++i) { + for (auto& var : program.Block(i).AllVars()) { + std::string var_name = var->Name(); + bool force_root = inference_root_scope_var_map.find(var_name) != + inference_root_scope_var_map.end(); + if (force_root) { + VLOG(4) << var_name + << " will be forced to be created in the root scope."; + } + if ((var->Persistable() || force_root) && microbatch_id == 0) { + auto* ptr = root_scope_->Var(var->Name()); + InitializeVariable(ptr, var->GetType()); + VLOG(5) << "Create persistable var: " << var->Name() + << ", which pointer is " << ptr; + } else if (!var->Persistable()) { + auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); + VLOG(5) << "Create variable " << var->Name() << " for microbatch " + << microbatch_id << ", which pointer is " << ptr << "."; + InitializeVariable(ptr, var->GetType()); + } } } } @@ -159,16 +178,11 @@ void Carrier::Start() { true, platform::errors::PreconditionNotMet( "Using carrier before initialized.")); - for (int64_t id : source_interceptor_ids_) { - VLOG(3) << "Carrier Start is sending start to source interceptor " << id - << "."; - InterceptorMessage start_msg; - // source node data_is_ready is send by carrier, so set src_id=-1 - start_msg.set_src_id(-1); - start_msg.set_dst_id(id); - start_msg.set_message_type(DATA_IS_READY); - Send(start_msg); - } + InterceptorMessage start_msg; + start_msg.set_src_id(SOURCE_ID); + start_msg.set_dst_id(SOURCE_ID); + start_msg.set_message_type(START); + Send(start_msg); // TODO(wangxi): async step Wait(); dev_ctx_->Wait(); @@ -270,6 +284,38 @@ void Carrier::CreateInterceptors() { auto gc = GetGC(place_); + // create source and sink task node + auto max_run_times = microbatch_scopes_.size(); + TaskNode* source = new TaskNode( + rank_, SOURCE_ID, max_run_times); // rank, task_id, max_run_times + TaskNode* sink = new TaskNode(rank_, SINK_ID, max_run_times); + // find nodes without upstreams or without downstreams + std::vector origin_sources, origin_sinks; + for (const auto& item : interceptor_id_to_node_) { + TaskNode* task_node = item.second; + if (task_node->upstream().empty()) { + origin_sources.emplace_back(task_node); + } + if (task_node->downstream().empty()) { + origin_sinks.emplace_back(task_node); + } + } + // link source node with origin source + for (const auto& node : origin_sources) { + source->AddDownstreamTask(node->task_id(), + std::numeric_limits::max()); + node->AddUpstreamTask(SOURCE_ID, std::numeric_limits::max()); + } + // link sink node with origin sink + for (const auto& node : origin_sinks) { + sink->AddUpstreamTask(node->task_id(), std::numeric_limits::max()); + node->AddDownstreamTask(SINK_ID, std::numeric_limits::max()); + } + // create source and sink interceptor + SetInterceptor(SOURCE_ID, + InterceptorFactory::Create("Source", SOURCE_ID, source)); + SetInterceptor(SINK_ID, InterceptorFactory::Create("Sink", SINK_ID, sink)); + // create each Interceptor // no auto init since there is no config for (const auto& item : interceptor_id_to_node_) { @@ -303,9 +349,15 @@ void Carrier::CreateInterceptors() { VLOG(3) << "Create Interceptor with interceptor id: " << interceptor_id << " with type: " << task_node->type() << "."; - if (task_node->upstream().empty()) { - source_interceptor_ids_.emplace_back(interceptor_id); - } + PADDLE_ENFORCE_EQ( + task_node->upstream().empty(), + false, + platform::errors::PreconditionNotMet( + "There should not have normal nodes as source nodes")); + PADDLE_ENFORCE_EQ(task_node->downstream().empty(), + false, + platform::errors::PreconditionNotMet( + "There should not have normal nodes as sink nodes")); } } diff --git a/paddle/fluid/distributed/fleet_executor/carrier.h b/paddle/fluid/distributed/fleet_executor/carrier.h index fe3d4926766..8e7fad3e892 100644 --- a/paddle/fluid/distributed/fleet_executor/carrier.h +++ b/paddle/fluid/distributed/fleet_executor/carrier.h @@ -25,6 +25,7 @@ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/interceptor_message.pb.h" #include "paddle/fluid/distributed/fleet_executor/task_loop_thread_pool.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/errors.h" @@ -60,7 +61,8 @@ class Carrier final { framework::Scope* scope, int64_t num_micro_batches, const platform::Place& place, - const std::vector& inference_root_scope_vars = {}); + const std::vector& inference_root_scope_vars = {}, + const std::vector& micro_scope_list = {}); void CopyParameters( int microbatch_id, @@ -100,8 +102,6 @@ class Carrier final { std::unordered_map> interceptor_idx_to_interceptor_; - std::vector source_interceptor_ids_; - bool is_init_{false}; std::mutex running_mutex_; diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc index 5b96ee76e71..08b2cb4b6cb 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.cc @@ -18,10 +18,85 @@ #include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/errors.h" +#include "paddle/phi/core/serialization.h" +#include "paddle/phi/core/utils/dim.h" namespace paddle { namespace distributed { +namespace { + +template +void SetVarResult(const std::string& name, + T value, + int64_t scope_id, + framework::Scope* scope, + const platform::Place& place, + const std::vector& dim_vec) { + auto* var = scope->FindVar(name); + auto* tensor = var->GetMutable(); + if (!var) { + VLOG(3) << "Create var and memory for var " << name; + var = scope->Var(name); + phi::DDim dims = phi::make_ddim(dim_vec); + tensor->Resize(dims); + tensor->mutable_data(dims, place); + } + + PADDLE_ENFORCE_EQ( + tensor->dims().size(), + 1, + platform::errors::OutOfRange("Only support transfer size 1 value.")); + PADDLE_ENFORCE_EQ( + tensor->dims().at(0), + 1, + platform::errors::OutOfRange("Only support transfer size 1 value.")); + if (platform::is_gpu_place(tensor->place())) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::DenseTensor cpu_tensor; + auto dim = phi::make_ddim({1}); + cpu_tensor.mutable_data(dim, platform::CPUPlace()); + auto* cpu_tensor_ptr = cpu_tensor.data(); + cpu_tensor_ptr[0] = value; + framework::TensorCopySync(cpu_tensor, tensor->place(), tensor); +#endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport device for cond interceptor.")); + } +} + +template +T GetVarResult(const std::string& name, + int64_t scope_id, + framework::Scope* scope) { + auto* var = scope->FindVar(name); + PADDLE_ENFORCE(var, + platform::errors::NotFound( + "Variable %s not exists in scope %ld", name, scope_id)); + const auto& tensor = var->Get(); + T res; + if (platform::is_gpu_place(tensor.place())) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::DenseTensor cpu_tensor; + framework::TensorCopySync(tensor, platform::CPUPlace(), &cpu_tensor); + res = cpu_tensor.data()[0]; +#endif + } else if (platform::is_cpu_place(tensor.place())) { + res = tensor.data()[0]; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport device for cond interceptor.")); + } + return res; +} +} // namespace + ComputeInterceptor::ComputeInterceptor(int64_t interceptor_id, TaskNode* node) : Interceptor(interceptor_id, node) { PrepareDeps(); @@ -33,57 +108,49 @@ void ComputeInterceptor::PrepareDeps() { auto& downstream = node_->downstream(); for (auto up : upstream) { - in_readys_.emplace(up.first, std::make_pair(up.second, 0)); - in_stops_.emplace(up.first, false); + std::map ready_size_map; + for (int64_t i = 0; i < node_->max_run_times(); ++i) { + ready_size_map.emplace(i, 0); + } + in_readys_.emplace(up.first, std::make_pair(up.second, ready_size_map)); } for (auto down : downstream) { out_buffs_.emplace(down.first, std::make_pair(down.second, 0)); } - - // source compute node, should we add a new SourceInterceptor? - if (upstream.empty()) { - is_source_ = true; - PADDLE_ENFORCE_GT(node_->max_run_times(), - 0, - platform::errors::InvalidArgument( - "Source ComputeInterceptor must run at least one " - "times, but now max_run_times=%ld", - node_->max_run_times())); - in_readys_.emplace(-1, - std::make_pair(std::numeric_limits::max(), 0)); - } - - // If there is no downstream or every downstream is in different rank, - // then this interceptor is the last one for current rank. - // This can be get during init, can be cached for later use. - is_last_ = downstream.empty(); } -void ComputeInterceptor::IncreaseReady(int64_t up_id) { +void ComputeInterceptor::IncreaseReady(int64_t up_id, int64_t scope_id) { auto it = in_readys_.find(up_id); PADDLE_ENFORCE_NE(it, in_readys_.end(), platform::errors::NotFound( "Cannot find upstream=%lld in in_readys.", up_id)); - // source node has no upstream, data_is_ready is send by carrier or others - if (is_source_ && up_id == -1) { - it->second.second += GetTaskNode()->max_run_times(); - return; - } - auto max_ready_size = it->second.first; - auto ready_size = it->second.second; - ready_size += 1; - PADDLE_ENFORCE_LE(ready_size, - max_ready_size, - platform::errors::OutOfRange( - "upstream=%lld ready_size must <= max_ready_size, but " - "now ready_size=%lld, max_ready_size=%lld", - up_id, - ready_size, - max_ready_size)); - it->second.second = ready_size; + const auto& ready_scope_map = it->second.second; + int64_t ready_size = 0; + for (auto& scope_iter : ready_scope_map) { + ready_size += scope_iter.second; + } + if (max_ready_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + ready_size, + max_ready_size, + platform::errors::OutOfRange( + "upstream=%lld ready_size must <= max_ready_size, but " + "now ready_size=%lld, max_ready_size=%lld", + up_id, + ready_size, + max_ready_size)); + } + PADDLE_ENFORCE_NE( + it->second.second.find(scope_id), + it->second.second.end(), + platform::errors::OutOfRange( + "Interceptor %lld can not find scope %lld in upstream ready map", + interceptor_id_, + scope_id)); + it->second.second.at(scope_id) = ready_scope_map.at(scope_id) + 1; } void ComputeInterceptor::DecreaseBuff(int64_t down_id) { @@ -105,22 +172,40 @@ void ComputeInterceptor::DecreaseBuff(int64_t down_id) { } bool ComputeInterceptor::IsInputReady() { - for (auto& ins : in_readys_) { - auto ready_size = ins.second.second; - // not ready, return false - if (ready_size == 0) { - VLOG(3) << "Interceptor " << GetInterceptorId() + for (int64_t i = 0; i < node_->max_run_times(); ++i) { + bool flag = true; + for (auto& ins : in_readys_) { + auto ready_size_map = ins.second.second; + flag = flag && (ready_size_map.at(i) != 0); + } + if (flag) { + for (auto iter : scope_id_to_finish_flag_) { + if (iter.first == i) { + break; + } else if (!iter.second) { + VLOG(3) << "The previous scope is not ready, waiting for the " + "previous scope " + << iter.first; + return false; + } + } + cur_scope_id_ = i; + return true; + } else { + VLOG(3) << "Interceptor " << GetInterceptorId() << " in scope " << i << "'s upstreams aren't all ready."; - return false; } } - return true; + return false; } bool ComputeInterceptor::CanWriteOutput() { for (auto& outs : out_buffs_) { auto max_buffer_size = outs.second.first; auto used_size = outs.second.second; + if (max_buffer_size == INFINITE_BUFFER_SIZE) { + continue; + } // full, return false if (used_size == max_buffer_size) { VLOG(3) << "Interceptor " << GetInterceptorId() @@ -137,30 +222,76 @@ void ComputeInterceptor::SendDataReadyToDownStream() { auto max_buff_size = outs.second.first; auto used_size = outs.second.second; used_size += 1; - PADDLE_ENFORCE_LE( - used_size, - max_buff_size, - platform::errors::OutOfRange("downstream=%lld used buff size must <= " - "max_buff_size, but now used_size=%lld, " - "max_buff_size=%lld", - down_id, - used_size, - max_buff_size)); + if (max_buff_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + used_size, + max_buff_size, + platform::errors::OutOfRange("downstream=%lld used buff size must <= " + "max_buff_size, but now used_size=%lld, " + "max_buff_size=%lld", + down_id, + used_size, + max_buff_size)); + } outs.second.second = used_size; - InterceptorMessage ready_msg; - ready_msg.set_message_type(DATA_IS_READY); - VLOG(3) << "ComputeInterceptor " << interceptor_id_ - << " Send data_is_ready msg to " << down_id - << " for step: " << step_; - Send(down_id, ready_msg); + bool need_send_vars = !(node_->vars_to_dtype().empty()); + if (need_send_vars) { + InterceptorMessage ready_msg = PrepareVarsMsg(); + VLOG(3) << "ComputeInterceptor " << interceptor_id_ + << " Send data_with_vars msg to " << down_id + << " in scope: " << cur_scope_id_; + Send(down_id, ready_msg); + } else { + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(cur_scope_id_); + VLOG(3) << "ComputeInterceptor " << interceptor_id_ + << " Send data_is_ready msg to " << down_id + << " in scope: " << cur_scope_id_; + Send(down_id, ready_msg); + } + } +} + +InterceptorMessage ComputeInterceptor::PrepareVarsMsg() { + PADDLE_ENFORCE_LT(cur_scope_id_, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + cur_scope_id_)); + auto* scope = microbatch_scopes_[cur_scope_id_]; + + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_WITH_VARS); + ready_msg.set_scope_idx(cur_scope_id_); + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + for (auto iter : node_->vars_to_dtype()) { + VarList* vars = ready_msg.add_vars_list(); + const auto& var_name = iter.first; + vars->set_name(var_name); + std::ostringstream ss; + auto& dev_ctx = *pool.Get(place_); + auto* var = scope->FindVar(var_name); + PADDLE_ENFORCE( + var, + platform::errors::NotFound( + "Variable %s not exists in scope %ld", var_name, cur_scope_id_)); + const auto& tensor = var->Get(); + SerializeToStream(ss, tensor, dev_ctx); + vars->set_stensor(ss.str()); + VLOG(3) << "Prepare vars msg " << var_name << " with dimension " + << tensor.dims() << " dtype " << tensor.dtype(); } + return ready_msg; } void ComputeInterceptor::ReplyCompletedToUpStream() { for (auto& ins : in_readys_) { auto up_id = ins.first; - auto ready_size = ins.second.second; + auto ready_size = ins.second.second.at(cur_scope_id_); ready_size -= 1; PADDLE_ENFORCE_GE( ready_size, @@ -169,109 +300,114 @@ void ComputeInterceptor::ReplyCompletedToUpStream() { "upstream=%lld ready_size must >= 0, but now got %lld", up_id, ready_size)); - ins.second.second = ready_size; + ins.second.second[cur_scope_id_] = ready_size; VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " Reply data_is_useless msg to " << up_id - << " for step: " << step_; - if (is_source_ && up_id == -1) return; + << " in scope: " << cur_scope_id_; InterceptorMessage reply_msg; reply_msg.set_message_type(DATA_IS_USELESS); + reply_msg.set_scope_idx(cur_scope_id_); Send(up_id, reply_msg); } } void ComputeInterceptor::RunOps() { - VLOG(3) << "ComputeInterceptor " << interceptor_id_ << " running ops for the " - << step_ + 1 << " time."; for (auto op : node_->ops()) { - op->Run(*microbatch_scopes_[step_ % node_->max_run_times()], place_); + PADDLE_ENFORCE_LT(cur_scope_id_, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + cur_scope_id_)); + op->Run(*microbatch_scopes_[cur_scope_id_], place_); if (gc_) { - framework::DeleteUnusedTensors( - *microbatch_scopes_[step_ % node_->max_run_times()], - op, - node_->unused_vars(), - gc_.get()); + framework::DeleteUnusedTensors(*microbatch_scopes_[cur_scope_id_], + op, + node_->unused_vars(), + gc_.get()); } } } void ComputeInterceptor::Run() { while (IsInputReady() && CanWriteOutput()) { - VLOG(3) << "id=" << GetInterceptorId() << " ComputeInterceptor running"; + VLOG(3) << "id=" << GetInterceptorId() + << " ComputeInterceptor running in scope " << cur_scope_id_; RunOps(); - ++step_; + + if (!scope_id_to_finish_flag_.empty()) { + PADDLE_ENFORCE_NE( + scope_id_to_finish_flag_.find(cur_scope_id_), + scope_id_to_finish_flag_.end(), + platform::errors::NotFound( + "Can not find scope %ld in scope_id_to_finish", cur_scope_id_)); + scope_id_to_finish_flag_.erase(cur_scope_id_); + } // send to downstream and increase buff used SendDataReadyToDownStream(); // reply to upstream and decrease ready data ReplyCompletedToUpStream(); - // Try to stop Carrier - if (is_last_ && (step_ % node_->max_run_times() == 0)) { - VLOG(3) << "Interceptor " << GetInterceptorId() - << " is stopping carrier."; - // FIXME(wangxi): with multi sink interceptor - StopCarrier(); - } } } -void ComputeInterceptor::ReceivedStop(int64_t up_id) { - received_stop_ = true; - - // source node has no upstream, stop is send by carrier or others - if (is_source_ && up_id == -1) return; - - auto it = in_stops_.find(up_id); - PADDLE_ENFORCE_NE(it, - in_stops_.end(), - platform::errors::NotFound( - "Cannot find upstream=%lld in in_stops.", up_id)); - PADDLE_ENFORCE_EQ( - it->second, - false, - platform::errors::AlreadyExists("Already received stop from %lld, stop " - "cannot be send more than once.")); - it->second = true; -} - -void ComputeInterceptor::TryStop() { - if (!received_stop_) return; - - // can stop only when all upstream is stop and - // downstream complete - for (auto& in_stop : in_stops_) { - if (!in_stop.second) return; - } - for (auto& out_buff : out_buffs_) { - auto used_size = out_buff.second.second; - if (used_size != 0) return; - } - - // send stop to downstream - for (auto& out : out_buffs_) { - auto down_id = out.first; - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(down_id, stop); +void ComputeInterceptor::DecodeMsgVars(const InterceptorMessage& msg) { + int64_t scope_id = msg.scope_idx(); + PADDLE_ENFORCE_LT(scope_id, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + scope_id)); + auto* scope = microbatch_scopes_[scope_id]; + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + for (const auto& var_iter : msg.vars_list()) { + const std::string& name = var_iter.name(); + auto& dev_ctx = *pool.Get(place_); + std::istringstream ss(var_iter.stensor()); + auto* var = scope->Var(name); + auto* tensor = var->GetMutable(); + DeserializeFromStream(ss, tensor, dev_ctx); + + VLOG(3) << "Set vars " << name << " with value in scope " << scope_id + << " with dims " << tensor->dims() << " with dtype " + << tensor->dtype(); } - stop_ = true; } void ComputeInterceptor::Compute(const InterceptorMessage& msg) { if (msg.message_type() == DATA_IS_READY) { - IncreaseReady(msg.src_id()); + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive data_is_ready " << msg.src_id() << " " + << msg.scope_idx() << " "; + IncreaseReady(msg.src_id(), msg.scope_idx()); Run(); } else if (msg.message_type() == DATA_IS_USELESS) { + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive data_is_useless " << msg.src_id() << " " + << msg.scope_idx() << " "; DecreaseBuff(msg.src_id()); Run(); - } else if (msg.message_type() == STOP) { - ReceivedStop(msg.src_id()); + } else if (msg.message_type() == DATA_WITH_VARS) { + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive data_with_vars " << msg.src_id() << " " + << msg.scope_idx() << " "; + DecodeMsgVars(msg); + IncreaseReady(msg.src_id(), msg.scope_idx()); + Run(); + } else if (msg.message_type() == START_LOOP) { + VLOG(3) << "Compute interceptor " << interceptor_id_ + << " receive start_loop " << msg.src_id() << " " << msg.scope_idx() + << " "; + IncreaseReady(msg.src_id(), msg.scope_idx()); + scope_id_to_finish_flag_.emplace(msg.scope_idx(), false); + Run(); } - - TryStop(); } REGISTER_INTERCEPTOR(Compute, ComputeInterceptor); diff --git a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h index fb82ce76c7b..07e0dd5b025 100644 --- a/paddle/fluid/distributed/fleet_executor/compute_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/compute_interceptor.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include "paddle/fluid/distributed/fleet_executor/interceptor.h" @@ -21,6 +22,8 @@ namespace paddle { namespace distributed { +const int64_t INFINITE_BUFFER_SIZE = -1; + class ComputeInterceptor : public Interceptor { public: ComputeInterceptor(int64_t interceptor_id, TaskNode* node); @@ -29,33 +32,27 @@ class ComputeInterceptor : public Interceptor { virtual void RunOps(); virtual void SendDataReadyToDownStream(); virtual void ReplyCompletedToUpStream(); + virtual void Compute(const InterceptorMessage& msg); + void Run(); + void IncreaseReady(int64_t up_id, int64_t scope_id); + void DecreaseBuff(int64_t down_id); + + int64_t cur_scope_id_; - int64_t step_{0}; + // upstream_id-->(max_ready_size, scope-->ready_size) + std::map>> + in_readys_{}; + // downstream_id-->(max_buffer_size, used_size) + std::map> out_buffs_{}; private: void PrepareDeps(); + InterceptorMessage PrepareVarsMsg(); + void DecodeMsgVars(const InterceptorMessage& msg); - void IncreaseReady(int64_t up_id); - void DecreaseBuff(int64_t down_id); bool IsInputReady(); bool CanWriteOutput(); - - void Run(); - void Compute(const InterceptorMessage& msg); - - void ReceivedStop(int64_t up_id); - void TryStop(); - - bool is_source_{false}; - bool is_last_{false}; - - // upstream_id-->(max_ready_size, ready_size) - std::map> in_readys_{}; - // downstream_id-->(max_buffer_size, used_size) - std::map> out_buffs_{}; - - bool received_stop_{false}; - std::map in_stops_{}; + std::map scope_id_to_finish_flag_; }; } // namespace distributed diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc new file mode 100644 index 00000000000..d3412a2443f --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.cc @@ -0,0 +1,167 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/cond_interceptor.h" +#include +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/executor_gc_helper.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/platform/errors.h" +#include "paddle/fluid/platform/place.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +CondInterceptor::CondInterceptor(int64_t interceptor_id, TaskNode* node) + : Interceptor(interceptor_id, node) { + PrepareDeps(); + RegisterMsgHandle([this](const InterceptorMessage& msg) { Run(msg); }); +} + +void CondInterceptor::PrepareDeps() { + auto& upstream = node_->upstream(); + auto& downstream = node_->downstream(); + auto& id_to_dep_type = node_->id_to_dep_type(); + + for (const auto& up : upstream) { + if (id_to_dep_type.at(up.first) == DependType::NORMAL) { + normal_in_id_.insert(up.first); + } else if (id_to_dep_type.at(up.first) == DependType::LOOP) { + loop_id_ = up.first; + } + } + + for (const auto& down : downstream) { + if (id_to_dep_type.at(down.first) == DependType::NORMAL) { + normal_out_id_.insert(down.first); + } else if (id_to_dep_type.at(down.first) == DependType::STOP_LOOP) { + stop_loop_id_ = down.first; + } + } +} + +bool CondInterceptor::GetCondResult() { + PADDLE_ENFORCE_LT(cur_scope_id_, + microbatch_scopes_.size(), + platform::errors::InvalidArgument( + "Step out of range. There are %ld " + "microbatch_scopes, but recevice scope index %ld", + microbatch_scopes_.size(), + cur_scope_id_)); + auto* cond_var = + microbatch_scopes_[cur_scope_id_]->FindVar(node_->cond_var()); + PADDLE_ENFORCE(cond_var, + platform::errors::NotFound( + "Condition variable %s not exists in scope %ld", + node_->cond_var(), + cur_scope_id_)); + const auto& cond_tensor = cond_var->Get(); + bool res = false; + if (platform::is_gpu_place(cond_tensor.place())) { +#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) + phi::DenseTensor cpu_tensor; + framework::TensorCopy(cond_tensor, platform::CPUPlace(), &cpu_tensor); + platform::DeviceContextPool::Instance().Get(cond_tensor.place())->Wait(); + res = cpu_tensor.data()[0]; +#endif + } else if (platform::is_cpu_place(cond_tensor.place())) { + res = cond_tensor.data()[0]; + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Unsupport device for cond interceptor.")); + } + return res; +} + +void CondInterceptor::SendDataReady(int64_t down_id) { + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(cur_scope_id_); + Send(down_id, ready_msg); +} + +void CondInterceptor::SendStartLoop(int64_t down_id) { + InterceptorMessage ready_msg; + ready_msg.set_message_type(START_LOOP); + ready_msg.set_scope_idx(cur_scope_id_); + Send(down_id, ready_msg); +} + +void CondInterceptor::ReplyDataIsUseless(int64_t up_id) { + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_USELESS); + ready_msg.set_scope_idx(cur_scope_id_); + Send(up_id, ready_msg); +} + +void CondInterceptor::Compute() { + bool cond = GetCondResult(); + VLOG(3) << "Cond interceptor get condition var " << node_->cond_var() + << " with value " << cond; + if (cond) { + VLOG(3) << "Loop again in scope " << cur_scope_id_; + for (auto& down_id : normal_out_id_) { + SendStartLoop(down_id); + } + ++num_of_scopes_; + } else { + VLOG(3) << "Finish loop in scope " << cur_scope_id_; + SendDataReady(stop_loop_id_); + } +} + +void CondInterceptor::Run(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY || + msg.message_type() == DATA_WITH_VARS) { + if (msg.src_id() == loop_id_) { + --num_of_scopes_; + VLOG(3) << "Receving loop again message from " << msg.src_id() + << " waiting other " << num_of_scopes_ << " scopes ready"; + ready_scope_id_.emplace_back(msg.scope_idx()); + if (num_of_scopes_ == 0) { + std::sort(ready_scope_id_.begin(), ready_scope_id_.end()); + for (auto scope_id : ready_scope_id_) { + VLOG(3) << "Start a new loop in scope " << scope_id; + cur_scope_id_ = scope_id; + Compute(); + } + ready_scope_id_.clear(); + } + } else { + cur_scope_id_ = msg.scope_idx(); + Compute(); + } + } else if (msg.message_type() == DATA_IS_USELESS) { + if (node_->id_to_dep_type().at(msg.src_id()) == DependType::STOP_LOOP) { + for (auto& up_id : normal_in_id_) { + ReplyDataIsUseless(up_id); + } + // Gc the variable in while block + int64_t scope_id = msg.scope_idx(); + if (gc_) { + VLOG(3) << "Release vars in while block in scope " << scope_id; + framework::DeleteUnusedTensors(*microbatch_scopes_[scope_id], + node_->while_block_vars(), + gc_.get()); + } + } + } +} + +REGISTER_INTERCEPTOR(Cond, CondInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/cond_interceptor.h b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h new file mode 100644 index 00000000000..a69468b28b4 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/cond_interceptor.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/distributed/fleet_executor/interceptor.h" + +namespace paddle { +namespace distributed { + +/* Condition Interceptor + * This is a special interceptor and only one condition op in the task node. + * This interceptor has two downstreams, + * 1. If the program result is true, select one of the downstreams, otherwise + * select another. + * 2. Used to implement while op in program. + */ +class CondInterceptor final : public Interceptor { + public: + CondInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void PrepareDeps(); + void Run(const InterceptorMessage& msg); + void Compute(); + bool GetCondResult(); + void SendDataReady(int64_t down_id); + void SendStartLoop(int64_t down_id); + void ReplyDataIsUseless(int64_t up_id); + + int64_t cur_scope_id_; + + std::set normal_in_id_; + std::set normal_out_id_; + int64_t stop_loop_id_; + int64_t loop_id_; + int64_t num_of_scopes_{0}; + std::vector ready_scope_id_; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc index a2d2ecd9bbf..915b1f82804 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.cc +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.cc @@ -14,6 +14,8 @@ #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include +#include +#include #include "paddle/fluid/distributed/fleet_executor/global.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" @@ -24,6 +26,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/variable.h" namespace paddle { namespace distributed { @@ -51,40 +54,40 @@ FleetExecutor::~FleetExecutor() { } } -void FleetExecutor::Init( - const std::string& carrier_id, - const framework::ProgramDesc& program_desc, - framework::Scope* scope, - const platform::Place& place, - int64_t num_micro_batches, - const std::vector& task_nodes, - const std::unordered_map& task_id_to_rank, - const std::vector& inference_root_scope_vars) { - PADDLE_ENFORCE_GT(task_nodes.size(), - 0, - platform::errors::InvalidArgument( - "Fleet executor is inited with empty task node")); - // TODO(fleet_exe devs): the unused_vars should be got from run time graph - std::vector> ops; - for (auto task_node : task_nodes) { - for (auto op : task_node->ops()) { - ops.emplace_back(std::unique_ptr(op)); +namespace { +void GetSubBlockTask(const std::vector& tasks, + TaskNode* cur_task, + std::set* sub_block_task) { + auto& downstream = cur_task->downstream(); + auto& id_to_dep_type = cur_task->id_to_dep_type(); + for (auto& down : downstream) { + int64_t task_id = down.first; + if (id_to_dep_type.at(task_id) == DependType::NORMAL) { + for (const auto& task : tasks) { + if (task->task_id() == task_id) { + sub_block_task->emplace(task); + GetSubBlockTask(tasks, task, sub_block_task); + } + } } } - auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); - // NOTE: For inference, the vars in inference_root_scope_vars - // shouldn't be deleted during inf, for that they may be the result of the - // inf. If they are GCed, it will cause error during ZeroCopy the result. +} + +void PreventVarsDelete( + std::unordered_map>* unused_vars, + const std::vector& vars_not_gc) { std::vector changed_ops; - for (auto pair : unused_vars) { + + for (const auto& pair : *unused_vars) { const framework::OperatorBase* op = pair.first; - std::vector unused = pair.second; - for (auto name : inference_root_scope_vars) { - auto iter = std::find(unused.begin(), unused.end(), name); - if (iter != unused.end()) { + std::vector cur_unused = pair.second; + for (auto name : vars_not_gc) { + auto iter = std::find(cur_unused.begin(), cur_unused.end(), name); + if (iter != cur_unused.end()) { VLOG(3) << "Removing var: [" << name << "] from the unused vars list of op: [" << op->Type() << "]"; - unused.erase(iter); + cur_unused.erase(iter); if (std::find(changed_ops.begin(), changed_ops.end(), op) == changed_ops.end()) { // record the op whose unused vars have been updated @@ -93,28 +96,120 @@ void FleetExecutor::Init( } } // update the unused vars list in the map - unused_vars[op] = unused; + unused_vars->at(op) = cur_unused; } for (auto op : changed_ops) { - auto iter = unused_vars.find(op); + const auto& iter = unused_vars->find(op); if (iter->second.empty()) { // remove those ops in the map that have empty unused vars list VLOG(3) << "Removing op: [" << op->Type() << "] from unused_vars map."; - unused_vars.erase(iter); + unused_vars->erase(iter); + } + } +} + +std::vector GetUnusedVarsAfterWhile( + const framework::ProgramDesc& program_desc, + TaskNode* cond_task, + const std::vector& vars_not_gc) { + // NOTE: Since while op won't appear in task node, in order to analyze + // the vars which should be free after calling while op, we rebuild the + // whole program and get the unused vars after calling while op. + // The vars in while block should not be free until the while op is finished. + // In a word, the vars need to be free after while op is: + // 1. Vars in parent block and being used in while block. + // 2. Local vars only defined in while block. + // The unused vars above will be free in cond interceptor. + std::vector while_block_vars; + std::vector> ops; + for (const auto& desc : program_desc.Block(0).AllOps()) { + ops.emplace_back(framework::OpRegistry::CreateOp(*desc)); + } + auto unused_vars = framework::GetUnusedVars(program_desc.Block(0), ops, {}); + PreventVarsDelete(&unused_vars, vars_not_gc); + for (const auto& pair : unused_vars) { + if (pair.first->Type() == "while") { + for (const auto& var_name : pair.second) { + while_block_vars.emplace_back(var_name); + } + for (auto& var : program_desc.Block(1).AllVars()) { + while_block_vars.emplace_back(var->Name()); + } + } + } + return while_block_vars; +} + +} // namespace + +void FleetExecutor::Init( + const std::string& carrier_id, + const framework::ProgramDesc& program_desc, + framework::Scope* scope, + const platform::Place& place, + int64_t num_micro_batches, + const std::vector& task_nodes, + const std::unordered_map& task_id_to_rank, + const std::vector& inference_root_scope_vars, + const std::vector& micro_scope_list) { + PADDLE_ENFORCE_GT(task_nodes.size(), + 0, + platform::errors::InvalidArgument( + "Fleet executor is inited with empty task node")); + // Set the unused var after running while op + std::set sub_block_tasks; + std::vector while_block_vars; + for (const auto& task_node : task_nodes) { + if (task_node->type() == "Cond") { + GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); + while_block_vars = GetUnusedVarsAfterWhile( + program_desc, task_node, inference_root_scope_vars); + VLOG(3) << "Vars will be gced after while op"; + for (auto var : while_block_vars) { + VLOG(3) << var; + } + task_node->SetWhileBlockVars(while_block_vars); + } + } + std::vector sub_block_ops; + for (const auto& task_node : sub_block_tasks) { + for (const auto& op : task_node->ops()) { + sub_block_ops.emplace_back(op); } } + // Analyse the unused vars in block 0. The operators in block 1 + // should be passed in first for prevent vars been released but removed soon. + // Since the unused vars in block 1 need to analyse separately. + std::vector> ops; + for (const auto& task_node : task_nodes) { + for (const auto& op : task_node->ops()) { + ops.emplace_back(std::unique_ptr(op)); + } + } + auto global_unused_vars = + framework::GetUnusedVars(program_desc.Block(0), ops, {}); + + for (auto& unique_op : ops) { + unique_op.release(); + } + + // NOTE: For inference, the vars in inference_root_scope_vars + // shouldn't be deleted during inf, for that they may be the result of the + // inf. If they are GCed, it will cause error during ZeroCopy the result. + PreventVarsDelete(&global_unused_vars, inference_root_scope_vars); + runtime_graph_ = std::make_shared(); std::unordered_map interceptor_id_to_task; for (auto task_node : task_nodes) { - task_node->SetUnusedVars(unused_vars); + if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { + task_node->SetUnusedVars(global_unused_vars); + } int64_t interceptor_id = task_node->task_id(); interceptor_id_to_task.emplace(interceptor_id, task_node); } runtime_graph_->SetInterceptorIdToRank(task_id_to_rank); runtime_graph_->SetInterceptorIdToNode(interceptor_id_to_task); - for (auto& unique_op : ops) { - unique_op.release(); - } + VLOG(5) << runtime_graph_->DebugString(); Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); @@ -126,7 +221,8 @@ void FleetExecutor::Init( place, num_micro_batches, program_desc, - inference_root_scope_vars); + inference_root_scope_vars, + micro_scope_list); GlobalVal::Get()->Barrier(); } @@ -136,7 +232,8 @@ void FleetExecutor::InitCarrier( const platform::Place& place, int64_t num_micro_batches, const framework::ProgramDesc& program_desc, - const std::vector& inference_root_scope_vars) { + const std::vector& inference_root_scope_vars, + const std::vector& micro_scope_list) { carrier->Init(exe_desc_.cur_rank(), runtime_graph_->interceptor_id_to_rank(), runtime_graph_->interceptor_id_to_node(), @@ -144,7 +241,8 @@ void FleetExecutor::InitCarrier( scope, num_micro_batches, place, - inference_root_scope_vars); + inference_root_scope_vars, + micro_scope_list); } void FleetExecutor::InitMessageBus() { diff --git a/paddle/fluid/distributed/fleet_executor/fleet_executor.h b/paddle/fluid/distributed/fleet_executor/fleet_executor.h index f633dbbc360..e8123bea1e1 100644 --- a/paddle/fluid/distributed/fleet_executor/fleet_executor.h +++ b/paddle/fluid/distributed/fleet_executor/fleet_executor.h @@ -18,6 +18,7 @@ #include "paddle/fluid/distributed/fleet_executor/carrier.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" +#include "paddle/fluid/framework/variable.h" #include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/place.h" @@ -45,7 +46,8 @@ class FleetExecutor final { int64_t num_micro_batches, const std::vector& task_nodes, const std::unordered_map& task_id_to_rank, - const std::vector& inference_root_scope_vars = {}); + const std::vector& inference_root_scope_vars = {}, + const std::vector& micro_scope_list = {}); void Run(const std::string& carrier_id); private: @@ -57,7 +59,8 @@ class FleetExecutor final { const platform::Place& place, int64_t num_micro_batches, const framework::ProgramDesc& program_desc, - const std::vector& inference_root_scope_vars = {}); + const std::vector& inference_root_scope_vars = {}, + const std::vector& micro_scope_list = {}); FleetExecutorDesc exe_desc_; std::shared_ptr runtime_graph_; std::unordered_set carrier_ids_; diff --git a/paddle/fluid/distributed/fleet_executor/interceptor.h b/paddle/fluid/distributed/fleet_executor/interceptor.h index 6a761072027..2c20e1ad611 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/interceptor.h @@ -93,7 +93,6 @@ class Interceptor { TaskNode* node_; // for stop - bool stop_{false}; void StopCarrier(); // for runtime @@ -114,9 +113,6 @@ class Interceptor { std::mutex mutex_; std::deque messages_; - - int64_t already_run_times_{0}; - int64_t used_slot_nums_{0}; }; class InterceptorFactory { diff --git a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto index 8508bc35f29..4db5a72d897 100644 --- a/paddle/fluid/distributed/fleet_executor/interceptor_message.proto +++ b/paddle/fluid/distributed/fleet_executor/interceptor_message.proto @@ -24,6 +24,21 @@ enum MessageType { ERR = 4; // current Interceptor encounters error RESET = 5; // reset the status START = 6; + DATA_WITH_VARS = 7; + START_LOOP = 8; +} + +enum ValueType { + INT3 = 0; + INT6 = 1; + FLOAT = 2; + DOUBLE = 3; + BOOL = 4; +} + +message VarList { + required string name = 1; + required string stensor = 2; } message InterceptorMessage { @@ -32,6 +47,7 @@ message InterceptorMessage { optional MessageType message_type = 3 [ default = RESET ]; optional bool ctrl_message = 4 [ default = false ]; optional int64 scope_idx = 5 [ default = 0 ]; + repeated VarList vars_list = 6; } message InterceptorResponse { optional bool rst = 1 [ default = false ]; } diff --git a/paddle/fluid/distributed/fleet_executor/sink_interceptor.h b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h index cb1d698a785..1abb7a641e2 100644 --- a/paddle/fluid/distributed/fleet_executor/sink_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/sink_interceptor.h @@ -25,7 +25,7 @@ namespace distributed { * 1. record the num of micro-step * 2. check whether to notify carrier the current step is finished */ -class SinkInterceptor : public Interceptor { +class SinkInterceptor final : public Interceptor { public: SinkInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/source_interceptor.h b/paddle/fluid/distributed/fleet_executor/source_interceptor.h index f8b18fb1848..95e8c1b3b03 100644 --- a/paddle/fluid/distributed/fleet_executor/source_interceptor.h +++ b/paddle/fluid/distributed/fleet_executor/source_interceptor.h @@ -25,7 +25,7 @@ namespace distributed { * 1. receive `start` message from carrier * 2. send num_of_steps `data_is_ready` message to downstream */ -class SourceInterceptor : public Interceptor { +class SourceInterceptor final : public Interceptor { public: SourceInterceptor(int64_t interceptor_id, TaskNode* node); diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.cc b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc new file mode 100644 index 00000000000..b9ce4fabed4 --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.cc @@ -0,0 +1,115 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/distributed/fleet_executor/start_interceptor.h" + +#include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/core/errors.h" + +namespace paddle { +namespace distributed { + +StartInterceptor::StartInterceptor(int64_t interceptor_id, TaskNode* node) + : ComputeInterceptor(interceptor_id, node) { + auto& downstream = node_->downstream(); + PADDLE_ENFORCE_EQ( + downstream.size(), + 1, + platform::errors::OutOfRange( + "The downstream for StartInterceptor only support 1 for now.")); + for (auto down : downstream) { + batch_size_ = down.second; + } + bool evenly_divisible = ((node_->max_run_times() % batch_size_) == 0); + PADDLE_ENFORCE( + evenly_divisible, + platform::errors::Fatal( + "Wrong config: Num of step should be divided by batch_size," + "num_step=%lld, batch_size=%lld", + node_->max_run_times(), + batch_size_)); +} + +void StartInterceptor::RunOps() { + finish_count_++; + ComputeInterceptor::RunOps(); +} + +void StartInterceptor::SendDataReadyToDownStream() { + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + auto max_buff_size = outs.second.first; + auto used_size = outs.second.second; + used_size += 1; + if (max_buff_size != INFINITE_BUFFER_SIZE) { + PADDLE_ENFORCE_LE( + used_size, + max_buff_size, + platform::errors::OutOfRange("downstream=%lld used buff size must <= " + "max_buff_size, but now used_size=%lld, " + "max_buff_size=%lld", + down_id, + used_size, + max_buff_size)); + } + outs.second.second = used_size; + } + if (finish_count_ == batch_size_) { + for (int64_t i = 0; i < batch_size_; ++i) { + int64_t scope_id = step_ % node_->max_run_times(); + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + InterceptorMessage ready_msg; + ready_msg.set_message_type(DATA_IS_READY); + ready_msg.set_scope_idx(scope_id); + VLOG(3) << "StartInterceptor " << interceptor_id_ + << " Send data_is_ready msg to " << down_id + << " in scope: " << scope_id; + Send(down_id, ready_msg); + } + step_++; + } + } +} + +void StartInterceptor::Compute(const InterceptorMessage& msg) { + if (msg.message_type() == DATA_IS_READY) { + VLOG(3) << "Start interceptor " << interceptor_id_ + << " receive data_is_ready " << msg.src_id() << " " + << msg.scope_idx() << " "; + IncreaseReady(msg.src_id(), msg.scope_idx()); + Run(); + } else if (msg.message_type() == DATA_IS_USELESS) { + VLOG(3) << "Start interceptor receive data_is_useless " << msg.src_id() + << " " << finish_count_; + finish_count_--; + if (finish_count_ == 0) { + for (int64_t i = 0; i < batch_size_; ++i) { + for (auto& outs : out_buffs_) { + auto down_id = outs.first; + DecreaseBuff(down_id); + } + } + for (int64_t i = 0; i < batch_size_; ++i) { + Run(); + } + } + } +} + +REGISTER_INTERCEPTOR(Start, StartInterceptor); + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/start_interceptor.h b/paddle/fluid/distributed/fleet_executor/start_interceptor.h new file mode 100644 index 00000000000..f082c48922b --- /dev/null +++ b/paddle/fluid/distributed/fleet_executor/start_interceptor.h @@ -0,0 +1,39 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/distributed/fleet_executor/compute_interceptor.h" + +namespace paddle { +namespace distributed { + +class StartInterceptor final : public ComputeInterceptor { + public: + StartInterceptor(int64_t interceptor_id, TaskNode* node); + + private: + void SendDataReadyToDownStream() override; + void RunOps() override; + void Compute(const InterceptorMessage& msg) override; + + int64_t batch_size_{0}; + int64_t finish_count_{0}; + int64_t step_{0}; +}; + +} // namespace distributed +} // namespace paddle diff --git a/paddle/fluid/distributed/fleet_executor/task_node.cc b/paddle/fluid/distributed/fleet_executor/task_node.cc index 341ffe290a5..60d21986580 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.cc +++ b/paddle/fluid/distributed/fleet_executor/task_node.cc @@ -24,33 +24,14 @@ namespace { using OperatorBase = TaskNode::OperatorBase; } -TaskNode::TaskNode(paddle::framework::ProgramDesc* program, - int64_t rank, - int64_t max_run_times, - int64_t max_slot_nums) - : program_(program), - rank_(rank), - max_run_times_(max_run_times), - max_slot_nums_(max_slot_nums) { - // Should be serially invoked, not thread-safe - // NOTE: when instantiate TaskNode with program, won't init task node - // immediately, since the provided program may be updated later (with - // high probability) by adding_feed_fetch_ops or by RuntimeGraph. - // So, delay the init part to the Init() function. - static int64_t task_node_cnt = 0; - task_id_ = task_node_cnt++; -} - TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums) + int64_t max_run_times) : program_(program), rank_(rank), task_id_(task_id), - max_run_times_(max_run_times), - max_slot_nums_(max_slot_nums) { + max_run_times_(max_run_times) { // TODO(liyurui): Will be removed when execute program is supported. Init(); } @@ -58,7 +39,6 @@ TaskNode::TaskNode(paddle::framework::ProgramDesc* program, TaskNode::TaskNode(paddle::framework::ProgramDesc* program, int64_t rank) : program_(program), rank_(rank), task_id_(rank) { max_run_times_ = 1; - max_slot_nums_ = 1; LOG(INFO) << "Constructing TaskNode for DistModelInf. The TaskNode's id is: " << rank @@ -69,6 +49,16 @@ void TaskNode::SetProgram(paddle::framework::ProgramDesc* program) { program_ = program; } +void TaskNode::SetVarsToDtype( + const std::map& vars_to_dtype) { + vars_to_dtype_ = vars_to_dtype; +} + +void TaskNode::SetVarsToShape( + const std::map>& vars_to_shape) { + vars_to_shape_ = vars_to_shape; +} + void TaskNode::Init(bool use_feed_fetch_ops) { if (!use_feed_fetch_ops) { VLOG(3) << "TaskNode will be inited without feed and fetch ops"; @@ -98,13 +88,11 @@ TaskNode::TaskNode(int32_t role, const std::vector& op_descs, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums) + int64_t max_run_times) : role_(role), rank_(rank), task_id_(task_id), - max_run_times_(max_run_times), - max_slot_nums_(max_slot_nums) { + max_run_times_(max_run_times) { if (op_descs.empty()) { return; } @@ -121,33 +109,35 @@ TaskNode::TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums) + int64_t max_run_times) : ops_(ops), role_(role), rank_(rank), task_id_(task_id), - max_run_times_(max_run_times), - max_slot_nums_(max_slot_nums) {} + max_run_times_(max_run_times) {} TaskNode::TaskNode(int32_t role, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums) + int64_t max_run_times) : role_(role), rank_(rank), task_id_(task_id), - max_run_times_(max_run_times), - max_slot_nums_(max_slot_nums) {} + max_run_times_(max_run_times) {} -bool TaskNode::AddUpstreamTask(int64_t task_id, int64_t buff_size) { +bool TaskNode::AddUpstreamTask(int64_t task_id, + int64_t buff_size, + DependType type) { const auto& ret = upstream_.emplace(task_id, buff_size); + id_to_dep_type_.emplace(task_id, type); return ret.second; } -bool TaskNode::AddDownstreamTask(int64_t task_id, int64_t buff_size) { +bool TaskNode::AddDownstreamTask(int64_t task_id, + int64_t buff_size, + DependType type) { const auto& ret = downstream_.emplace(task_id, buff_size); + id_to_dep_type_.emplace(task_id, type); return ret.second; } diff --git a/paddle/fluid/distributed/fleet_executor/task_node.h b/paddle/fluid/distributed/fleet_executor/task_node.h index 8538ac9ff81..181ab96c242 100644 --- a/paddle/fluid/distributed/fleet_executor/task_node.h +++ b/paddle/fluid/distributed/fleet_executor/task_node.h @@ -14,8 +14,10 @@ #pragma once #include +#include #include #include +#include #include #include @@ -29,38 +31,30 @@ class OpDesc; } // namespace framework namespace distributed { +enum class DependType { NORMAL, LOOP, STOP_LOOP }; + class TaskNode final { public: using OperatorBase = paddle::framework::OperatorBase; TaskNode(int64_t rank, int64_t task_id, int64_t max_run_times); - TaskNode(int32_t role, - int64_t rank, - int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums); + TaskNode(int32_t role, int64_t rank, int64_t task_id, int64_t max_run_times); TaskNode(int32_t role, const std::vector& op_descs, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums); + int64_t max_run_times); TaskNode(int32_t role, const std::vector& ops, int64_t rank, int64_t task_id, - int64_t max_run_times, - int64_t max_slot_nums); - TaskNode(paddle::framework::ProgramDesc* program, - int64_t rank, - int64_t max_run_times, - int64_t max_slot_nums); + int64_t max_run_times); TaskNode(paddle::framework::ProgramDesc* program, int64_t rank); // TODO(liyurui): This will be the only constructor for task node TaskNode(paddle::framework::ProgramDesc* program, int64_t task_id, int64_t rank, - int64_t max_run_times, - int64_t max_slot_nums); + int64_t max_run_times); + ~TaskNode() = default; void SetProgram(paddle::framework::ProgramDesc* program); @@ -69,11 +63,11 @@ class TaskNode final { int64_t task_id() const { return task_id_; } int32_t role() const { return role_; } int64_t max_run_times() const { return max_run_times_; } - int64_t max_slot_nums() const { return max_slot_nums_; } int64_t run_per_steps() const { return run_per_steps_; } int64_t run_at_offset() const { return run_at_offset_; } int64_t reply_up_per_steps() const { return reply_up_per_steps_; } int64_t send_down_per_steps() const { return send_down_per_steps_; } + const std::string& cond_var() const { return cond_var_; } const std::unordered_map& upstream() const { return upstream_; } @@ -86,11 +80,20 @@ class TaskNode final { const std::vector>& unique_ops() const { return ops_vec_; } + const std::unordered_map id_to_dep_type() const { + return id_to_dep_type_; + } const std::unordered_map>& unused_vars() const { return unused_vars_; } + const std::vector while_block_vars() const { + return while_block_vars_; + } + void SetCondVarName(const std::string& cond_var_name) { + cond_var_ = cond_var_name; + } void SetRunPerSteps(int64_t value); void SetRunAtOffset(int64_t value); void SetReplyUpPerSteps(int64_t value); @@ -101,11 +104,27 @@ class TaskNode final { unused_vars) { unused_vars_ = unused_vars; } + void SetWhileBlockVars(const std::vector& vars) { + while_block_vars_ = vars; + } // upstream need buffs? - bool AddUpstreamTask(int64_t task_id, int64_t buff_size = 1); - bool AddDownstreamTask(int64_t task_id, int64_t buff_size = 1); + bool AddUpstreamTask(int64_t task_id, + int64_t buff_size = 1, + DependType type = DependType::NORMAL); + bool AddDownstreamTask(int64_t task_id, + int64_t buff_size = 1, + DependType type = DependType::NORMAL); std::string DebugString() const; + const std::map& vars_to_dtype() const { + return vars_to_dtype_; + } + void SetVarsToDtype(const std::map& vars_to_dtype); + const std::map>& vars_to_shape() const { + return vars_to_shape_; + } + void SetVarsToShape( + const std::map>& vars_to_shape); private: DISABLE_COPY_AND_ASSIGN(TaskNode); @@ -115,16 +134,22 @@ class TaskNode final { // task_id-->buff_size std::unordered_map upstream_; std::unordered_map downstream_; + // task_id-->type + std::unordered_map id_to_dep_type_; + framework::ProgramDesc* program_; + std::string cond_var_; std::vector> ops_vec_; std::unordered_map> unused_vars_; + std::vector while_block_vars_; + std::map vars_to_dtype_; + std::map> vars_to_shape_; int32_t role_; int64_t rank_; int64_t task_id_; int64_t max_run_times_; - int64_t max_slot_nums_; int64_t run_per_steps_{1}; int64_t run_at_offset_{0}; diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc index 86d0609ce09..ace89d63c5e 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_run_op_test.cc @@ -77,9 +77,8 @@ TEST(ComputeInterceptor, Compute) { // FIXME: don't delete, otherwise interceptor will use undefined node TaskNode* source = new TaskNode(0, SOURCE_ID, 2); // rank, task_id, max_run_times - TaskNode* node_a = - new TaskNode(0, ops, 0, 0, 2, 0); // role, ops, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 2, 0); + TaskNode* node_a = new TaskNode(0, ops, 0, 0, 2); // role, ops, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 2); TaskNode* sink = new TaskNode(0, SINK_ID, 2); // source->a->b->sink diff --git a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc index 4992a8b34c9..1a4f3f2ce9a 100644 --- a/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/compute_interceptor_test.cc @@ -21,61 +21,49 @@ limitations under the License. */ #include "paddle/fluid/distributed/fleet_executor/interceptor.h" #include "paddle/fluid/distributed/fleet_executor/message_bus.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/phi/core/kernel_registry.h" namespace paddle { namespace distributed { -class StartInterceptor : public Interceptor { - public: - StartInterceptor(int64_t interceptor_id, TaskNode* node) - : Interceptor(interceptor_id, node) { - RegisterMsgHandle([this](const InterceptorMessage& msg) { NOP(msg); }); - } - - void NOP(const InterceptorMessage& msg) { - if (msg.message_type() == STOP) { - stop_ = true; - InterceptorMessage stop; - stop.set_message_type(STOP); - Send(1, stop); // stop 1, compute - return; - } - std::cout << GetInterceptorId() << " recv msg from " << msg.src_id() - << std::endl; - } -}; - TEST(ComputeInterceptor, Compute) { std::string carrier_id = "0"; Carrier* carrier = GlobalMap::Create(carrier_id, carrier_id); - carrier->Init(0, {{0, 0}, {1, 0}, {2, 0}}); + carrier->Init(0, {{SOURCE_ID, 0}, {0, 0}, {1, 0}, {SINK_ID, 0}}); MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); - TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); - - // a->b->c + TaskNode* source = + new TaskNode(0, SOURCE_ID, 3); // rank, task_id, max_run_times + TaskNode* node_a = new TaskNode(0, 0, 0, 3); + TaskNode* node_b = new TaskNode(0, 0, 1, 3); + TaskNode* sink = new TaskNode(0, SINK_ID, 3); + + // source->a->b->sink + source->AddDownstreamTask(0); + node_a->AddUpstreamTask(SOURCE_ID); node_a->AddDownstreamTask(1, 3); node_b->AddUpstreamTask(0, 3); - node_b->AddDownstreamTask(2); - node_c->AddUpstreamTask(1); + node_b->AddDownstreamTask(SINK_ID); + sink->AddUpstreamTask(1); - Interceptor* a = - carrier->SetInterceptor(0, std::make_unique(0, node_a)); + carrier->SetInterceptor( + SOURCE_ID, InterceptorFactory::Create("Source", SOURCE_ID, source)); + carrier->SetInterceptor(0, InterceptorFactory::Create("Compute", 0, node_a)); carrier->SetInterceptor(1, InterceptorFactory::Create("Compute", 1, node_b)); - carrier->SetInterceptor(2, InterceptorFactory::Create("Compute", 2, node_c)); + carrier->SetInterceptor(SINK_ID, + InterceptorFactory::Create("Sink", SINK_ID, sink)); + // start InterceptorMessage msg; - msg.set_message_type(DATA_IS_READY); - // test run three times - a->Send(1, msg); - a->Send(1, msg); - a->Send(1, msg); + msg.set_message_type(START); + msg.set_dst_id(SOURCE_ID); + carrier->EnqueueInterceptorMessage(msg); carrier->Wait(); carrier->Release(); diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc index 54adf06fb67..f43f3860199 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_test.cc @@ -33,7 +33,6 @@ class PingPongInterceptor : public Interceptor { void PingPong(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { - stop_ = true; return; } std::cout << GetInterceptorId() << " recv msg, count=" << count_ diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc index 3828c4478cb..62c23068d7d 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_ping_pong_with_brpc_test.cc @@ -36,7 +36,6 @@ class PingPongInterceptor : public Interceptor { void PingPong(const InterceptorMessage& msg) { if (msg.message_type() == STOP) { - stop_ = true; StopCarrier(); return; } diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc index 3415e377478..12fc77a2717 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_long_path_test.cc @@ -66,17 +66,17 @@ TEST(AmplifierInterceptor, Amplifier) { MessageBus* msg_bus = GlobalVal::Create(); msg_bus->Init(0, {{0, "127.0.0.0:0"}}, "127.0.0.0:0"); - int64_t micro_steps = 3; + int64_t micro_steps = 1; // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* source = new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times - TaskNode* node_a = new TaskNode(0, 0, 0, 1, 0); // role, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 1, 0); - TaskNode* node_c = new TaskNode(0, 0, 2, 1, 0); - TaskNode* node_d = new TaskNode(0, 0, 3, 1, 0); - TaskNode* node_e = new TaskNode(0, 0, 4, 1, 0); - TaskNode* node_f = new TaskNode(0, 0, 5, 1, 0); + TaskNode* node_a = new TaskNode(0, 0, 0, 1); // role, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, 1); + TaskNode* node_c = new TaskNode(0, 0, 2, 1); + TaskNode* node_d = new TaskNode(0, 0, 3, 1); + TaskNode* node_e = new TaskNode(0, 0, 4, 1); + TaskNode* node_f = new TaskNode(0, 0, 5, 1); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); // source->a->b->c->d->e->f->sink diff --git a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc index fdee01fed1a..4a29f07db5b 100644 --- a/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/interceptor_pipeline_short_path_test.cc @@ -83,11 +83,10 @@ TEST(AmplifierInterceptor, Amplifier) { // NOTE: don't delete, otherwise interceptor will use undefined node TaskNode* source = new TaskNode(0, SOURCE_ID, micro_steps); // rank, task_id, max_run_times - TaskNode* node_a = - new TaskNode(0, 0, 0, micro_steps, 0); // role, rank, task_id - TaskNode* node_b = new TaskNode(0, 0, 1, 3, 0); - TaskNode* node_c = new TaskNode(0, 0, 2, 3, 0); - TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps, 0); + TaskNode* node_a = new TaskNode(0, 0, 0, micro_steps); // role, rank, task_id + TaskNode* node_b = new TaskNode(0, 0, 1, micro_steps); + TaskNode* node_c = new TaskNode(0, 0, 2, micro_steps); + TaskNode* node_d = new TaskNode(0, 0, 3, micro_steps); TaskNode* sink = new TaskNode(0, SINK_ID, micro_steps); // source->a->b->c->d->sink diff --git a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc index 879d7e9b029..b2b1d06634b 100644 --- a/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/sink_interceptor_test.cc @@ -62,10 +62,9 @@ TEST(SourceInterceptor, Source) { msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = - new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id - TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3, 0); // role, rank, task_id + TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id + TaskNode* sink = new TaskNode(0, SINK_ID, 0, 3); // role, rank, task_id source->AddDownstreamTask(0, 1); node_a->AddUpstreamTask(SOURCE_ID, 1); diff --git a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc index 21a1b4accc9..a707650dfbc 100644 --- a/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc +++ b/paddle/fluid/distributed/fleet_executor/test/source_interceptor_test.cc @@ -61,9 +61,8 @@ TEST(SourceInterceptor, Source) { msg_bus->Init(0, {{0, "127.0.0.0:0"}}, ""); // NOTE: don't delete, otherwise interceptor will use undefined node - TaskNode* source = - new TaskNode(0, SOURCE_ID, 0, 3, 0); // role, rank, task_id - TaskNode* node_a = new TaskNode(0, 0, 0, 3, 0); // role, rank, task_id + TaskNode* source = new TaskNode(0, SOURCE_ID, 0, 3); // role, rank, task_id + TaskNode* node_a = new TaskNode(0, 0, 0, 3); // role, rank, task_id source->AddDownstreamTask(0, 1); node_a->AddUpstreamTask(SOURCE_ID, 1); diff --git a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc index e43c67d7bf3..952fc4058cd 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op.cu.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op.cu.cc @@ -112,5 +112,7 @@ REGISTER_OP_CUDA_KERNEL(c_broadcast, ops::CBroadcastOpCUDAKernel, #endif ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel, + ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel, ops::CBroadcastOpCUDAKernel); diff --git a/paddle/fluid/operators/collective/c_embedding_op.cu b/paddle/fluid/operators/collective/c_embedding_op.cu index 53aef8e8357..cddbd162571 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cu +++ b/paddle/fluid/operators/collective/c_embedding_op.cu @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/fluid/platform/device/gpu/gpu_primitives.h" #include "paddle/fluid/platform/float16.h" +DECLARE_bool(cudnn_deterministic); + namespace paddle { namespace operators { @@ -83,6 +85,32 @@ __global__ void CEmbeddingGrad(T *table, } } +template +__global__ void CEmbeddingGradSerial(T *table, + const T *output, + const IndexT *ids, + const int rows, + const int columns, + const int64_t N, + const int64_t start_idx, + const int64_t end_idx, + const int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + if (i == 0) { + for (int j = 0; j < limit; j++) { + size_t row = j / columns; + size_t col = j % columns; + auto id = ids[row]; + if (id >= start_idx && id < end_idx) { + auto real_idx = id - start_idx; + paddle::platform::CudaAtomicAdd(&table[real_idx * columns + col], + output[i]); + } + } + } + } +} + template class CEmbeddingCUDAKernel : public framework::OpKernel { public: @@ -163,28 +191,56 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel { t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); - if (index_type == framework::proto::VarType::INT32) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); - } else if (index_type == framework::proto::VarType::INT64) { - CEmbeddingGrad - <<>>(d_table, - d_output, - ids_t->data(), - K, - D, - N, - start_idx, - end_idx, - limit); + if (FLAGS_cudnn_deterministic) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + blocks = 1; + if (index_type == framework::proto::VarType::INT32) { + CEmbeddingGradSerial + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + } else if (index_type == framework::proto::VarType::INT64) { + CEmbeddingGradSerial + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + } + } else { + if (index_type == framework::proto::VarType::INT32) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + } else if (index_type == framework::proto::VarType::INT64) { + CEmbeddingGrad + <<>>(d_table, + d_output, + ids_t->data(), + K, + D, + N, + start_idx, + end_idx, + limit); + } } } }; diff --git a/paddle/fluid/pybind/bind_fleet_executor.cc b/paddle/fluid/pybind/bind_fleet_executor.cc index b4a6432e9e5..616dbada4d2 100644 --- a/paddle/fluid/pybind/bind_fleet_executor.cc +++ b/paddle/fluid/pybind/bind_fleet_executor.cc @@ -65,6 +65,7 @@ struct npy_format_descriptor { namespace paddle { namespace pybind { +using paddle::distributed::DependType; using paddle::distributed::DistModel; using paddle::distributed::DistModelConfig; using paddle::distributed::DistModelDataBuf; @@ -164,18 +165,17 @@ void BindFleetExecutor(py::module* m) { .def( "run", &FleetExecutor::Run, py::call_guard()); + py::enum_(*m, "DependType") + .value("NORMAL", DependType::NORMAL) + .value("LOOP", DependType::LOOP) + .value("STOP_LOOP", DependType::STOP_LOOP); + py::class_(*m, "TaskNode") - .def(py::init()) .def(py::init()) .def(py::init&, int64_t, int64_t, - int64_t, int64_t>()) .def("task_id", &TaskNode::task_id) .def("add_upstream_task", &TaskNode::AddUpstreamTask) @@ -183,7 +183,10 @@ void BindFleetExecutor(py::module* m) { .def("set_run_pre_steps", &TaskNode::SetRunPerSteps) .def("set_run_at_offset", &TaskNode::SetRunAtOffset) .def("set_type", &TaskNode::SetType) + .def("set_cond_var_name", &TaskNode::SetCondVarName) .def("role", &TaskNode::role) + .def("set_vars_to_shape", &TaskNode::SetVarsToShape) + .def("set_vars_to_dtype", &TaskNode::SetVarsToDtype) .def("init", [](TaskNode& self) { self.Init(); }) .def("set_program", &TaskNode::SetProgram); diff --git a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu index 60672e46d8b..c0896d84bb0 100644 --- a/paddle/phi/kernels/gpu/embedding_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/embedding_grad_kernel.cu @@ -23,6 +23,8 @@ #include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/embedding_util.h" +DECLARE_bool(cudnn_deterministic); + namespace phi { template @@ -101,6 +103,12 @@ struct EmbeddingGradCUDAFunctor { const int gridx = 2 * dev_ctx_.GetSMCount(); dim3 threads(128, 8); dim3 grids(gridx, 1); + + if (FLAGS_cudnn_deterministic) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + grids.x = 1; + threads.y = 1; + } EmbeddingGrad<<>>( d_table, d_output, ids, N, K, D); } diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index f420a06cfbc..23f16489ad8 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -15,6 +15,7 @@ import copy from copy import deepcopy import time +from numpy import sort from paddle.fluid import core from paddle.fluid import framework @@ -59,7 +60,8 @@ def compute_compatible_process_mesh(process_mesh_list): compatible_result = None for process_mesh in process_mesh_list: compatible, compatible_result = _compute_compatible_process_mesh_two( - compatible_result, process_mesh) + compatible_result, process_mesh + ) if not compatible: return None return copy.deepcopy(compatible_result) @@ -82,7 +84,8 @@ def compute_compatible_dim_mapping(dim_mapping_list): compatible_result = -1 for mapping in dim_mapping_list: compatible, compatible_result = _compute_compatible_dim_mapping_two( - compatible_result, mapping) + compatible_result, mapping + ) if not compatible: return None return compatible_result @@ -90,7 +93,7 @@ def compute_compatible_dim_mapping(dim_mapping_list): def compute_compatible_dims_mapping(dims_mapping_list): """Compute the compatible dims mapping given a list of dims mapping. - Each of dims mapping is also a list. + Each of dims mapping is also a list. """ if not dims_mapping_list: return None @@ -103,7 +106,8 @@ def compute_compatible_dims_mapping(dims_mapping_list): compatible_result = [] for dim_mappings in zip(*dims_mapping_list): compatible_dim_mapping = compute_compatible_dim_mapping( - list(dim_mappings)) + list(dim_mappings) + ) if compatible_dim_mapping is None: return None compatible_result.append(compatible_dim_mapping) @@ -129,7 +133,8 @@ def _validate_dims_mapping(dims_mapping, process_mesh): return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len( - process_mesh.topology): + process_mesh.topology + ): return False for i in range(len(process_mesh.topology)): if dims_mapping.count(i) > 1: @@ -138,7 +143,6 @@ def _validate_dims_mapping(dims_mapping, process_mesh): class Completer: - def __init__(self, dist_context): assert dist_context is not None self._dist_context = dist_context @@ -150,12 +154,15 @@ class Completer: return False tensor_desc = tensor_node.var() # Skip reader tensor - if tensor_desc.type() == core.VarDesc.VarType.READER \ - or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + tensor_desc.type() == core.VarDesc.VarType.READER + or tensor_desc.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or tensor_desc.type == core.VarDesc.VarType.STEP_SCOPES + ): return False tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) + tensor_node + ) assert tensor_dist_attr is not None if tensor_dist_attr.is_annotated("dims_mapping"): return False @@ -164,48 +171,74 @@ class Completer: dims_mapping_list = [] for pred_op_node in tensor_node.inputs: if pred_op_node.op() is not None: - if pred_op_node.op().type() == "create_py_reader" \ - or pred_op_node.op().type() == "create_double_buffer_reader" \ - or pred_op_node.op().type() == "read": + if ( + pred_op_node.op().type() == "create_py_reader" + or pred_op_node.op().type() + == "create_double_buffer_reader" + or pred_op_node.op().type() == "read" + ): continue - op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( - pred_op_node) - if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_graph( + pred_op_node + ) + ) + if ( + op_dist_attr.process_mesh + == tensor_dist_attr.process_mesh + ): op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) + tensor_desc.name() + ) dims_mapping_list.append(op_dims_mapping) dims_mapping_list.append(tensor_dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if not _validate_dims_mapping(compatible_dims_mapping, - tensor_dist_attr.process_mesh): + dims_mapping_list + ) + if not _validate_dims_mapping( + compatible_dims_mapping, tensor_dist_attr.process_mesh + ): return False - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != tensor_dims_mapping + ): tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True else: dims_mapping_list = [] for succ_op_node in tensor_node.outputs: if succ_op_node.op() is not None: - if succ_op_node.op().type() == "create_py_reader" \ - or succ_op_node.op().type() == "create_double_buffer_reader" \ - or succ_op_node.op().type() == "read": + if ( + succ_op_node.op().type() == "create_py_reader" + or succ_op_node.op().type() + == "create_double_buffer_reader" + or succ_op_node.op().type() == "read" + ): continue - op_dist_attr = self._dist_context.get_op_dist_attr_for_graph( - succ_op_node) - if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_graph( + succ_op_node + ) + ) + if ( + op_dist_attr.process_mesh + == tensor_dist_attr.process_mesh + ): op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) + tensor_desc.name() + ) dims_mapping_list.append(op_dims_mapping) dims_mapping_list.append(tensor_dims_mapping) compatible_dims_mapping = compute_compatible_dims_mapping( - dims_mapping_list) - if not _validate_dims_mapping(compatible_dims_mapping, - tensor_dist_attr.process_mesh): + dims_mapping_list + ) + if not _validate_dims_mapping( + compatible_dims_mapping, tensor_dist_attr.process_mesh + ): return False - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != tensor_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != tensor_dims_mapping + ): tensor_dist_attr.dims_mapping = compatible_dims_mapping changed = True return changed @@ -216,10 +249,12 @@ class Completer: return False # Skip reader op op_desc = op_node.op() - if op_desc.type() == "create_py_reader" \ - or op_desc.type() == "create_double_buffer_reader" \ - or op_desc.type() == "while" \ - or op_desc.type() == "read": + if ( + op_desc.type() == "create_py_reader" + or op_desc.type() == "create_double_buffer_reader" + or op_desc.type() == "while" + or op_desc.type() == "read" + ): return False dist_op = self._dist_context.get_dist_op_for_graph(op_node) op_dist_attr = dist_op.dist_attr @@ -231,28 +266,42 @@ class Completer: continue tensor_desc = tensor_node.var() if op_dist_attr.is_annotated_input_dims_mapping( - tensor_desc.name()): + tensor_desc.name() + ): continue - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) + if ( + op_dist_attr.process_mesh + == tensor_dist_attr.process_mesh + ): tensor_dims_mapping = tensor_dist_attr.dims_mapping op_dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) + tensor_desc.name() + ) + compatible_dims_mapping = ( + compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping] + ) + ) if not _validate_dims_mapping( - compatible_dims_mapping, - op_dist_attr.process_mesh): + compatible_dims_mapping, op_dist_attr.process_mesh + ): continue - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != op_dims_mapping + ): op_dist_attr.set_input_dims_mapping( - tensor_desc.name(), compatible_dims_mapping) + tensor_desc.name(), compatible_dims_mapping + ) changed = True # Find the most compatible implemenetations from the distributed operator - op_dist_impls = find_compatible_distributed_operator_impls(dist_op, - fwd=True) + op_dist_impls = find_compatible_distributed_operator_impls( + dist_op, fwd=True + ) if op_dist_impls is not None: not_compatible = True backup_op_dist_attr = copy.deepcopy(op_dist_attr) @@ -261,8 +310,10 @@ class Completer: dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True - if op_dist_impl.is_auto_compatible(dist_op) \ - and dist_op.validate_dist_attr(): + if ( + op_dist_impl.is_auto_compatible(dist_op) + and dist_op.validate_dist_attr() + ): if op_dist_impl.type == "elementwise": op_dist_attr.impl_type = "default" else: @@ -287,28 +338,42 @@ class Completer: continue tensor_desc = tensor_node.var() if op_dist_attr.is_annotated_output_dims_mapping( - tensor_desc.name()): + tensor_desc.name() + ): continue - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) + if ( + op_dist_attr.process_mesh + == tensor_dist_attr.process_mesh + ): tensor_dims_mapping = tensor_dist_attr.dims_mapping op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) - compatible_dims_mapping = compute_compatible_dims_mapping( - [op_dims_mapping, tensor_dims_mapping]) + tensor_desc.name() + ) + compatible_dims_mapping = ( + compute_compatible_dims_mapping( + [op_dims_mapping, tensor_dims_mapping] + ) + ) if not _validate_dims_mapping( - compatible_dims_mapping, - op_dist_attr.process_mesh): + compatible_dims_mapping, op_dist_attr.process_mesh + ): continue - if (compatible_dims_mapping is not None) and \ - (compatible_dims_mapping != op_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != op_dims_mapping + ): op_dist_attr.set_output_dims_mapping( - tensor_desc.name(), compatible_dims_mapping) + tensor_desc.name(), compatible_dims_mapping + ) changed = True # Find the most compatible implemenetations from the distributed operator op_dist_impls = find_compatible_distributed_operator_impls( - dist_op, fwd=False) + dist_op, fwd=False + ) if op_dist_impls is not None: not_compatible = True backup_op_dist_attr = copy.deepcopy(op_dist_attr) @@ -317,8 +382,10 @@ class Completer: dim_changed = op_dist_impl.update_dims_mapping(dist_op) if dim_changed: changed = True - if op_dist_impl.is_auto_compatible(dist_op) \ - and dist_op.validate_dist_attr(): + if ( + op_dist_impl.is_auto_compatible(dist_op) + and dist_op.validate_dist_attr() + ): if op_dist_impl.type == "elementwise": op_dist_attr.impl_type = "default" else: @@ -342,24 +409,33 @@ class Completer: changed = False for parent_node, child_node in self._node_pairs_between_graphs: parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph( - parent_node) + parent_node + ) child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( - child_node) - if parent_node_dist_attr.process_mesh != child_node_dist_attr.process_mesh: + child_node + ) + if ( + parent_node_dist_attr.process_mesh + != child_node_dist_attr.process_mesh + ): continue parent_node_dims_mapping = parent_node_dist_attr.dims_mapping child_node_dims_mapping = child_node_dist_attr.dims_mapping compatible_dims_mapping = compute_compatible_dims_mapping( - [parent_node_dims_mapping, child_node_dims_mapping]) - if not _validate_dims_mapping(compatible_dims_mapping, - parent_node_dist_attr.process_mesh): + [parent_node_dims_mapping, child_node_dims_mapping] + ) + if not _validate_dims_mapping( + compatible_dims_mapping, parent_node_dist_attr.process_mesh + ): return False - if (compatible_dims_mapping is not None) \ - and (compatible_dims_mapping != parent_node_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != parent_node_dims_mapping + ): parent_node_dist_attr.dims_mapping = compatible_dims_mapping changed = True - if (compatible_dims_mapping is not None) \ - and (compatible_dims_mapping != child_node_dims_mapping): + if (compatible_dims_mapping is not None) and ( + compatible_dims_mapping != child_node_dims_mapping + ): child_node_dist_attr.dims_mapping = compatible_dims_mapping changed = True return changed @@ -369,11 +445,15 @@ class Completer: op_nodes = self._dist_context._serial_ordered_op_nodes # NOTE: this list may be changed if Paddle changes the existing rules. related_reader_ops = [ - "create_py_reader", "create_double_buffer_reader", "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ] for op_node in op_nodes: - if op_node.op() is not None \ - and op_node.op().type() in related_reader_ops: + if ( + op_node.op() is not None + and op_node.op().type() in related_reader_ops + ): continue op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) for tensor_node in op_node.outputs: @@ -381,11 +461,18 @@ class Completer: if tensor_node.var().type() == core.VarDesc.VarType.READER: continue tensor_desc = tensor_node.var() - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - if op_dist_attr.process_mesh == tensor_dist_attr.process_mesh: + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) + if ( + op_dist_attr.process_mesh + == tensor_dist_attr.process_mesh + ): op_dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_desc.name()) + tensor_desc.name() + ) tensor_dist_attr.dims_mapping = op_dims_mapping def _update_dims_mapping(self): @@ -394,17 +481,22 @@ class Completer: while not reach_fix_point: changed = False for is_fwd in [True, False]: - all_nodes = self._dist_context.serial_ordered_nodes \ - if is_fwd else reversed(self._dist_context.serial_ordered_nodes) + all_nodes = ( + self._dist_context.serial_ordered_nodes + if is_fwd + else reversed(self._dist_context.serial_ordered_nodes) + ) for node in all_nodes: if node.is_var() and node.var() is not None: tensor_changed = self._update_tensor_node_dims_mapping( - node, fwd=is_fwd) + node, fwd=is_fwd + ) if tensor_changed: changed = True if node.is_op() and node.op() is not None: op_changed = self._update_op_node_dims_mapping( - node, fwd=is_fwd) + node, fwd=is_fwd + ) if op_changed: changed = True graph_changed = self._update_dims_mapping_between_graphs() @@ -423,12 +515,16 @@ class Completer: if not op_dist_attr.is_annotated("process_mesh"): process_mesh = op_dist_attr.process_mesh nearest_op_dis_attr = self._dist_context.get_dist_attr_for_graph( - nearest_op_node) + nearest_op_node + ) nearest_process_mesh = nearest_op_dis_attr.process_mesh compatible_process_mesh = compute_compatible_process_mesh( - [process_mesh, nearest_process_mesh]) - if compatible_process_mesh is not None \ - and process_mesh != compatible_process_mesh: + [process_mesh, nearest_process_mesh] + ) + if ( + compatible_process_mesh is not None + and process_mesh != compatible_process_mesh + ): op_dist_attr.process_mesh = compatible_process_mesh # Skip the process_mesh setting of inputs and outputs of while_op if op_dist_attr.op_type == "while": @@ -436,43 +532,60 @@ class Completer: # Set the process mesh of the op node's leaf-inputs for tensor_node in op_node.inputs: if tensor_node.is_var() and tensor_node.var() is not None: - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) if tensor_dist_attr.is_annotated("process_mesh"): continue # Skip the non-leaf var node if len(tensor_node.inputs) != 0: continue compatible_process_mesh = compute_compatible_process_mesh( - [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh]) - if compatible_process_mesh is not None \ - and tensor_dist_attr.process_mesh != compatible_process_mesh: + [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh] + ) + if ( + compatible_process_mesh is not None + and tensor_dist_attr.process_mesh != compatible_process_mesh + ): tensor_dist_attr.process_mesh = compatible_process_mesh # Set the process mesh of the op node's outputs for tensor_node in op_node.outputs: if tensor_node.is_var() and tensor_node.var() is not None: - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) if tensor_dist_attr.is_annotated("process_mesh"): continue compatible_process_mesh = compute_compatible_process_mesh( - [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh]) - if compatible_process_mesh is not None \ - and tensor_dist_attr.process_mesh != compatible_process_mesh: + [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh] + ) + if ( + compatible_process_mesh is not None + and tensor_dist_attr.process_mesh != compatible_process_mesh + ): tensor_dist_attr.process_mesh = compatible_process_mesh def _update_process_mesh_for_specials(self): - def _find_nearest_tensor_node_before(nodes, idx, var_name): for node in reversed(nodes[:idx]): - if node.is_var() and node.var() is not None \ - and node.var().name() == var_name: + if ( + node.is_var() + and node.var() is not None + and node.var().name() == var_name + ): return node def _find_nearest_tensor_node_after(nodes, idx, var_name): - for node in nodes[idx + 1:]: - if node.is_var() and node.var() is not None \ - and node.var().name() == var_name: + for node in nodes[idx + 1 :]: + if ( + node.is_var() + and node.var() is not None + and node.var().name() == var_name + ): return node def _find_nodes_related_to_cond(source_node): @@ -490,28 +603,42 @@ class Completer: neighbors = cur.inputs + cur.outputs for node in neighbors: if node.is_var() and node.var() is not None: - if node.var().type() != core.VarDesc.VarType.READER \ - and len(node.var().shape()) == 1: + if ( + node.var().type() != core.VarDesc.VarType.READER + and len(node.var().shape()) == 1 + ): frontier.append(node) related_nodes.append(node) if node.is_op() and node.op() is not None: flag = True - if node.op().type() == "create_py_reader" \ - or node.op().type() == "create_double_buffer_reader" \ - or node.op().type() == "read": + if ( + node.op().type() == "create_py_reader" + or node.op().type() == "create_double_buffer_reader" + or node.op().type() == "read" + ): flag = False for tensor_node in node.inputs: - if tensor_node.is_var() and tensor_node.var( - ) is not None: - if tensor_node.var().type() in __not_shape_var_type__ \ - or len(tensor_node.var().shape()) != 1: + if ( + tensor_node.is_var() + and tensor_node.var() is not None + ): + if ( + tensor_node.var().type() + in __not_shape_var_type__ + or len(tensor_node.var().shape()) != 1 + ): flag = False break for tensor_node in node.outputs: - if tensor_node.is_var() and tensor_node.var( - ) is not None: - if tensor_node.var().type() in __not_shape_var_type__ \ - or len(tensor_node.var().shape()) != 1: + if ( + tensor_node.is_var() + and tensor_node.var() is not None + ): + if ( + tensor_node.var().type() + in __not_shape_var_type__ + or len(tensor_node.var().shape()) != 1 + ): flag = False break if flag: @@ -536,27 +663,32 @@ class Completer: dims_mapping = dist_attr.get_output_dims_mapping(arg_name) for _ in dims_mapping: new_dims_mapping.append(-1) - dist_attr.set_output_dims_mapping(arg_name, - new_dims_mapping) + dist_attr.set_output_dims_mapping( + arg_name, new_dims_mapping + ) # Amend the process meshes related to while_op for while_op_node, while_op_node_idx in self._while_op_nodes.values(): sub_graph_id = while_op_node.op()._block_attr_id("sub_block") sub_graph = self._dist_context.serial_graph.get_sub_graph( - sub_graph_id) + sub_graph_id + ) sub_graph_nodes = list(sub_graph.all_nodes()) while_dist_op = self._dist_context.get_dist_op_for_graph( - while_op_node) + while_op_node + ) while_op_dist_attr = while_dist_op.dist_attr # Step 1: set the process mesh of while_op to the merged process mesh of its subblock merged_process_mesh = while_op_dist_attr.process_mesh for node in sub_graph_nodes: - if (node.is_var() and node.var() is not None) \ - or (node.is_op() and node.op() is not None): + if (node.is_var() and node.var() is not None) or ( + node.is_op() and node.op() is not None + ): dist_attr = self._dist_context.get_dist_attr_for_graph(node) merged_process_mesh = merge_process_mesh_two( - merged_process_mesh, dist_attr.process_mesh) + merged_process_mesh, dist_attr.process_mesh + ) while_op_dist_attr.process_mesh = merged_process_mesh _make_dims_mapping_replicate(while_op_dist_attr) @@ -566,97 +698,143 @@ class Completer: cond_tensor_name = while_op_node.op().input("Condition")[0] cond_tensor_node = None for node in while_op_node.inputs: - if node.is_var() and node.var() is not None \ - and node.var().name() == cond_tensor_name: + if ( + node.is_var() + and node.var() is not None + and node.var().name() == cond_tensor_name + ): cond_tensor_node = node cond_tensor_related_nodes.append(cond_tensor_node) break cond_tensor_related_nodes.extend( - _find_nodes_related_to_cond(cond_tensor_node)) + _find_nodes_related_to_cond(cond_tensor_node) + ) # Step 2.2: Find related nodes of cond var in the subgraph of while_op cond_tensor_node = None for node in reversed(sub_graph_nodes): - if node.is_var() and node.var() is not None \ - and node.var().name() == cond_tensor_name \ - and len(node.outputs) == 0: + if ( + node.is_var() + and node.var() is not None + and node.var().name() == cond_tensor_name + and len(node.outputs) == 0 + ): cond_tensor_node = node break cond_tensor_related_nodes.extend( - _find_nodes_related_to_cond(cond_tensor_node)) + _find_nodes_related_to_cond(cond_tensor_node) + ) # Step 2.3: Add the StepScops output of while_op stepscopes_tensor_name = while_op_node.op().output("StepScopes")[0] stepscopes_tensor_node = None for output_node in while_op_node.outputs: - if output_node.is_var() and output_node.var() is not None \ - and output_node.var().name() == stepscopes_tensor_name: + if ( + output_node.is_var() + and output_node.var() is not None + and output_node.var().name() == stepscopes_tensor_name + ): stepscopes_tensor_node = output_node cond_tensor_related_nodes.append(stepscopes_tensor_node) # Step 2.4: Set the process meshes of all nodes related to cond var to the process mesh of while op for node in cond_tensor_related_nodes: tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( - node) + node + ) tensor_dist_attr.process_mesh = merged_process_mesh _make_dims_mapping_replicate(tensor_dist_attr) # Step 3: set the process meshes of the inputs in while_op to the process meshes of the outside input nodes while_op_inputs_dist_attrs = while_op_dist_attr.inputs_dist_attrs - for tensor_name, tensor_dist_attr in while_op_inputs_dist_attrs.items( - ): + for ( + tensor_name, + tensor_dist_attr, + ) in while_op_inputs_dist_attrs.items(): nearest_tensor_node = _find_nearest_tensor_node_before( - self._dist_context.serial_ordered_nodes, while_op_node_idx, - tensor_name) - nearest_tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( - nearest_tensor_node) - tensor_dist_attr.process_mesh = nearest_tensor_dist_attr.process_mesh + self._dist_context.serial_ordered_nodes, + while_op_node_idx, + tensor_name, + ) + nearest_tensor_dist_attr = ( + self._dist_context.get_dist_attr_for_graph( + nearest_tensor_node + ) + ) + tensor_dist_attr.process_mesh = ( + nearest_tensor_dist_attr.process_mesh + ) # Step 4: set the process meshes of the outputs in while_op to the process meshes of the outside output nodes while_op_outputs_dist_attrs = while_op_dist_attr.outputs_dist_attrs - for tensor_name, tensor_dist_attr in while_op_outputs_dist_attrs.items( - ): + for ( + tensor_name, + tensor_dist_attr, + ) in while_op_outputs_dist_attrs.items(): nearest_tensor_node = _find_nearest_tensor_node_before( - self._dist_context.serial_ordered_nodes, while_op_node_idx, - tensor_name) + self._dist_context.serial_ordered_nodes, + while_op_node_idx, + tensor_name, + ) if nearest_tensor_node is None: nearest_tensor_node = _find_nearest_tensor_node_after( self._dist_context.serial_ordered_nodes, - while_op_node_idx, tensor_name) - nearest_tensor_dist_attr = self._dist_context.get_dist_attr_for_graph( - nearest_tensor_node) - tensor_dist_attr.process_mesh = nearest_tensor_dist_attr.process_mesh + while_op_node_idx, + tensor_name, + ) + nearest_tensor_dist_attr = ( + self._dist_context.get_dist_attr_for_graph( + nearest_tensor_node + ) + ) + tensor_dist_attr.process_mesh = ( + nearest_tensor_dist_attr.process_mesh + ) # Amend the process meshes related to array for array_node_list in self._array_nodes.values(): merged_process_mesh = None for array_node in array_node_list: dist_attr = self._dist_context.get_dist_attr_for_graph( - array_node) + array_node + ) merged_process_mesh = merge_process_mesh_two( - merged_process_mesh, dist_attr.process_mesh) + merged_process_mesh, dist_attr.process_mesh + ) for array_node in array_node_list: dist_attr = self._dist_context.get_dist_attr_for_graph( - array_node) + array_node + ) dist_attr.process_mesh = merged_process_mesh _make_dims_mapping_replicate(dist_attr) def _update_process_mesh_between_graphs(self): for parent_node, child_node in self._node_pairs_between_graphs: parent_node_dist_attr = self._dist_context.get_dist_attr_for_graph( - parent_node) + parent_node + ) child_node_dist_attr = self._dist_context.get_dist_attr_for_graph( - child_node) - parent_node_dist_attr.process_mesh = child_node_dist_attr.process_mesh - compatible_process_mesh = compute_compatible_process_mesh([ - parent_node_dist_attr.process_mesh, + child_node + ) + parent_node_dist_attr.process_mesh = ( child_node_dist_attr.process_mesh - ]) - if compatible_process_mesh is not None \ - and parent_node_dist_attr.process_mesh != compatible_process_mesh: + ) + compatible_process_mesh = compute_compatible_process_mesh( + [ + parent_node_dist_attr.process_mesh, + child_node_dist_attr.process_mesh, + ] + ) + if ( + compatible_process_mesh is not None + and parent_node_dist_attr.process_mesh + != compatible_process_mesh + ): parent_node_dist_attr.process_mesh = compatible_process_mesh - if compatible_process_mesh is not None \ - and child_node_dist_attr.process_mesh != compatible_process_mesh: + if ( + compatible_process_mesh is not None + and child_node_dist_attr.process_mesh != compatible_process_mesh + ): child_node_dist_attr.process_mesh = compatible_process_mesh def _update_process_mesh(self): @@ -665,8 +843,9 @@ class Completer: # Step 1: Set the annotated process meshes from tensors to the first ops using them ordered_tensor_nodes = self._dist_context._serial_ordered_tensor_nodes for tensor_node in ordered_tensor_nodes: - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph(tensor_node) + ) if not tensor_dist_attr.is_annotated("process_mesh"): continue first_op_node = None @@ -684,13 +863,18 @@ class Completer: if first_op_node is None: continue op_dist_attr = self._dist_context.get_dist_attr_for_graph( - first_op_node) + first_op_node + ) if op_dist_attr is not None and not op_dist_attr.is_annotated( - "process_mesh"): + "process_mesh" + ): compatible_process_mesh = compute_compatible_process_mesh( - [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh]) - if compatible_process_mesh is not None \ - and op_dist_attr.process_mesh != compatible_process_mesh: + [tensor_dist_attr.process_mesh, op_dist_attr.process_mesh] + ) + if ( + compatible_process_mesh is not None + and op_dist_attr.process_mesh != compatible_process_mesh + ): op_dist_attr.process_mesh = compatible_process_mesh # Step 2: set the process meshes of ops with the nearest op before them @@ -698,8 +882,10 @@ class Completer: idx_of_first_op_node_has_process_mesh = -1 for idx, op_node in enumerate(ordered_op_nodes): op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) - if op_dist_attr.process_mesh is not None \ - and idx_of_first_op_node_has_process_mesh == -1: + if ( + op_dist_attr.process_mesh is not None + and idx_of_first_op_node_has_process_mesh == -1 + ): idx_of_first_op_node_has_process_mesh = idx # Reuse the following method to set the related tensors for same op node self._update_process_mesh_by_nearest(op_node, op_node) @@ -707,17 +893,20 @@ class Completer: if idx_of_first_op_node_has_process_mesh + 1 > len(ordered_op_nodes): return None for idx, op_node in enumerate( - ordered_op_nodes[idx_of_first_op_node_has_process_mesh + 1:]): + ordered_op_nodes[idx_of_first_op_node_has_process_mesh + 1 :] + ): original_idx = idx_of_first_op_node_has_process_mesh + idx + 1 nearest_op_node = ordered_op_nodes[original_idx - 1] nearest_op_dist_attr = self._dist_context.get_dist_attr_for_graph( - nearest_op_node) + nearest_op_node + ) op_dist_attr = self._dist_context.get_dist_attr_for_graph(op_node) assert nearest_op_dist_attr.process_mesh is not None self._update_process_mesh_by_nearest(op_node, nearest_op_node) # Step 2.3: set the process meshes of ops by the nearest op node before the first op node nearest_op_node = ordered_op_nodes[ - idx_of_first_op_node_has_process_mesh] + idx_of_first_op_node_has_process_mesh + ] for op_node in ordered_op_nodes[:idx_of_first_op_node_has_process_mesh]: self._update_process_mesh_by_nearest(op_node, nearest_op_node) @@ -728,6 +917,20 @@ class Completer: self._update_process_mesh_between_graphs() def _prepare(self): + def _find_nearest_parent_nodes(sorted_parent_nodes, child_idx): + before_node = None + after_node = None + pos = -1 + for pos, (parent_idx, parent_node) in enumerate( + sorted_parent_nodes + ): + if parent_idx > child_idx: + after_node = parent_node + break + if pos > 0: + _, before_node = sorted_parent_nodes[pos - 1] + return before_node, after_node + if self._has_prepared: return self._while_op_nodes = {} @@ -751,24 +954,27 @@ class Completer: self._array_nodes[array_var_name] = [] self._array_nodes[array_var_name].append(node) self._array_nodes[array_var_name].append(node.outputs[0]) + # TODO: Use dict and name as the key to store the nodes, + # and use the id comparsion to deal with the before or after position if node.is_var() and node.var() is not None: if node.node.graph_id() != 0: - for before_node in reversed(all_nodes[:idx]): - if before_node.is_var() and before_node.var() is not None \ - and before_node.node.graph_id() == node.node.graph_id() - 1 \ - and before_node.var().name() == node.var().name(): - self._node_pairs_between_graphs.append( - (before_node, node)) - for after_node in all_nodes[idx + 1:]: - if after_node.is_var() and after_node.var() is not None \ - and after_node.node.graph_id() == node.node.graph_id() - 1 \ - and after_node.var().name() == node.var().name(): + parent_nodes = ( + self._dist_context._tensor_nodes_with_same_name[ + node.node.graph_id() - 1 + ].get(node.var().name(), None) + ) + if parent_nodes is not None: + sorted_parent_nodes = sorted( + parent_nodes, key=lambda x: x[0] + ) + for _, parent_node in sorted_parent_nodes: self._node_pairs_between_graphs.append( - (after_node, node)) + (parent_node, node) + ) self._has_prepared = True def complete_forward_annotation(self, serial_main_program=None): - """ Complete annotation for the partial annotated serial_main_program. + """Complete annotation for the partial annotated serial_main_program. Arguments: serial_main_program: partial annotated serial_main_program. Returns:e @@ -787,14 +993,22 @@ class Completer: # self._dist_context.validate_dist_attr_for_program() + start_time = time.time() self._prepare() + # print("completion-prepare: ", time.time() - start_time, flush=True) + start_time = time.time() self._update_process_mesh() + # print("completion-mesh: ", time.time() - start_time, flush=True) + start_time = time.time() self._update_dims_mapping() + # print("graph-dims: ", time.time() - start_time, flush=True) + start_time = time.time() # Copy the corresponding distributed attribute from graph to serial_main_program self._dist_context.copy_dist_attr_from_graph_to_program() + # print("completion-copy: ", time.time() - start_time, flush=True) else: self._dist_context.initialize(with_graph=False) @@ -822,8 +1036,9 @@ class Completer: # TODO: we must ensure the world process group contains all ranks ranks = get_world_process_group().ranks process_mesh = ProcessMesh(ranks) - for dist_tensor in self._dist_context._dist_tensors_for_program.values( - ): + for ( + dist_tensor + ) in self._dist_context._dist_tensors_for_program.values(): serial_tensor = dist_tensor.serial_tensor tensor_dist_attr = dist_tensor.dist_attr tensor_dist_attr.process_mesh = process_mesh @@ -842,27 +1057,35 @@ class Completer: if not serial_tensor.is_parameter: if arg_name not in input_xshape_arg_names: old_dims_mapping = op_dist_attr.get_input_dims_mapping( - arg_name) + arg_name + ) if len(old_dims_mapping) > 0: new_dims_mapping = [0] + [ -1 for _ in range(len(old_dims_mapping) - 1) ] op_dist_attr.set_input_dims_mapping( - arg_name, new_dims_mapping) + arg_name, new_dims_mapping + ) else: old_dims_mapping = op_dist_attr.get_input_dims_mapping( - arg_name) + arg_name + ) if len(old_dims_mapping) > 1: new_dims_mapping = [-1, 0] + [ -1 for _ in range(len(old_dims_mapping) - 2) ] op_dist_attr.set_input_dims_mapping( - arg_name, new_dims_mapping) + arg_name, new_dims_mapping + ) # Set tensor's dims_mapping by the op's - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - serial_tensor) - tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping( - arg_name) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + serial_tensor + ) + ) + tensor_dist_attr.dims_mapping = ( + op_dist_attr.get_input_dims_mapping(arg_name) + ) output_xshape_arg_names = [] if "XShape" in op_desc.output_names(): output_xshape_arg_names = op_desc.output("XShape") @@ -871,37 +1094,48 @@ class Completer: if not serial_tensor.is_parameter: if arg_name not in output_xshape_arg_names: old_dims_mapping = op_dist_attr.get_output_dims_mapping( - arg_name) + arg_name + ) if len(old_dims_mapping) > 0: new_dims_mapping = [0] + [ -1 for _ in range(len(old_dims_mapping) - 1) ] op_dist_attr.set_output_dims_mapping( - arg_name, new_dims_mapping) + arg_name, new_dims_mapping + ) else: old_dims_mapping = op_dist_attr.get_output_dims_mapping( - arg_name) + arg_name + ) if len(old_dims_mapping) > 1: new_dims_mapping = [-1, 0] + [ -1 for _ in range(len(old_dims_mapping) - 2) ] op_dist_attr.set_output_dims_mapping( - arg_name, new_dims_mapping) + arg_name, new_dims_mapping + ) # Set tensor's dims_mapping by the op's - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - serial_tensor) - tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( - arg_name) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + serial_tensor + ) + ) + tensor_dist_attr.dims_mapping = ( + op_dist_attr.get_output_dims_mapping(arg_name) + ) op_dist_impls = find_compatible_distributed_operator_impls( - dist_op, partial=False) + dist_op, partial=False + ) if op_dist_impls is not None: not_compatible = True backup_op_dist_attr = copy.deepcopy(op_dist_attr) for op_dist_impl in op_dist_impls: op_dist_impl.update_dims_mapping(dist_op) - if op_dist_impl.is_auto_compatible(dist_op) \ - and dist_op.validate_dist_attr(): + if ( + op_dist_impl.is_auto_compatible(dist_op) + and dist_op.validate_dist_attr() + ): op_dist_attr.impl_type = op_dist_impl.type op_dist_attr.impl_idx = op_dist_impl.idx not_compatible = False @@ -943,24 +1177,36 @@ class Completer: # Use the first op to set the tensor dist attr if tensor_name in has_set_dist_attr: continue - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dist_attr.process_mesh = op_dist_attr.process_mesh - tensor_dist_attr.dims_mapping = op_dist_attr.get_input_dims_mapping( - tensor_name) if tensor.is_parameter else [ - -1 for i in tensor_desc.shape() - ] + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) + tensor_dist_attr.process_mesh = ( + op_dist_attr.process_mesh + ) + tensor_dist_attr.dims_mapping = ( + op_dist_attr.get_input_dims_mapping(tensor_name) + if tensor.is_parameter + else [-1 for i in tensor_desc.shape()] + ) has_set_dist_attr.add(tensor_name) for tensor_node in node.outputs: if tensor_node.is_var() and tensor_node.var() is not None: tensor_name = tensor_node.var().name() if tensor_name in has_set_dist_attr: continue - tensor_dist_attr = self._dist_context.get_tensor_dist_attr_for_graph( - tensor_node) - tensor_dist_attr.process_mesh = op_dist_attr.process_mesh - tensor_dist_attr.dims_mapping = op_dist_attr.get_output_dims_mapping( - tensor_name) + tensor_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_graph( + tensor_node + ) + ) + tensor_dist_attr.process_mesh = ( + op_dist_attr.process_mesh + ) + tensor_dist_attr.dims_mapping = ( + op_dist_attr.get_output_dims_mapping(tensor_name) + ) has_set_dist_attr.add(tensor_name) self._update_process_mesh_for_specials() @@ -981,7 +1227,7 @@ class Completer: def _complete_high_order_grad_annotation(self, serial_main_program=None): """ - NOTE: + NOTE: [HighOrderGrad] Complete the annotation of vars and ops only for high order gradient. This function is temporary to support high order gradient, and will be removed in the future. """ @@ -1011,79 +1257,108 @@ class Completer: for idx in range(0, len(ops)): op = ops[idx] if int(op.attr('op_role')) == int( - core.op_proto_and_checker_maker.OpRole.Forward): + core.op_proto_and_checker_maker.OpRole.Forward + ): continue if int(op.attr('op_role')) == int( - core.op_proto_and_checker_maker.OpRole.Backward) and int( - ops[idx - 1].attr('op_role')) == int( - core.op_proto_and_checker_maker.OpRole.Forward): + core.op_proto_and_checker_maker.OpRole.Backward + ) and int(ops[idx - 1].attr('op_role')) == int( + core.op_proto_and_checker_maker.OpRole.Forward + ): appended_grad_times += 1 if int(op.attr('op_role')) == int( - int(core.op_proto_and_checker_maker.OpRole.Backward) - | int(core.op_proto_and_checker_maker.OpRole.Loss)): + int(core.op_proto_and_checker_maker.OpRole.Backward) + | int(core.op_proto_and_checker_maker.OpRole.Loss) + ): assert op.type == "fill_constant" break # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] - if grad_op.desc.original_id( - ) in dist_op_context.grad_op_id_to_op_id: + if ( + grad_op.desc.original_id() + in dist_op_context.grad_op_id_to_op_id + ): # TODO support the case where one forward op corresponding to multiple xxx_grad op forward_op = _get_op_by_id( - ops, dist_op_context.grad_op_id_to_op_id[ - grad_op.desc.original_id()]) + ops, + dist_op_context.grad_op_id_to_op_id[ + grad_op.desc.original_id() + ], + ) assert forward_op is not None - fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( - forward_op) + fwd_op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program(forward_op) + ) fwd_op_process_mesh = fwd_op_dist_attr.process_mesh grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = fwd_op_process_mesh for input_name in grad_op.input_arg_names: - if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names: + if ( + input_name not in forward_op.input_arg_names + and input_name not in forward_op.output_arg_names + ): if input_name in grad_var_to_var[appended_grad_times]: fwd_name = grad_var_to_var[appended_grad_times][ - input_name] - ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( - fwd_name) + input_name + ] + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping( + fwd_name + ) + ) else: input_var = vars[input_name] ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( - input_var).dims_mapping + input_var + ).dims_mapping else: if fwd_op_dist_attr.get_input_dims_mapping(input_name): - ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( - input_name) + ref_dims_mapping = ( + fwd_op_dist_attr.get_input_dims_mapping( + input_name + ) + ) else: - ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( - input_name) - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_name) + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping( + input_name + ) + ) + assert ( + ref_dims_mapping is not None + ), "[{}] 's dims mapping is NONE".format(input_name) grad_op_dist_attr.set_input_dims_mapping( - input_name, ref_dims_mapping) + input_name, ref_dims_mapping + ) for output_name in grad_op.output_arg_names: assert output_name in grad_var_to_var[appended_grad_times] fwd_name = grad_var_to_var[appended_grad_times][output_name] ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( - fwd_name) + fwd_name + ) # var output_var = vars[output_name] tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.process_mesh = fwd_op_process_mesh self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr.set_output_dims_mapping( - output_name, ref_dims_mapping) + output_name, ref_dims_mapping + ) self._dist_context.set_op_dist_attr_for_program( - grad_op, grad_op_dist_attr) + grad_op, grad_op_dist_attr + ) # grad ops that have not a corresponding mapping in grad_op_id_to_op_id else: @@ -1091,14 +1366,20 @@ class Completer: if grad_op.type == 'sum': assert all(map(_is_grad_var_name, grad_op.input_arg_names)) output_name = grad_op.output_arg_names[0] - assert output_name in grad_var_to_var[appended_grad_times], \ - "sum op's output '{}' has no corresponding var".format( - output_name) + assert ( + output_name in grad_var_to_var[appended_grad_times] + ), "sum op's output '{}' has no corresponding var".format( + output_name + ) ref_fwd_var_name = grad_var_to_var[appended_grad_times][ - output_name] + output_name + ] ref_fwd_var = vars[ref_fwd_var_name] - ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - ref_fwd_var) + ref_fwd_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var + ) + ) ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh # output @@ -1107,21 +1388,27 @@ class Completer: tensor_dist_attr.process_mesh = ref_fwd_process_mesh output_var = vars[output_name] self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = ref_fwd_process_mesh for var_name in grad_op.input_arg_names: grad_op_dist_attr.set_input_dims_mapping( - var_name, ref_fwd_dims_mapping) + var_name, ref_fwd_dims_mapping + ) grad_op_dist_attr.set_output_dims_mapping( - output_name, ref_fwd_dims_mapping) + output_name, ref_fwd_dims_mapping + ) elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] - ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - ref_var) + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + ref_var + ) + ) ref_dims_mapping = ref_dist_attr.dims_mapping ref_process_mesh = ref_dist_attr.process_mesh # output @@ -1131,24 +1418,29 @@ class Completer: output_var_name = grad_op.output_arg_names[0] output_var = vars[output_var_name] self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.set_input_dims_mapping( - ref_var_name, ref_dims_mapping) + ref_var_name, ref_dims_mapping + ) grad_op_dist_attr.set_output_dims_mapping( - output_var_name, ref_dims_mapping) + output_var_name, ref_dims_mapping + ) elif grad_op.type in ['shape', 'fill_constant']: continue else: - raise ValueError("got unexpect op [{}]".format( - str(grad_op.type))) + raise ValueError( + "got unexpect op [{}]".format(str(grad_op.type)) + ) self._dist_context.set_op_dist_attr_for_program( - grad_op, grad_op_dist_attr) + grad_op, grad_op_dist_attr + ) def complete_backward_annotation(self, serial_main_program=None): """Complete the annotation of vars and ops in the backward phase for parallel program.""" @@ -1165,9 +1457,9 @@ class Completer: def _get_forward_varname_from_grad_varname(grad_var_name): assert _is_grad_var_name( - grad_var_name), "[{}] is not a grad varnme.".format( - grad_var_name) - return grad_var_name[:grad_var_name.find("@GRAD")] + grad_var_name + ), "[{}] is not a grad varnme.".format(grad_var_name) + return grad_var_name[: grad_var_name.find("@GRAD")] def _get_op_by_id(ops, id): for op in ops: @@ -1178,160 +1470,216 @@ class Completer: first_backward_op_idx = -1 for idx, op in enumerate(serial_main_program.global_block().ops): if int(op.attr('op_role')) == int( - int(core.op_proto_and_checker_maker.OpRole.Backward) - | int(core.op_proto_and_checker_maker.OpRole.Loss)): + int(core.op_proto_and_checker_maker.OpRole.Backward) + | int(core.op_proto_and_checker_maker.OpRole.Loss) + ): assert op.type == "fill_constant" first_backward_op_idx = idx break - assert first_backward_op_idx >= 0, "No backward procedure found in this program." + assert ( + first_backward_op_idx >= 0 + ), "No backward procedure found in this program." ops = list(serial_main_program.global_block().ops) vars = serial_main_program.global_block().vars dist_op_context = self._dist_context.dist_op_context - grad_var_to_var = dist_op_context.grad_var_to_var[len( - dist_op_context.grad_var_to_var)] + grad_var_to_var = dist_op_context.grad_var_to_var[ + len(dist_op_context.grad_var_to_var) + ] for idx in range(first_backward_op_idx, len(ops)): # complete the initial grad loss op if idx == first_backward_op_idx: assert ops[idx].type == "fill_constant" - assert len( - ops[idx].input_arg_names - ) == 0, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].input_arg_names)) - assert len( - ops[idx].output_arg_names - ) == 1, "first backward op should has only ONE output, but got [{}]".format( - len(ops[idx].output_arg_names)) + assert ( + len(ops[idx].input_arg_names) == 0 + ), "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].input_arg_names) + ) + assert ( + len(ops[idx].output_arg_names) == 1 + ), "first backward op should has only ONE output, but got [{}]".format( + len(ops[idx].output_arg_names) + ) grad_var = vars[ops[idx].output_arg_names[0]] forward_var_name = _get_forward_varname_from_grad_varname( - grad_var.name) + grad_var.name + ) forward_var = vars[forward_var_name] # TODO complete other attribte for grad var tensor_dist_attr = TensorDistributedAttribute() - process_mesh = self._dist_context.get_tensor_dist_attr_for_program( - forward_var).process_mesh - dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( - forward_var).dims_mapping + process_mesh = ( + self._dist_context.get_tensor_dist_attr_for_program( + forward_var + ).process_mesh + ) + dims_mapping = ( + self._dist_context.get_tensor_dist_attr_for_program( + forward_var + ).dims_mapping + ) tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.process_mesh = process_mesh self._dist_context.set_tensor_dist_attr_for_program( - grad_var, tensor_dist_attr) + grad_var, tensor_dist_attr + ) op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = process_mesh - op_dist_attr.set_output_dims_mapping(grad_var.name, - dims_mapping) + op_dist_attr.set_output_dims_mapping( + grad_var.name, dims_mapping + ) self._dist_context.set_op_dist_attr_for_program( - ops[idx], op_dist_attr) + ops[idx], op_dist_attr + ) continue # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] - if grad_op.desc.original_id( - ) in dist_op_context.grad_op_id_to_op_id: + if ( + grad_op.desc.original_id() + in dist_op_context.grad_op_id_to_op_id + ): # TODO support the case where one forward op corresponding to multiple xxx_grad op forward_op = _get_op_by_id( ops[:first_backward_op_idx], dist_op_context.grad_op_id_to_op_id[ - grad_op.desc.original_id()]) + grad_op.desc.original_id() + ], + ) assert forward_op is not None if grad_op.type == "concat" and forward_op.type == "split": - forward_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( - forward_op) + forward_op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program( + forward_op + ) + ) output_var = vars[grad_op.desc.output('Out')[0]] split_input_var_name = forward_op.input("X")[0] - ref_dims_mapping = forward_op_dist_attr.get_input_dims_mapping( - split_input_var_name) + ref_dims_mapping = ( + forward_op_dist_attr.get_input_dims_mapping( + split_input_var_name + ) + ) ref_mesh = forward_op_dist_attr.process_mesh grad_op_dist_attr = OperatorDistributedAttribute() for input_name in grad_op.input_arg_names: grad_op_dist_attr.set_input_dims_mapping( - input_name, ref_dims_mapping) + input_name, ref_dims_mapping + ) output_var_dist_attr = TensorDistributedAttribute() output_var_dist_attr.dims_mapping = ref_dims_mapping output_var_dist_attr.process_mesh = ref_mesh self._dist_context.set_tensor_dist_attr_for_program( - output_var, output_var_dist_attr) + output_var, output_var_dist_attr + ) grad_op_dist_attr.set_output_dims_mapping( - output_var.name, ref_dims_mapping) + output_var.name, ref_dims_mapping + ) grad_op_dist_attr.process_mesh = ref_mesh self._dist_context.set_op_dist_attr_for_program( - grad_op, grad_op_dist_attr) + grad_op, grad_op_dist_attr + ) grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx continue - fwd_op_dist_attr = self._dist_context.get_op_dist_attr_for_program( - forward_op) + fwd_op_dist_attr = ( + self._dist_context.get_op_dist_attr_for_program(forward_op) + ) fwd_op_process_mesh = fwd_op_dist_attr.process_mesh grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = fwd_op_process_mesh for input_name in grad_op.input_arg_names: - if input_name not in forward_op.input_arg_names and input_name not in forward_op.output_arg_names: + if ( + input_name not in forward_op.input_arg_names + and input_name not in forward_op.output_arg_names + ): if input_name in grad_var_to_var: fwd_name = grad_var_to_var[input_name] - ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( - fwd_name) + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping( + fwd_name + ) + ) else: input_var = vars[input_name] ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( - input_var).dims_mapping + input_var + ).dims_mapping else: if fwd_op_dist_attr.get_input_dims_mapping(input_name): - ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( - input_name) + ref_dims_mapping = ( + fwd_op_dist_attr.get_input_dims_mapping( + input_name + ) + ) else: - ref_dims_mapping = fwd_op_dist_attr.get_output_dims_mapping( - input_name) - assert ref_dims_mapping is not None, "[{}] 's dims mapping is NONE".format( - input_name) + ref_dims_mapping = ( + fwd_op_dist_attr.get_output_dims_mapping( + input_name + ) + ) + assert ( + ref_dims_mapping is not None + ), "[{}] 's dims mapping is NONE".format(input_name) grad_op_dist_attr.set_input_dims_mapping( - input_name, ref_dims_mapping) + input_name, ref_dims_mapping + ) for output_name in grad_op.output_arg_names: assert output_name in grad_var_to_var fwd_name = grad_var_to_var[output_name] ref_dims_mapping = fwd_op_dist_attr.get_input_dims_mapping( - fwd_name) + fwd_name + ) # var output_var = vars[output_name] tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.dims_mapping = ref_dims_mapping tensor_dist_attr.process_mesh = fwd_op_process_mesh self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr.set_output_dims_mapping( - output_name, ref_dims_mapping) + output_name, ref_dims_mapping + ) grad_op_dist_attr.impl_type = fwd_op_dist_attr.impl_type grad_op_dist_attr.impl_idx = fwd_op_dist_attr.impl_idx self._dist_context.set_op_dist_attr_for_program( - grad_op, grad_op_dist_attr) + grad_op, grad_op_dist_attr + ) # grad ops that have not a corresponding mapping in grad_op_id_to_op_id else: if grad_op.type == 'sum': assert all(map(_is_grad_var_name, grad_op.input_arg_names)) output_name = grad_op.output_arg_names[0] - assert output_name in grad_var_to_var, "sum op's output '{}' has no corresponding var".format( - output_name) + assert ( + output_name in grad_var_to_var + ), "sum op's output '{}' has no corresponding var".format( + output_name + ) ref_fwd_var_name = grad_var_to_var[output_name] ref_fwd_var = vars[ref_fwd_var_name] - ref_fwd_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - ref_fwd_var) + ref_fwd_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + ref_fwd_var + ) + ) ref_fwd_dims_mapping = ref_fwd_dist_attr.dims_mapping ref_fwd_process_mesh = ref_fwd_dist_attr.process_mesh @@ -1341,24 +1689,30 @@ class Completer: tensor_dist_attr.process_mesh = ref_fwd_process_mesh output_var = vars[output_name] self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = ref_fwd_process_mesh for var_name in grad_op.input_arg_names: grad_op_dist_attr.set_input_dims_mapping( - var_name, ref_fwd_dims_mapping) + var_name, ref_fwd_dims_mapping + ) grad_op_dist_attr.set_output_dims_mapping( - output_name, ref_fwd_dims_mapping) + output_name, ref_fwd_dims_mapping + ) grad_op_dist_attr.impl_type = "default" grad_op_dist_attr.impl_idx = 0 elif grad_op.type == 'fill_any_like': ref_var_name = grad_op.input_arg_names[0] ref_var = vars[ref_var_name] - ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - ref_var) + ref_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + ref_var + ) + ) ref_dims_mapping = ref_dist_attr.dims_mapping ref_process_mesh = ref_dist_attr.process_mesh # output @@ -1368,27 +1722,35 @@ class Completer: output_var_name = grad_op.output_arg_names[0] output_var = vars[output_var_name] self._dist_context.set_tensor_dist_attr_for_program( - output_var, tensor_dist_attr) + output_var, tensor_dist_attr + ) # op grad_op_dist_attr = OperatorDistributedAttribute() grad_op_dist_attr.process_mesh = ref_process_mesh grad_op_dist_attr.set_input_dims_mapping( - ref_var_name, ref_dims_mapping) + ref_var_name, ref_dims_mapping + ) grad_op_dist_attr.set_output_dims_mapping( - output_var_name, ref_dims_mapping) + output_var_name, ref_dims_mapping + ) else: - raise ValueError("got unexpect op [{}]".format( - str(grad_op.type))) + raise ValueError( + "got unexpect op [{}]".format(str(grad_op.type)) + ) self._dist_context.set_op_dist_attr_for_program( - grad_op, grad_op_dist_attr) + grad_op, grad_op_dist_attr + ) def complete_update_annotation(self, serial_main_program): """Complete the annotation of vars and ops in the update phase for parallel program.""" # Copy the dist tensors and dist ops annotated by users from the default context # global mesh - from paddle.distributed.auto_parallel.process_group import get_world_process_group + from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, + ) + world_ranks = get_world_process_group().ranks # Notice: serial_main_program is actually a dist_main_program of current rank, @@ -1407,17 +1769,22 @@ class Completer: if int(op.attr('op_role')) == int(OpRole.Optimize): if is_gradient_clip_op(op): if op.type in [ - "sum", "sqrt", "fill_constant", "elementwise_max", - "elementwise_div" + "sum", + "sqrt", + "fill_constant", + "elementwise_max", + "elementwise_div", ]: op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = world_ranks for in_name in op.input_arg_names: in_var = vars[in_name] in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - in_var) + in_var + ) op_dist_attr.set_input_dist_attr( - in_name, in_dist_attr) + in_name, in_dist_attr + ) for out_name in op.output_arg_names: out_var = vars[out_name] out_dist_attr = TensorDistributedAttribute() @@ -1426,22 +1793,30 @@ class Completer: -1 for _ in range(len(out_var.shape)) ] self._dist_context.set_tensor_dist_attr_for_program( - out_var, out_dist_attr) + out_var, out_dist_attr + ) op_dist_attr.set_output_dist_attr( - out_name, out_dist_attr) + out_name, out_dist_attr + ) else: in_var = vars[op.input("X")[0]] - in_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - in_var) + in_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + in_var + ) + ) assert in_dist_attr is not None ref_process_mesh = in_dist_attr.process_mesh ref_dims_mapping = in_dist_attr.dims_mapping - if op.type == "cast" and \ - ops[idx + 1].type == "elementwise_mul": + if ( + op.type == "cast" + and ops[idx + 1].type == "elementwise_mul" + ): ref_var = vars[ops[idx + 1].input("X")[0]] ref_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - ref_var) + ref_var + ) assert ref_dist_attr is not None ref_process_mesh = ref_dist_attr.process_mesh @@ -1451,51 +1826,72 @@ class Completer: if out_var.shape == in_var.shape: out_dist_attr.dims_mapping = ref_dims_mapping else: - assert len( - out_var.shape) == 1 and out_var.shape[0] == 1 + assert ( + len(out_var.shape) == 1 + and out_var.shape[0] == 1 + ) out_dist_attr.dims_mapping = [-1] self._dist_context.set_tensor_dist_attr_for_program( - out_var, out_dist_attr) + out_var, out_dist_attr + ) op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = ref_process_mesh op_dist_attr.set_input_dist_attr( - in_var.name, in_dist_attr) + in_var.name, in_dist_attr + ) op_dist_attr.set_output_dist_attr( - out_var.name, out_dist_attr) + out_var.name, out_dist_attr + ) self._dist_context.set_op_dist_attr_for_program( - op, op_dist_attr) + op, op_dist_attr + ) if "Grad" in op.input_names and "Param" in ops[idx].input_names: - assert len( - op.input("Param")) == 1, "Only support one-to-one now." - assert len( - op.input("Grad")) == 1, "Only support one-to-one now." + assert ( + len(op.input("Param")) == 1 + ), "Only support one-to-one now." + assert ( + len(op.input("Grad")) == 1 + ), "Only support one-to-one now." param = vars[op.input("Param")[0]] grad_var = vars[op.input("Grad")[0]] - param_dist_attr = self._dist_context.get_tensor_dist_attr_for_program( - param) + param_dist_attr = ( + self._dist_context.get_tensor_dist_attr_for_program( + param + ) + ) assert param_dist_attr is not None - ref_process_mesh = self._dist_context.get_tensor_dist_attr_for_program( - param).process_mesh + ref_process_mesh = ( + self._dist_context.get_tensor_dist_attr_for_program( + param + ).process_mesh + ) assert ref_process_mesh is not None - ref_dims_mapping = self._dist_context.get_tensor_dist_attr_for_program( - param).dims_mapping + ref_dims_mapping = ( + self._dist_context.get_tensor_dist_attr_for_program( + param + ).dims_mapping + ) assert ref_dims_mapping is not None op_dist_attr = OperatorDistributedAttribute() op_dist_attr.process_mesh = ref_process_mesh - op_dist_attr.set_input_dims_mapping(grad_var.name, - ref_dims_mapping) - op_dist_attr.set_input_dims_mapping(param.name, - ref_dims_mapping) + op_dist_attr.set_input_dims_mapping( + grad_var.name, ref_dims_mapping + ) + op_dist_attr.set_input_dims_mapping( + param.name, ref_dims_mapping + ) op_dist_attr.set_output_dims_mapping( - param.name, ref_dims_mapping) + param.name, ref_dims_mapping + ) learning_var = vars[op.input("LearningRate")[0]] op_dist_attr.set_input_dims_mapping(learning_var.name, [-1]) op_dist_attr.set_output_dims_mapping( - learning_var.name, [-1]) + learning_var.name, [-1] + ) if not learning_rate_completed: learning_rate_completed = True @@ -1503,18 +1899,19 @@ class Completer: var_dist_attr.process_mesh = world_ranks var_dist_attr.dims_mapping = [-1] self._dist_context.set_tensor_dist_attr_for_program( - learning_var, var_dist_attr) + learning_var, var_dist_attr + ) for input_name in op.desc.input_names(): if input_name in [ - 'Param', - 'Grad', - 'LearningRate', - "SkipUpdate", - "Beta1Tensor", - "Beta2Tensor", - "EpsilonTensor", + 'Param', + 'Grad', + 'LearningRate', + "SkipUpdate", + "Beta1Tensor", + "Beta2Tensor", + "EpsilonTensor", ]: continue if len(op.desc.input(input_name)) == 0: @@ -1527,22 +1924,28 @@ class Completer: if "Beta1Pow" in input_name or "Beta2Pow" in input_name: input_var_attr.dims_mapping = [-1] op_dist_attr.set_input_dims_mapping( - input_var.name, [-1]) + input_var.name, [-1] + ) op_dist_attr.set_output_dims_mapping( - input_var.name, [-1]) + input_var.name, [-1] + ) else: input_var_attr.dims_mapping = ref_dims_mapping op_dist_attr.set_input_dims_mapping( - input_var.name, ref_dims_mapping) + input_var.name, ref_dims_mapping + ) op_dist_attr.set_output_dims_mapping( - input_var.name, ref_dims_mapping) + input_var.name, ref_dims_mapping + ) input_var_attr.process_mesh = ref_process_mesh self._dist_context.set_tensor_dist_attr_for_program( - input_var, input_var_attr) + input_var, input_var_attr + ) self._dist_context.set_op_dist_attr_for_program( - op, op_dist_attr) + op, op_dist_attr + ) continue def complete_prim_annotation(self, serial_main_program=None): @@ -1578,14 +1981,18 @@ class Completer: def _init_global_mesh_for_program(self): # Copy the dist tensors and dist ops annotated by users from the default context # global mesh - from paddle.distributed.auto_parallel.process_group import get_world_process_group + from paddle.distributed.auto_parallel.process_group import ( + get_world_process_group, + ) + world_ranks = get_world_process_group().ranks for block in self._dist_context._serial_main_program.blocks: for tensor in block.vars.values(): # Copy the distributed tensors in the default context dist_tensor = self._dist_context.get_dist_tensor_for_program( - tensor) + tensor + ) assert dist_tensor is not None dist_tensor.dist_attr.process_mesh = world_ranks for op in block.ops: @@ -1596,7 +2003,8 @@ class Completer: # Find the most compatible implemenetations from the distributed operator op_dist_impls = find_compatible_distributed_operator_impls( - dist_op, fwd=True) + dist_op, fwd=True + ) if op_dist_impls is not None: backup_op_dist_attr = copy.deepcopy(dist_op.dist_attr) for op_dist_impl in op_dist_impls: diff --git a/python/paddle/distributed/auto_parallel/constants.py b/python/paddle/distributed/auto_parallel/constants.py index 82c5011faf0..44d804a4816 100644 --- a/python/paddle/distributed/auto_parallel/constants.py +++ b/python/paddle/distributed/auto_parallel/constants.py @@ -94,6 +94,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False) set_field_default_config(GRADIENT_MERGE, "k_steps", 1) set_field_default_config(GRADIENT_MERGE, "avg", True) +######################################### +# pipeline configuration +######################################### +PIPELINE = "pipeline" +set_field_default_config(PIPELINE, "enable", False) +set_field_default_config(PIPELINE, "schedule_mode", "1F1B") +set_field_default_config(PIPELINE, "micro_batch_size", 1) +set_field_default_config(PIPELINE, "accumulate_steps", 1) +set_field_default_config(PIPELINE, "generation_batch_size", 1) + ######################################### # quantization configuration ######################################### diff --git a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py index a3d737769d0..c4862af3848 100644 --- a/python/paddle/distributed/auto_parallel/cost/estimate_cost.py +++ b/python/paddle/distributed/auto_parallel/cost/estimate_cost.py @@ -556,8 +556,8 @@ def get_cost_from_engine(engine, mode): ) serial_startup_prog = ( - engine._serial_startup_progs[mode].clone() - if mode in engine._serial_startup_progs + engine._fwd_dist_contexts[mode]._original_serial_main_program.clone() + if mode in engine._fwd_dist_contexts else engine._orig_startup_prog.clone() ) losses = ( diff --git a/python/paddle/distributed/auto_parallel/dist_context.py b/python/paddle/distributed/auto_parallel/dist_context.py index 387c964f0aa..437df0723f9 100644 --- a/python/paddle/distributed/auto_parallel/dist_context.py +++ b/python/paddle/distributed/auto_parallel/dist_context.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License +import time import copy from collections import defaultdict import paddle.fluid @@ -24,7 +25,7 @@ from .dist_attribute import OperatorDistributedAttribute from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator from .process_mesh import ProcessMesh -from .utils import is_loss_grad_op, is_loss_op +from .utils import is_loss_grad_op, is_loss_op, is_valid_list_index # There always exists a default context for user. And user can set it to another one. _g_default_distributed_context = None @@ -53,15 +54,17 @@ class DistributedContext: One auto-parallel run should use its own DistributedContext to avoid interfering other run. """ - def __init__(self, - serial_main_prog=None, - serial_startup_prog=None, - serial_optimizer=None, - serial_loss=None, - feed_vars={}, - fetch_vars={}, - cluster=None, - strategy=None): + def __init__( + self, + serial_main_prog=None, + serial_startup_prog=None, + serial_optimizer=None, + serial_loss=None, + feed_vars={}, + fetch_vars={}, + cluster=None, + strategy=None, + ): # Data members related to original programs (unchanged) self._original_serial_main_program = serial_main_prog self._original_serial_startup_program = serial_startup_prog @@ -110,7 +113,7 @@ class DistributedContext: # self._tensor_id_to_tensor_node_ids = {} self._is_initialized = False - #TODO: need a better way to remove the following flag + # TODO: need a better way to remove the following flag self._need_copy_dist_attr_to_graph = False self._backup_pass_context_stack = [] self._backup_block_state_stack = [] @@ -125,6 +128,9 @@ class DistributedContext: # A flag indicates whether the used parallelism is data parallel self._data_parallel = False + # record upstream and downstream of cur rank + self._up_down_streams = UpDownStream() + @property def serial_main_program(self): return self._serial_main_program @@ -192,7 +198,8 @@ class DistributedContext: @property def has_annotation(self): return len(self._dist_tensors_for_program) or len( - self._dist_ops_for_program) + self._dist_ops_for_program + ) @property def gradient_scale(self): @@ -206,24 +213,33 @@ class DistributedContext: def data_parallel(self): return self._data_parallel + @property + def up_down_streams(self): + return self._up_down_streams + @data_parallel.setter def data_parallel(self, dp): self._data_parallel = dp def _backup_serial_info(self, mode): self._backup_serial_main_program_stack.append( - self._serial_main_program.clone()) + self._serial_main_program.clone() + ) self._backup_serial_startup_program_stack.append( - self._serial_startup_program.clone()) - self._backup_pass_context_stack.append(copy.deepcopy( - self._pass_context)) + self._serial_startup_program.clone() + ) + self._backup_pass_context_stack.append( + copy.deepcopy(self._pass_context) + ) self._backup_block_state_stack.append(copy.deepcopy(self._block_state)) def _backup_dist_info(self, mode): self._backup_dist_tensors_for_program_stack.append( - copy.deepcopy(self._dist_tensors_for_program)) + copy.deepcopy(self._dist_tensors_for_program) + ) self._backup_dist_ops_for_program_stack.append( - copy.deepcopy(self._dist_ops_for_program)) + copy.deepcopy(self._dist_ops_for_program) + ) def _backup(self, serial=True, serial_mode=None, dist=True, dist_mode=None): # Use this function carefully @@ -240,7 +256,8 @@ class DistributedContext: block_idx = loss.block.idx var_name = loss.name var = self._serial_main_program.blocks[ - block_idx]._var_recursive(var_name) + block_idx + ]._var_recursive(var_name) self._serial_loss = var elif len(self._original_serial_loss) == 0: self._serial_loss = [] @@ -250,7 +267,8 @@ class DistributedContext: block_idx = self._original_serial_loss.block.idx var_name = self._original_serial_loss.name var = self._serial_main_program.blocks[ - block_idx]._var_recursive(var_name) + block_idx + ]._var_recursive(var_name) self._serial_loss = var def _restore_serial_feed_vars(self): @@ -260,7 +278,8 @@ class DistributedContext: block_idx = var.block.idx var_name = var.name var = self._serial_main_program.blocks[ - block_idx]._var_recursive(var_name) + block_idx + ]._var_recursive(var_name) new_var_list.append(var) self._serial_feed_vars[key] = new_var_list @@ -275,7 +294,8 @@ class DistributedContext: block_idx = var.block.idx var_name = var.name var = self._serial_main_program.blocks[ - block_idx]._var_recursive(var_name) + block_idx + ]._var_recursive(var_name) new_inner_var_list.append(var) new_var_list.append(new_inner_var_list) else: @@ -283,22 +303,27 @@ class DistributedContext: block_idx = var.block.idx var_name = var.name var = self._serial_main_program.blocks[ - block_idx]._var_recursive(var_name) + block_idx + ]._var_recursive(var_name) new_var_list.append(var) self._serial_fetch_vars[key] = new_var_list def _restore_serial_info(self, mode="to_backup"): if mode == "to_backup": - self._serial_main_program = self._backup_serial_main_program_stack.pop( + self._serial_main_program = ( + self._backup_serial_main_program_stack.pop() ) - self._serial_startup_program = self._backup_serial_startup_program_stack.pop( + self._serial_startup_program = ( + self._backup_serial_startup_program_stack.pop() ) elif mode == "to_original": assert self._original_serial_main_program is not None assert self._original_serial_startup_program is not None - self._serial_main_program = self._original_serial_main_program.clone( + self._serial_main_program = ( + self._original_serial_main_program.clone() ) - self._serial_startup_program = self._original_serial_startup_program.clone( + self._serial_startup_program = ( + self._original_serial_startup_program.clone() ) self._restore_serial_loss() @@ -310,21 +335,27 @@ class DistributedContext: def _restore_dist_info(self, mode="to_backup"): if mode == "to_backup": - self._dist_tensors_for_program = self._backup_dist_tensors_for_program_stack.pop( + self._dist_tensors_for_program = ( + self._backup_dist_tensors_for_program_stack.pop() ) - self._dist_ops_for_program = self._backup_dist_ops_for_program_stack.pop( + self._dist_ops_for_program = ( + self._backup_dist_ops_for_program_stack.pop() ) elif mode == "to_original": assert self._original_dist_tensors_for_program assert self._original_dist_ops_for_program self._dist_tensors_for_program = copy.deepcopy( - self._original_dist_tensors_for_program) + self._original_dist_tensors_for_program + ) self._dist_ops_for_program = copy.deepcopy( - self._original_dist_ops_for_program) + self._original_dist_ops_for_program + ) elif mode == "to_default": new_tensors_ids = [] - for tensor_id, dist_tensor in self._dist_tensors_for_program.items( - ): + for ( + tensor_id, + dist_tensor, + ) in self._dist_tensors_for_program.items(): if tensor_id in self._tensors_ids: dist_tensor.dist_attr.reset() else: @@ -341,8 +372,10 @@ class DistributedContext: self._dist_ops_for_program.pop(op_id) else: new_tensors_ids = [] - for tensor_id, dist_tensor in self._dist_tensors_for_program.items( - ): + for ( + tensor_id, + dist_tensor, + ) in self._dist_tensors_for_program.items(): new_tensors_ids.append(tensor_id) for tensor_id in new_tensors_ids: self._dist_tensors_for_program.pop(tensor_id) @@ -357,11 +390,13 @@ class DistributedContext: self._need_copy_dist_attr_to_graph = True self._process_meshes = [] - def _restore(self, - serial=True, - serial_mode="to_backup", - dist=True, - dist_mode="to_backup"): + def _restore( + self, + serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_backup", + ): # Use this function carefully if serial: self._restore_serial_info(serial_mode) @@ -372,11 +407,13 @@ class DistributedContext: if not self._is_initialized: if not self._serial_main_program: if self._original_serial_main_program: - self._serial_main_program = self._original_serial_main_program.clone( + self._serial_main_program = ( + self._original_serial_main_program.clone() ) if not self._serial_startup_program: if self._original_serial_startup_program: - self._serial_startup_program = self._original_serial_startup_program.clone( + self._serial_startup_program = ( + self._original_serial_startup_program.clone() ) if not self._serial_loss: self._restore_serial_loss() @@ -390,26 +427,34 @@ class DistributedContext: self._init_dist_attr_for_program() # Backup the original distributed information for later restore self._original_dist_tensors_for_program = copy.deepcopy( - self._dist_tensors_for_program) + self._dist_tensors_for_program + ) self._original_dist_ops_for_program = copy.deepcopy( - self._dist_ops_for_program) + self._dist_ops_for_program + ) self._tensors_ids = list(self._dist_tensors_for_program.keys()) self._ops_ids = list(self._dist_ops_for_program.keys()) self._is_initialized = True if with_graph: set_flags({"FLAGS_convert_all_blocks": True}) + start_time = time.time() self._serial_graph = framework.IrGraph( - core.Graph(self._serial_main_program.desc)) + core.Graph(self._serial_main_program.desc) + ) + # print("context-graph-build: ", time.time() - start_time, flush=True) self._init_dist_attr_for_graph() + start_time = time.time() self._need_copy_dist_attr_to_graph = False + # print("context-graph-dist: ", time.time() - start_time, flush=True) if self._need_copy_dist_attr_to_graph and with_graph: self.copy_dist_attr_from_program_to_graph() def add_process_mesh(self, process_mesh): - assert isinstance(process_mesh, ProcessMesh), \ - 'The type of dim_mapping must be ProcessMesh.' + assert isinstance( + process_mesh, ProcessMesh + ), 'The type of dim_mapping must be ProcessMesh.' if process_mesh not in self.process_meshes: self._process_meshes.append(process_mesh) @@ -431,7 +476,8 @@ class DistributedContext: else: serial_tensor_id = serial_tensor.desc.original_id() dist_tensor = self._dist_tensors_for_program.get( - serial_tensor_id, None) + serial_tensor_id, None + ) if dist_tensor: return dist_tensor else: @@ -471,7 +517,8 @@ class DistributedContext: else: serial_tensor_id = serial_tensor.desc.original_id() dist_tensor = self._dist_tensors_for_program.get( - serial_tensor_id, None) + serial_tensor_id, None + ) if dist_tensor: return dist_tensor.dist_attr else: @@ -490,8 +537,9 @@ class DistributedContext: def get_tensor_dist_attr_for_graph(self, serial_tensor_node): serial_tensor_node_id = _node_id(serial_tensor_node) - dist_tensor = self._dist_tensors_for_graph.get(serial_tensor_node_id, - None) + dist_tensor = self._dist_tensors_for_graph.get( + serial_tensor_node_id, None + ) if dist_tensor: return dist_tensor.dist_attr else: @@ -533,7 +581,8 @@ class DistributedContext: if serial_node.is_var() and serial_node.var() is not None: serial_tensor_node_id = _node_id(serial_node) dist_tensor = self._dist_tensors_for_graph.get( - serial_tensor_node_id, None) + serial_tensor_node_id, None + ) if dist_tensor: return dist_tensor.dist_attr else: @@ -560,7 +609,8 @@ class DistributedContext: for tensor in block.vars.values(): # Copy the distributed tensors in the default context default_dist_tensor = default_ctx.get_dist_tensor_for_program( - tensor) + tensor + ) if default_dist_tensor and default_ctx is not self: self.add_dist_tensor_for_program(default_dist_tensor) current_dist_tensor = self.get_dist_tensor_for_program(tensor) @@ -577,115 +627,218 @@ class DistributedContext: dist_op = DistributedOperator(op) self.add_dist_op_for_program(dist_op) self._original_dist_tensors_for_program = copy.deepcopy( - self._dist_tensors_for_program) + self._dist_tensors_for_program + ) self._original_dist_ops_for_program = copy.deepcopy( - self._dist_ops_for_program) + self._dist_ops_for_program + ) def _order_nodes_by_program_order(self): + # def _contains(nodes, target_node): + # for node in nodes: + # if _node_id(node) == _node_id(target_node): + # return True + # return False - def _contains(nodes, target_node): - for node in nodes: - if _node_id(node) == _node_id(target_node): - return True - return False - + start_time = time.time() serial_ordered_tensor_nodes = [] serial_ordered_op_nodes = [] all_nodes = [] + visited = {} for idx, graph in enumerate(self._serial_graph.all_sub_graphs()): for node in graph.all_nodes(): all_nodes.append(node) + # print("context-graph-dist-ordering-0: ", time.time() - start_time, flush=True) + start_time = time.time() for node in all_nodes: if node.is_var() and node.var() is not None: serial_ordered_tensor_nodes.append(node) + visited[_node_id(node)] = False if node.is_op() and node.op() is not None: serial_ordered_op_nodes.append(node) + # print("context-graph-dist-ordering-1: ", time.time() - start_time, flush=True) + start_time = time.time() serial_ordered_tensor_nodes.sort( - key=lambda node: node.node.original_desc_id()) + key=lambda node: node.node.original_desc_id() + ) + # print("context-graph-dist-ordering-2: ", time.time() - start_time, flush=True) + start_time = time.time() serial_ordered_op_nodes.sort( - key=lambda node: node.node.original_desc_id()) + key=lambda node: node.node.original_desc_id() + ) + # print("context-graph-dist-ordering-3: ", time.time() - start_time, flush=True) + start_time = time.time() num_nodes_before = len(serial_ordered_tensor_nodes) + len( - serial_ordered_op_nodes) + serial_ordered_op_nodes + ) new_serial_ordered_tensor_nodes = [] new_serial_ordered_op_nodes = [] new_serial_ordered_nodes = [] + tmp_time = 0 + # TODO: user a counter for the following sort for op_node in serial_ordered_op_nodes: tensor_nodes = [] for tensor_node in op_node.inputs: - if tensor_node.is_var() \ - and tensor_node.var() is not None \ - and not _contains(new_serial_ordered_nodes, tensor_node): + # if ( + # tensor_node.is_var() + # and tensor_node.var() is not None + # and not _contains(new_serial_ordered_nodes, tensor_node) + # ): + # tensor_nodes.append(tensor_node) + # new_serial_ordered_tensor_nodes.append(tensor_node) + if ( + tensor_node.is_var() + and tensor_node.var() is not None + and not visited[_node_id(tensor_node)] + ): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True + + inner_start_time = time.time() tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + tmp_time += time.time() - inner_start_time new_serial_ordered_nodes.extend(tensor_nodes) new_serial_ordered_nodes.append(op_node) new_serial_ordered_op_nodes.append(op_node) tensor_nodes = [] for tensor_node in op_node.outputs: - if tensor_node.is_var() \ - and tensor_node.var() is not None \ - and not _contains(new_serial_ordered_nodes, tensor_node): + # if ( + # tensor_node.is_var() + # and tensor_node.var() is not None + # and not _contains(new_serial_ordered_nodes, tensor_node) + # ): + # tensor_nodes.append(tensor_node) + # new_serial_ordered_tensor_nodes.append(tensor_node) + if ( + tensor_node.is_var() + and tensor_node.var() is not None + and not visited[_node_id(tensor_node)] + ): tensor_nodes.append(tensor_node) new_serial_ordered_tensor_nodes.append(tensor_node) + visited[_node_id(tensor_node)] = True + inner_start_time = time.time() tensor_nodes.sort(key=lambda node: node.node.original_desc_id()) + tmp_time += time.time() - inner_start_time new_serial_ordered_nodes.extend(tensor_nodes) + # print("context-graph-dist-ordering-4: ", tmp_time, flush=True) + # print("context-graph-dist-ordering-5: ", time.time() - start_time, flush=True) + start_time = time.time() new_serial_ordered_tensor_nodes.sort( - key=lambda node: node.node.original_desc_id()) + key=lambda node: node.node.original_desc_id() + ) + # print("context-graph-dist-ordering-6: ", time.time() - start_time, flush=True) + start_time = time.time() new_serial_ordered_op_nodes.sort( - key=lambda node: node.node.original_desc_id()) + key=lambda node: node.node.original_desc_id() + ) + # print("context-graph-dist-ordering-7: ", time.time() - start_time, flush=True) + start_time = time.time() self._serial_ordered_tensor_nodes = new_serial_ordered_tensor_nodes self._serial_ordered_op_nodes = new_serial_ordered_op_nodes self._serial_ordered_nodes = new_serial_ordered_nodes assert len(self._serial_ordered_nodes) == len( - self._serial_ordered_tensor_nodes) + len( - self._serial_ordered_op_nodes) + self._serial_ordered_tensor_nodes + ) + len(self._serial_ordered_op_nodes) + # TODO: Use [graph_id][tensor_name][node] to store the tensor nodes for completion preparation + # graph_id -> tensor->name -> node_lists + self._tensor_nodes_with_same_name = defaultdict(dict) + for idx, node in enumerate(self._serial_ordered_nodes): + if node.is_var() and node.var() is not None: + graph_id = node.node.graph_id() + tensor_name = node.var().name() + if ( + self._tensor_nodes_with_same_name[graph_id].get( + tensor_name, None + ) + is None + ): + self._tensor_nodes_with_same_name[graph_id][ + tensor_name + ] = [] + self._tensor_nodes_with_same_name[graph_id][tensor_name].append( + (idx, node) + ) + # print("context-graph-dist-ordering-8: ", time.time() - start_time, flush=True) + + start_time = time.time() self._serial_orphan_tensor_nodes = [] for tensor_node in serial_ordered_tensor_nodes: - if not _contains(self._serial_ordered_tensor_nodes, tensor_node): + # if not _contains(self._serial_ordered_tensor_nodes, tensor_node): + if not visited[_node_id(tensor_node)]: self._serial_orphan_tensor_nodes.append(tensor_node) if len(self._serial_ordered_nodes) != num_nodes_before: print( "WARNING: there are some orphan tensors or ops which are not used in the execution." ) + # print("context-graph-dist-ordering-9: ", time.time() - start_time, flush=True) def _init_dist_attr_for_graph(self): # Convert program to graph and initialize the distributed attributes + start_time = time.time() self._order_nodes_by_program_order() + # print("context-graph-dist-ordering: ", time.time() - start_time, flush=True) + start_time = time.time() + self._tensor_original_id_to_id = {} + self._op_original_id_to_id = {} + for tensor_id, tensor in self._dist_tensors_for_program.items(): + original_id = tensor.serial_tensor.desc.original_id() + self._tensor_original_id_to_id[original_id] = tensor_id + for op_id, op in self._dist_ops_for_program.items(): + original_id = op.serial_op.desc.original_id() + self._op_original_id_to_id[original_id] = op_id + # print("context-graph-dist-mapping: ", time.time() - start_time, flush=True) + start_time = time.time() for node in self.serial_ordered_nodes: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items( - ): - if tensor_id == cur_tensor_id \ - or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): - dist_tensor = cur_dist_tensor - self._node_id_to_tensor_id[_node_id( - node)] = cur_tensor_id - assert dist_tensor is not None, \ - "Tensor must have a distributed tensor after the initialization for program." + cur_dist_tensor = self._dist_tensors_for_program.get( + tensor_id, None + ) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get( + cur_tensor_id, None + ) + dist_tensor = cur_dist_tensor + self._node_id_to_tensor_id[_node_id(node)] = cur_tensor_id + assert ( + dist_tensor is not None + ), "Tensor must have a distributed tensor after the initialization for program." serial_tensor_node_id = _node_id(node) - new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, - dist_tensor.dist_attr) + new_dist_tensor = DistributedTensor( + dist_tensor.serial_tensor, dist_tensor.dist_attr + ) self._dist_tensors_for_graph[ - serial_tensor_node_id] = new_dist_tensor + serial_tensor_node_id + ] = new_dist_tensor if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for cur_op_id, cur_dist_op in self._dist_ops_for_program.items( - ): - if op_id == cur_op_id \ - or op_id == cur_dist_op.serial_op.desc.original_id(): - dist_op = cur_dist_op - self._node_id_to_op_id[_node_id(node)] = cur_op_id - assert dist_op is not None, \ - "Operator must have a distributed operator after the initialization for program." + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get( + cur_op_id, None + ) + dist_op = cur_dist_op + self._node_id_to_op_id[_node_id(node)] = cur_op_id + assert ( + dist_op is not None + ), "Operator must have a distributed operator after the initialization for program." serial_op_node_id = _node_id(node) - new_dist_op = DistributedOperator(dist_op.serial_op, - dist_op.dist_attr) + new_dist_op = DistributedOperator( + dist_op.serial_op, dist_op.dist_attr + ) self._dist_ops_for_graph[serial_op_node_id] = new_dist_op + # print("context-graph-dist-init: ", time.time() - start_time, flush=True) def clear_dist_info_for_program(self): self._dist_tensors_for_program.clear() @@ -700,36 +853,52 @@ class DistributedContext: if node.is_var() and node.var() is not None: dist_tensor = None tensor_id = node.node.original_desc_id() - for cur_tensor_id, cur_dist_tensor in self._dist_tensors_for_program.items( - ): - if tensor_id == cur_tensor_id \ - or tensor_id == cur_dist_tensor.serial_tensor.desc.original_id(): - dist_tensor = cur_dist_tensor - assert dist_tensor is not None, \ - "Tensor must have a distributed tensor after the initialization for program." + cur_dist_tensor = self._dist_tensors_for_program.get( + tensor_id, None + ) + if cur_dist_tensor is not None: + cur_tensor_id = tensor_id + else: + cur_tensor_id = self._tensor_original_id_to_id[tensor_id] + cur_dist_tensor = self._dist_tensors_for_program.get( + cur_tensor_id, None + ) + dist_tensor = cur_dist_tensor + assert ( + dist_tensor is not None + ), "Tensor must have a distributed tensor after the initialization for program." serial_tensor_node_id = _node_id(node) - new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor, - dist_tensor.dist_attr) + new_dist_tensor = DistributedTensor( + dist_tensor.serial_tensor, dist_tensor.dist_attr + ) self._dist_tensors_for_graph[ - serial_tensor_node_id] = new_dist_tensor + serial_tensor_node_id + ] = new_dist_tensor if node.is_op() and node.op() is not None: dist_op = None op_id = node.node.original_desc_id() - for cur_op_id, cur_dist_op in self._dist_ops_for_program.items( - ): - if op_id == cur_op_id \ - or op_id == cur_dist_op.serial_op.desc.original_id(): - dist_op = cur_dist_op - assert dist_op is not None, \ - "Operator must have a distributed operator after the initialization for program." + cur_dist_op = self._dist_ops_for_program.get(op_id, None) + if cur_dist_op is not None: + cur_op_id = op_id + else: + cur_op_id = self._op_original_id_to_id[op_id] + cur_dist_op = self._dist_ops_for_program.get( + cur_op_id, None + ) + dist_op = cur_dist_op + assert ( + dist_op is not None + ), "Operator must have a distributed operator after the initialization for program." serial_op_node_id = _node_id(node) - new_dist_op = DistributedOperator(dist_op.serial_op, - dist_op.dist_attr) + new_dist_op = DistributedOperator( + dist_op.serial_op, dist_op.dist_attr + ) self._dist_ops_for_graph[serial_op_node_id] = new_dist_op def copy_dist_attr_from_graph_to_program(self): - assert self._is_initialized, \ - "Both program and graph must be initialized." + assert ( + self._is_initialized + ), "Both program and graph must be initialized." updated_tensors = {} # all_nodes = self._serial_graph.all_nodes() all_nodes = self._serial_ordered_nodes @@ -739,11 +908,15 @@ class DistributedContext: updated = updated_tensors.get(tensor_id, False) # If a var has multiples var nodes in graph, only use the first one for now if not updated: - tensor_dist_attr_for_graph = self.get_tensor_dist_attr_for_graph( - node) + tensor_dist_attr_for_graph = ( + self.get_tensor_dist_attr_for_graph(node) + ) dist_tensor_for_program = self._dist_tensors_for_program[ - tensor_id] - dist_tensor_for_program.dist_attr = tensor_dist_attr_for_graph + tensor_id + ] + dist_tensor_for_program.dist_attr = ( + tensor_dist_attr_for_graph + ) updated_tensors[tensor_id] = True if node.is_op() and node.op() is not None: op_id = self._node_id_to_op_id[_node_id(node)] @@ -755,22 +928,26 @@ class DistributedContext: for orphan_node in self._serial_orphan_tensor_nodes: serial_tensor_id = orphan_node.var().id() dist_tensor = self._dist_tensors_for_program.get( - serial_tensor_id, None) + serial_tensor_id, None + ) if dist_tensor: dist_tensor.dist_attr.process_mesh = self._process_meshes[0] else: serial_tensor_id = orphan_node.var().original_id() dist_tensor = self._dist_tensors_for_program.get( - serial_tensor_id, None) + serial_tensor_id, None + ) dist_tensor.dist_attr.process_mesh = self._process_meshes[0] def amend_dist_attr_for_program(self): for dist_tensor in self._dist_tensors_for_program.values(): serial_tensor = dist_tensor.serial_tensor dist_attr = dist_tensor.dist_attr - if serial_tensor.type == core.VarDesc.VarType.READER \ - or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + serial_tensor.type == core.VarDesc.VarType.READER + or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = serial_tensor.shape @@ -780,8 +957,11 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + if ( + dims_mapping[i] != -1 + and tensor_shape[i] > 0 + and process_mesh_shape[dims_mapping[i]] > tensor_shape[i] + ): dims_mapping[i] = -1 if dims_mapping[i] != -1 and len(process_mesh_processes) == 1: dims_mapping[i] = -1 @@ -795,9 +975,13 @@ class DistributedContext: if dist_op.get_serial_input(arg_name) is None: tensor_shape = [] else: - if dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.READER \ - or dist_op.get_serial_input(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or dist_op.serial_op.type == "create_py_reader": + if ( + dist_op.get_serial_input(arg_name).type + == core.VarDesc.VarType.READER + or dist_op.get_serial_input(arg_name).type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or dist_op.serial_op.type == "create_py_reader" + ): tensor_shape = [] else: tensor_shape = dist_op.get_serial_input(arg_name).shape @@ -805,16 +989,27 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + if ( + dims_mapping[i] != -1 + and tensor_shape[i] > 0 + and process_mesh_shape[dims_mapping[i]] + > tensor_shape[i] + ): dims_mapping[i] = -1 - if dims_mapping[i] != -1 and len( - process_mesh_processes) == 1: + if ( + dims_mapping[i] != -1 + and len(process_mesh_processes) == 1 + ): dims_mapping[i] = -1 for arg_name in serial_op.output_arg_names: - if dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.READER \ - or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or dist_op.get_serial_output(arg_name).type == core.VarDesc.VarType.STEP_SCOPES: + if ( + dist_op.get_serial_output(arg_name).type + == core.VarDesc.VarType.READER + or dist_op.get_serial_output(arg_name).type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or dist_op.get_serial_output(arg_name).type + == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = dist_op.get_serial_output(arg_name).shape @@ -822,11 +1017,17 @@ class DistributedContext: # If the dimension of tensor is less than the sharding dimension of process mesh, # we just amend the dimension mapping to -1. (Is this really OK?) for i in range(len(tensor_shape)): - if dims_mapping[i] != -1 and tensor_shape[i] > 0 \ - and process_mesh_shape[dims_mapping[i]] > tensor_shape[i]: + if ( + dims_mapping[i] != -1 + and tensor_shape[i] > 0 + and process_mesh_shape[dims_mapping[i]] + > tensor_shape[i] + ): dims_mapping[i] = -1 - if dims_mapping[i] != -1 and len( - process_mesh_processes) == 1: + if ( + dims_mapping[i] != -1 + and len(process_mesh_processes) == 1 + ): dims_mapping[i] = -1 if len(process_mesh_processes) == 1: dist_op.dist_attr.impl_type = "default" @@ -834,30 +1035,44 @@ class DistributedContext: def validate_dist_attr_for_program(self): if not self._is_initialized: - assert False, \ - "Program must be initialized before validating its distributed attributes" + assert ( + False + ), "Program must be initialized before validating its distributed attributes" for block in self.serial_main_program.blocks: for tensor in block.vars.values(): dist_tensor = self.get_dist_tensor_for_program(tensor) - assert dist_tensor is not None, \ - "Tensor {} does not have a distributed attribute.".format( - dist_tensor.serial_tensor.name) - if (dist_tensor - is not None) and (not dist_tensor.validate_dist_attr()): - assert False, "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( + assert ( + dist_tensor is not None + ), "Tensor {} does not have a distributed attribute.".format( + dist_tensor.serial_tensor.name + ) + if (dist_tensor is not None) and ( + not dist_tensor.validate_dist_attr() + ): + assert ( + False + ), "Tensor {} (id: {}, original_id: {}) has a wrong distributed attributes {}.".format( dist_tensor.serial_tensor.name, dist_tensor.serial_tensor.desc.id(), dist_tensor.serial_tensor.desc.original_id(), - dist_tensor.dist_attr) + dist_tensor.dist_attr, + ) for op in block.ops: dist_op = self.get_dist_op_for_program(op) - assert dist_op is not None, \ - "Operator {} does not have a distributed attribute.".format( - dist_op.serial_op.type) + assert ( + dist_op is not None + ), "Operator {} does not have a distributed attribute.".format( + dist_op.serial_op.type + ) if (dist_op is not None) and (not dist_op.validate_dist_attr()): - assert False, "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format( - dist_op.serial_op.type, dist_op.serial_op.desc.id(), - dist_op.serial_op.desc.original_id(), dist_op.dist_attr) + assert ( + False + ), "Operator {} (id: {}, original_id: {}) has a wrong distributed attributes {} .".format( + dist_op.serial_op.type, + dist_op.serial_op.desc.id(), + dist_op.serial_op.desc.original_id(), + dist_op.dist_attr, + ) return True def __deepcopy__(self, memo): @@ -866,15 +1081,28 @@ class DistributedContext: memo[id(self)] = result for k, v in self.__dict__.items(): if k in [ - "_original_serial_main_program", "_original_serial_startup_program", \ - "_serial_main_program", "_serial_startup_program", "_serial_graph", \ - "_dist_main_programs", "_dist_startup_programs", \ - "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ - "_serial_ordered_op_nodes", "_original_serial_loss", \ - "_original_serial_feed_vars", "_original_serial_fetch_vars", \ - "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \ - "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \ - "_pass_context"]: + "_original_serial_main_program", + "_original_serial_startup_program", + "_serial_main_program", + "_serial_startup_program", + "_serial_graph", + "_dist_main_programs", + "_dist_startup_programs", + "_serial_ordered_nodes", + "_serial_ordered_tensor_nodes", + "_serial_ordered_op_nodes", + "_original_serial_loss", + "_original_serial_feed_vars", + "_original_serial_fetch_vars", + "_serial_loss", + "_serial_feed_vars", + "_serial_fetch_vars", + "_serial_optimizer", + "_backup_serial_main_program_stack", + "_backup_serial_startup_program_stack", + "_pass_context", + "_tensor_nodes_with_same_name", + ]: setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) @@ -916,8 +1144,12 @@ class DistributedOperatorContext: memo[id(self)] = result for k, v in self.__dict__.items(): if k in [ - "_dst_main_program", "_dst_startup_program", "_cur_src_op", - "_work_block", "_main_block", "_startup_block" + "_dst_main_program", + "_dst_startup_program", + "_cur_src_op", + "_work_block", + "_main_block", + "_startup_block", ]: setattr(result, k, v) else: @@ -997,7 +1229,6 @@ class DistributedOperatorContext: class BlockState(object): - def __init__(self): self.nblock = 0 self.forward_indices = [] @@ -1014,8 +1245,11 @@ class BlockState(object): for idx, block in enumerate(program.blocks): assert idx == block.idx, "index doesn't match" - assert block.forward_block_idx == -1, "forward_block_idx of forward block [{}] is not [{}]".format( - idx, block.forward_block_idx) + assert ( + block.forward_block_idx == -1 + ), "forward_block_idx of forward block [{}] is not [{}]".format( + idx, block.forward_block_idx + ) self.forward_indices.append(idx) self.nblock += 1 @@ -1024,7 +1258,8 @@ class BlockState(object): def parse_backward_blocks(self, program): assert 0 in self.forward_indices, "forward block idx are{}".format( - self.forward_indices) + self.forward_indices + ) self.backward_to_forward_index_map[0] = 0 for idx, block in enumerate(program.blocks): @@ -1039,3 +1274,49 @@ class BlockState(object): self.nblock += 1 assert self.nblock == len(program.blocks) + + +class UpDownStream: + def __init__(self): + self._ups = dict() + self._downs = dict() + + def add_up_stream(self, rank, up_stream): + ups = self._ups.get(rank, None) + if not ups: + self._ups[rank] = [up_stream] + elif up_stream != -1: + ups = list(filter(lambda a: a != -1, ups)) + ups.append(up_stream) + self._ups[rank] = ups + + def add_down_stream(self, rank, down_stream): + downs = self._downs.get(rank, None) + if not downs: + self._downs[rank] = [down_stream] + elif down_stream != -1: + downs = list(filter(lambda a: a != -1, downs)) + downs.append(down_stream) + self._downs[rank] = downs + + def add_pair_stream(self, up, down): + self.add_up_stream(up, -1) + self.add_up_stream(down, up) + self.add_down_stream(up, down) + self.add_down_stream(down, -1) + # print(up, "'s upstream is ", self.ups(up)) + # print(down, "'s upstream is ", self.ups(down)) + # print(up, "'s downstream is ", self.downs(up)) + # print(down, "'s downstream is ", self.downs(down)) + + def ups(self, rank): + ups = self._ups.get(rank, None) + if not ups: + return None + return list(set(ups)) + + def downs(self, rank): + downs = self._downs.get(rank, None) + if not downs: + return None + return list(set(downs)) diff --git a/python/paddle/distributed/auto_parallel/dist_op.py b/python/paddle/distributed/auto_parallel/dist_op.py index 300c80ec718..1ca49e36473 100644 --- a/python/paddle/distributed/auto_parallel/dist_op.py +++ b/python/paddle/distributed/auto_parallel/dist_op.py @@ -27,7 +27,6 @@ from .utils import convert_to_shard_spec, verify_shard_spec class DistributedOperator: - def __init__(self, serial_op, dist_attr=None): self._serial_op = serial_op self._serial_inputs = {} @@ -78,28 +77,34 @@ class DistributedOperator: if tensor is None: tensor_shape = [] else: - if tensor.type == core.VarDesc.VarType.READER \ - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: + if ( + tensor.type == core.VarDesc.VarType.READER + or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ): tensor_shape = [] else: tensor_shape = tensor.shape if self._dist_attr.get_input_dims_mapping(tensor_name) is None: tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] - self._dist_attr.set_input_dims_mapping(tensor_name, - tensor_dims_mapping) + self._dist_attr.set_input_dims_mapping( + tensor_name, tensor_dims_mapping + ) for tensor_name in self._serial_op.output_arg_names: tensor = self._serial_op.block._var_recursive(tensor_name) - if tensor.type == core.VarDesc.VarType.READER \ - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or tensor.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + tensor.type == core.VarDesc.VarType.READER + or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or tensor.type == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = tensor.shape self._serial_outputs[tensor_name] = tensor if self._dist_attr.get_output_dims_mapping(tensor_name) is None: tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))] - self._dist_attr.set_output_dims_mapping(tensor_name, - tensor_dims_mapping) + self._dist_attr.set_output_dims_mapping( + tensor_name, tensor_dims_mapping + ) if self._dist_attr.op_type is None: self._dist_attr.op_type = self.serial_op.type if self._dist_attr.impl_type is None: @@ -117,8 +122,10 @@ class DistributedOperator: new_dist_attr = {} for key, value in dist_attr.items(): if isinstance(key, Variable): - if key.name in self._serial_op.input_arg_names \ - or key.name in self._serial_op.output_arg_names: + if ( + key.name in self._serial_op.input_arg_names + or key.name in self._serial_op.output_arg_names + ): new_dist_attr[key] = value else: new_dist_attr[key] = value @@ -129,13 +136,15 @@ class DistributedOperator: for tensor_name in self._serial_op.input_arg_names: tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name) if tensor_dist_attr: - new_dist_attr.set_input_dist_attr(tensor_name, - tensor_dist_attr) + new_dist_attr.set_input_dist_attr( + tensor_name, tensor_dist_attr + ) for tensor_name in self._serial_op.output_arg_names: tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name) if tensor_dist_attr: - new_dist_attr.set_output_dist_attr(tensor_name, - tensor_dist_attr) + new_dist_attr.set_output_dist_attr( + tensor_name, tensor_dist_attr + ) else: assert False, "Cannot recognize the {} parameter.".format(dist_attr) return new_dist_attr @@ -146,8 +155,10 @@ class DistributedOperator: for name in self.serial_op.input_arg_names: input_dist_attr = self.dist_attr.get_input_dist_attr(name) dims_mapping = input_dist_attr.dims_mapping - if self.get_serial_input( - name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY: + if ( + self.get_serial_input(name).type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + ): shape = [] else: shape = self.get_serial_input(name).shape @@ -155,7 +166,8 @@ class DistributedOperator: return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.dist_attr.process_mesh.topology): + self.dist_attr.process_mesh.topology + ): return False for i in range(len(self.dist_attr.process_mesh.topology)): if dims_mapping.count(i) > 1: @@ -166,8 +178,12 @@ class DistributedOperator: for name in self.serial_op.output_arg_names: output_dist_attr = self.dist_attr.get_output_dist_attr(name) dims_mapping = output_dist_attr.dims_mapping - if self.get_serial_output(name).type == core.VarDesc.VarType.LOD_TENSOR_ARRAY\ - or self.get_serial_output(name).type == core.VarDesc.VarType.STEP_SCOPES: + if ( + self.get_serial_output(name).type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or self.get_serial_output(name).type + == core.VarDesc.VarType.STEP_SCOPES + ): shape = [] else: shape = self.get_serial_output(name).shape @@ -175,7 +191,8 @@ class DistributedOperator: return False for i in range(len(dims_mapping)): if dims_mapping[i] < -1 or dims_mapping[i] >= len( - self.dist_attr.process_mesh.topology): + self.dist_attr.process_mesh.topology + ): return False for i in range(len(self.dist_attr.process_mesh.topology)): if dims_mapping.count(i) > 1: @@ -185,8 +202,9 @@ class DistributedOperator: return True def __str__(self): - str = "{{op type: {}, op id: {}".format(self.serial_op.desc.type(), - self.serial_op.desc.id()) + str = "{{op type: {}, op id: {}".format( + self.serial_op.desc.type(), self.serial_op.desc.id() + ) # str += ", {}".format(self.dist_attr) # return str @@ -195,8 +213,9 @@ class DistributedOperator: annotated_str = "annotated" else: annotated_str = "non-annotated" - str += ", process_mesh ({}): {}".format(annotated_str, - self.dist_attr.process_mesh) + str += ", process_mesh ({}): {}".format( + annotated_str, self.dist_attr.process_mesh + ) for arg_name in self.serial_op.desc.input_arg_names(): dims_mapping = self.dist_attr.get_input_dims_mapping(arg_name) @@ -212,7 +231,8 @@ class DistributedOperator: else: is_parameter_str = "non-parameter" str += ", {}'s dims_mapping (input, {}, {}): {}".format( - arg_name, annotated_str, is_parameter_str, dims_mapping) + arg_name, annotated_str, is_parameter_str, dims_mapping + ) for arg_name in self.serial_op.desc.output_arg_names(): dims_mapping = self.dist_attr.get_output_dims_mapping(arg_name) @@ -228,12 +248,14 @@ class DistributedOperator: else: is_parameter_str = "non-parameter" str += ", {}'s dims_mapping (output, {}, {}): {}".format( - arg_name, annotated_str, is_parameter_str, dims_mapping) + arg_name, annotated_str, is_parameter_str, dims_mapping + ) str += ", pipeline stage: {}".format(None) str += ", dist_impl idx: {} , dist_impl type {} }}".format( - self.dist_attr._impl_idx, self.dist_attr._impl_type) + self.dist_attr._impl_idx, self.dist_attr._impl_type + ) return str @@ -242,7 +264,11 @@ class DistributedOperator: result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_serial_op" or k == "_serial_inputs" or k == "_serial_outputs": + if ( + k == "_serial_op" + or k == "_serial_inputs" + or k == "_serial_outputs" + ): setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) @@ -250,9 +276,9 @@ class DistributedOperator: class DistributedOperatorHelper: - - def __init__(self, serial_op, process_mesh, in_dims_mappings, - out_dims_mappings): + def __init__( + self, serial_op, process_mesh, in_dims_mappings, out_dims_mappings + ): self._serial_op = serial_op self._process_mesh = process_mesh self._in_dims_mappings = in_dims_mappings @@ -262,8 +288,11 @@ class DistributedOperatorHelper: tensor_to_dims_mapping = {} index = 0 if self._in_dims_mappings: - assert len(args) + len(kwargs) == len(self._in_dims_mappings), \ - "The length of dims_mapping {} does not matching the length output {}.".format(len(self._in_dims_mappings), len(args) + len(kwargs)) + assert len(args) + len(kwargs) == len( + self._in_dims_mappings + ), "The length of dims_mapping {} does not matching the length output {}.".format( + len(self._in_dims_mappings), len(args) + len(kwargs) + ) for arg in args: if isinstance(arg, Variable) and self._in_dims_mappings: tensor_to_dims_mapping[arg.name] = self._in_dims_mappings[index] @@ -287,13 +316,17 @@ class DistributedOperatorHelper: raise ValueError("Unrecognized outpout.") if self._out_dims_mappings: - assert len(new_output) == len(self._out_dims_mappings), \ - "The length of dims_mapping {} does not matching the length output {}.".format(len(self._out_dims_mappings), len(new_output)) + assert len(new_output) == len( + self._out_dims_mappings + ), "The length of dims_mapping {} does not matching the length output {}.".format( + len(self._out_dims_mappings), len(new_output) + ) for i, item in enumerate(new_output): if isinstance(item, Variable) and self._out_dims_mappings: tensor_to_dims_mapping[item.name] = self._out_dims_mappings[i] from .dist_context import get_default_distributed_context + default_dist_ctx = get_default_distributed_context() for idx in range(op_size, new_op_size): op = cur_block.ops[idx] @@ -302,53 +335,68 @@ class DistributedOperatorHelper: if name in tensor_to_dims_mapping.keys(): tensor = dist_op.get_serial_input(name) tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr( - name) + name + ) dims_mapping = tensor_to_dims_mapping[name] if tensor is None: tensor_shape = [] else: - if tensor.type == core.VarDesc.VarType.READER \ - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or tensor.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + tensor.type == core.VarDesc.VarType.READER + or tensor.type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or tensor.type == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = tensor.shape if dims_mapping is not None: dims_mapping = tensor_to_dims_mapping[name] shard_spec = convert_to_shard_spec( - dims_mapping, self._process_mesh) - assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ - "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( - name, shard_spec, tensor_shape, self._process_mesh) + dims_mapping, self._process_mesh + ) + assert verify_shard_spec( + shard_spec, tensor_shape, self._process_mesh + ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + name, shard_spec, tensor_shape, self._process_mesh + ) tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.mark_annotated("dims_mapping") for name in dist_op.serial_op.output_arg_names: if name in tensor_to_dims_mapping.keys(): tensor = dist_op.get_serial_output(name) tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr( - name) + name + ) dims_mapping = tensor_to_dims_mapping[name] if tensor is None: tensor_shape = [] else: - if tensor.type == core.VarDesc.VarType.READER \ - or tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or tensor.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + tensor.type == core.VarDesc.VarType.READER + or tensor.type + == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or tensor.type == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = tensor.shape if dims_mapping is not None: dims_mapping = tensor_to_dims_mapping[name] shard_spec = convert_to_shard_spec( - dims_mapping, self._process_mesh) - assert verify_shard_spec(shard_spec, tensor_shape, self._process_mesh), \ - "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( - name, shard_spec, tensor_shape, self._process_mesh) + dims_mapping, self._process_mesh + ) + assert verify_shard_spec( + shard_spec, tensor_shape, self._process_mesh + ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + name, shard_spec, tensor_shape, self._process_mesh + ) tensor_dist_attr.dims_mapping = dims_mapping tensor_dist_attr.mark_annotated("dims_mapping") dist_op.dist_attr.process_mesh = self._process_mesh if self._process_mesh is not None: dist_op.dist_attr.mark_annotated("process_mesh") default_dist_ctx.add_dist_op_for_program(dist_op) + default_dist_ctx.add_process_mesh(self._process_mesh) return output diff --git a/python/paddle/distributed/auto_parallel/engine.py b/python/paddle/distributed/auto_parallel/engine.py index d3f1b249d12..5df89a277df 100644 --- a/python/paddle/distributed/auto_parallel/engine.py +++ b/python/paddle/distributed/auto_parallel/engine.py @@ -34,6 +34,7 @@ from paddle.fluid.framework import Operator, _non_static_mode from paddle.fluid.framework import _current_expected_place as _get_device from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.distributed import fleet +from paddle.distributed.parallel import _is_global_parallel_initialize from .callbacks import config_callbacks from .converter import Converter @@ -160,7 +161,6 @@ class Engine: " or `paddle.fluid.optimizer.Optimizer`." ) self._optimizer = validate_opt(optimizer) - self._orig_optimizer = copy.deepcopy(self._optimizer) metrics = metrics or [] for metric in to_list(metrics): @@ -185,12 +185,18 @@ class Engine: self._strategy = strategy or Strategy() self._logger = get_logger(logging.INFO) - if os.getenv("POD_NAME"): + if os.getenv("POD_NAME") and not _is_global_parallel_initialize(): self._logger.info( "Distribute training by paddle.distributed.launch" ) fleet.init(is_collective=True) + # for compute cost + # TODO: remove _fwd_main_progs and _orig_optimizer + self._fwd_dist_contexts = {} + self._fwd_main_progs = {} + self._orig_optimizer = copy.deepcopy(self._optimizer) + self._executor = None self._cur_rank = paddle.distributed.get_rank() self._nranks = paddle.distributed.get_world_size() @@ -200,14 +206,6 @@ class Engine: self._orig_startup_prog = static.default_startup_program() self._orig_dist_context = get_default_distributed_context() self._dist_contexts = {} - self._fwd_main_progs = {} - self._fwd_dist_contexts = {} - self._serial_main_progs = {} - self._serial_startup_progs = {} - self._dist_main_progs = defaultdict(dict) # dist main programs - self._dist_startup_progs = defaultdict(dict) # dist startup programs - self._feed_vars = {} - self._fetch_vars = {} self._planners = {} self._has_prepared = {"train": False, "eval": False, "predict": False} self._has_prepared_reader = { @@ -338,9 +336,9 @@ class Engine: return inputs, labels - def _prepare_reader(self): - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] + def _prepare_reader(self, feed_list=[]): dist_context = self._dist_contexts[self._mode] + dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_block = dist_main_prog.global_block() # NOTE: this list may be changed if Paddle changes the existing rules. @@ -361,10 +359,13 @@ class Engine: if op.type in related_reader_ops: reader_op_indices.append(idx) # Step 2: insert the new reader ops to cpp + # record the read ops' desc to insert to program of forward task_node + read_ops_desc = [] new_reader_ops = [] for idx in reversed(reader_op_indices): new_op_desc = dist_main_block.desc._prepend_op() new_op_desc.copy_from(dist_main_block.ops[idx].desc) + read_ops_desc.append(new_op_desc) new_op = Operator( dist_main_block, new_op_desc, type=new_op_desc.type() ) @@ -383,6 +384,29 @@ class Engine: dist_main_block._sync_with_cpp() self._has_prepared_reader[self._mode] = True + # Insert read op to forward TaskNode if 1F1B pass is setted + if self.main_program._pipeline_opt: + assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] + fleet_opt = self.main_program._pipeline_opt["fleet_opt"] + fwd_task = fleet_opt["tasks"][0] + fwd_prog = fwd_task.get_program() + fwd_block = fwd_prog.global_block() + + for var in feed_list: + if var.name not in fwd_block.vars: + fwd_block._clone_variable(var) + + for op_desc in read_ops_desc: + new_op_desc = fwd_block.desc._prepend_op() + new_op_desc.copy_from(op_desc) + new_op = Operator( + fwd_block, new_op_desc, type=new_op_desc.type() + ) + fwd_block.ops.insert(0, new_op) + + fwd_block._sync_with_cpp() + fwd_task.set_program(fwd_prog) + def _prepare_feed(self, data, user_feeds, mode): feeds = {} if data is not None: @@ -430,14 +454,16 @@ class Engine: fetch_names.append([]) fetch_indices.append(group_indices) + dist_context = self._dist_contexts[mode] + fetch_vars = dist_context.serial_fetch_vars if mode != "predict": - _process_fetch_group("loss", self._fetch_vars[mode]["loss"]) + _process_fetch_group("loss", fetch_vars["loss"]) if mode != "predict": - metrics = self._fetch_vars[mode]["metrics"] + metrics = fetch_vars["metrics"] for i, var_list in enumerate(metrics): _process_fetch_group("metrics_" + str(i), var_list) if mode == "predict": - _process_fetch_group("outputs", self._fetch_vars[mode]["outputs"]) + _process_fetch_group("outputs", fetch_vars["outputs"]) user_fetches_collection = [ item[1] for item in get_collection(CollectionNames.FETCHES) ] @@ -471,7 +497,8 @@ class Engine: logs["loss"] = outs[idx][0] group_idx += 1 # logging metrics - metric_vars = self._fetch_vars[mode]["metrics"] + dist_context = self._dist_contexts[mode] + metric_vars = dist_context.serial_fetch_vars["metrics"] if metric_vars: for metric in self._metrics: metrics_indices = fetch_indices[group_idx] @@ -502,15 +529,18 @@ class Engine: logs["fetches"] = logs_fetch return logs - def _prepare_program(self, mode): + def _prepare_program(self, mode, init_parameters=True): # Do the build process self._build(mode) # Do the planning process self._plan(mode) # Do the parallel process self._parallel(mode) - # Init comm and startup program - self._initialize(mode) + # Init comm + self._init_comm() + if init_parameters: + # startup program + self._initialize(mode) self._has_prepared[mode] = True def _build(self, mode): @@ -542,8 +572,8 @@ class Engine: paddle.enable_static() else: # build program in static mode - serial_main_prog = self._serial_main_progs.get(mode, None) - if serial_main_prog is not None: + dist_context = self._dist_contexts.get(mode, None) + if dist_context is not None: return outputs = [] @@ -581,7 +611,7 @@ class Engine: metric.compute(*(outputs + self._labels)) ) ) - else: + elif mode == "train": assert isinstance( self._loss, Variable ), "the type of `loss` of the Engine arguments should be Variable." @@ -724,37 +754,21 @@ class Engine: ) dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr) - def _initialize(self, mode): - # Get the current content from the distributed context - self._serial_main_progs[mode] = self._dist_contexts[ - mode - ].serial_main_program - self._serial_startup_progs[mode] = self._dist_contexts[ - mode - ].serial_startup_program - self._dist_main_progs[mode] = self._dist_contexts[ - mode - ].dist_main_programs - self._dist_startup_progs[mode] = self._dist_contexts[ - mode - ].dist_startup_programs - self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars - self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars - self._optimizer = self._dist_contexts[mode]._serial_optimizer - + def _init_comm(self): if self._nranks > 1: # Traverse different rank programs and traverse each op of them, # instantiate communication by process_mapping. all_process_groups = get_all_process_groups() if self._strategy.auto_mode == "full": - initialize_pg_in_full_mode(all_process_groups, cur_rank) + initialize_pg_in_full_mode(all_process_groups, self._cur_rank) else: for process_group in all_process_groups: if self._cur_rank not in process_group.ranks: continue process_group.instantiate() + def _initialize(self, mode): place = _get_device() if isinstance(place, fluid.CUDAPlace): place = fluid.CUDAPlace(ParallelEnv().dev_id) @@ -764,15 +778,17 @@ class Engine: np.random.seed(self._strategy.seed + self._dp_ranks[0]) random.seed(self._strategy.seed + self._dp_ranks[0]) + dist_context = self._dist_contexts[mode] if self._dygraph_mode: - dist_context = self._dist_contexts[mode] - dist_main_program = self._dist_main_progs[mode][self._cur_rank] + dist_main_program = dist_context.dist_main_programs[self._cur_rank] self.program_helper.init(dist_main_program, place, dist_context) if self._executor is None: self._executor = paddle.static.Executor(place) uninitialized = [] - dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + dist_startup_prog = dist_context.dist_startup_programs[ + self._cur_rank + ] for var in dist_startup_prog.list_vars(): scope_var = global_scope().find_var(var.name) if scope_var and scope_var.get_tensor()._is_initialized(): @@ -789,7 +805,9 @@ class Engine: if self._strategy.reinit: self._logger.info("NOTE: parameters will be re-initialized.") - dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank] + dist_startup_prog = dist_context.dist_startup_programs[ + self._cur_rank + ] self._executor.run(dist_startup_prog) def fit( @@ -926,7 +944,7 @@ class Engine: ) except core.EOFException: break - lr = get_lr(self._optimizer) + lr = get_lr(self.optimizer) logs = self._prepare_logger( outs, epoch, @@ -1262,6 +1280,7 @@ class Engine: main_program=None, startup_program=None, mode=None, + init_parameters=True, ): if mode is not None: self.to_mode(mode) @@ -1304,7 +1323,7 @@ class Engine: self._inputs_spec, self._labels_spec = inputs_spec, labels_spec self._inputs, self._labels = inputs, labels if not self._has_prepared[self._mode]: - self._prepare_program(self._mode) + self._prepare_program(self._mode, init_parameters) else: self._switch_mode(self._mode) @@ -1355,16 +1374,17 @@ class Engine: ) batch_size //= self._k_steps - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] - dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + dist_main_prog = dist_context.dist_main_programs[self._cur_rank] + dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_main_block = dist_main_prog.global_block() # NOTE: Get feed_list, then insert dataloader op with sharded var shape. # Cause predict_program does not contain labels var, # then we will add labels var from serial_program to dist_program, # that maintains the length of feed_list equal to the length of dataset's values. - inputs_var = self._feed_vars[self._mode]["inputs"] - labels_var = self._feed_vars[self._mode]["labels"] + inputs_var = dist_context.serial_feed_vars["inputs"] + labels_var = dist_context.serial_feed_vars["labels"] feed_list = [] for var in inputs_var + labels_var: if var.name in dist_main_block.vars: @@ -1423,16 +1443,17 @@ class Engine: ) batch_size //= self._k_steps - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] - dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + dist_main_prog = dist_context.dist_main_programs[self._cur_rank] + dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_main_block = dist_main_prog.global_block() # NOTE: Get feed_list, then insert dataloader op with sharded var shape. # Cause predict_program does not contain labels var, # then we will add labels var from serial_program to dist_program, # that maintains the length of feed_list equal to the length of dataset's values. - inputs_var = self._feed_vars[self._mode]["inputs"] - labels_var = self._feed_vars[self._mode]["labels"] + inputs_var = dist_context.serial_feed_vars["inputs"] + labels_var = dist_context.serial_feed_vars["labels"] feed_list = [] for var in inputs_var + labels_var: if var.name in dist_main_block.vars: @@ -1462,7 +1483,7 @@ class Engine: data_parallel_world_size=self._dp_world_sizes, data_parallel_rank=self._dp_ranks, ) - self._prepare_reader() + self._prepare_reader(feed_list) return dataloader def _tune(self, tune_data, tune_sample_split=None, batch_size=1): @@ -1551,10 +1572,9 @@ class Engine: def _switch_mode(self, mode): assert ( - mode in self._dist_main_progs + mode in self._dist_contexts ), "{} model is not ready, please call `prepare()` first.".format(mode) self.to_mode(mode) - self._optimizer = self._dist_contexts[mode]._serial_optimizer def to_mode(self, mode): assert mode in [ @@ -1565,8 +1585,8 @@ class Engine: self._mode = mode def _set_state_dict(self, mode, strict, state_dict, dist_attr): - program = self._dist_main_progs[mode][self._cur_rank] dist_context = self._dist_contexts[mode] + program = dist_context.dist_main_programs[self._cur_rank] cur_dist_attr = get_dist_attr(program, dist_context) converter = Converter(state_dict, dist_attr, cur_dist_attr) state_dict = converter.convert(strict=strict) @@ -1618,10 +1638,10 @@ class Engine: """ if training: - assert self._mode in self._serial_main_progs - serial_program = self._serial_main_progs[self._mode] - dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank] + assert self._mode in self._dist_contexts dist_context = self._dist_contexts[self._mode] + serial_program = dist_context.serial_main_program + dist_main_prog = dist_context.dist_main_programs[self._cur_rank] self._saver.save( path, serial_program=serial_program, @@ -1629,10 +1649,11 @@ class Engine: dist_context=dist_context, ) else: - assert "predict" in self._dist_main_progs - feed_vars = self._feed_vars["predict"]['inputs'] - fetch_vars = self._fetch_vars["predict"]['outputs'] - dist_main_prog = self._dist_main_progs["predict"][self._cur_rank] + assert "predict" in self._dist_contexts + dist_context = self._dist_contexts["predict"] + feed_vars = dist_context.serial_feed_vars['inputs'] + fetch_vars = dist_context.serial_fetch_vars['outputs'] + dist_main_prog = dist_context.dist_main_programs[self._cur_rank] self._saver.save_inference_model( path, feed_vars, @@ -1758,11 +1779,13 @@ class Engine: @property def main_program(self): - return self._dist_main_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + return dist_context.dist_main_programs[self._cur_rank] @property def startup_program(self): - return self._dist_startup_progs[self._mode][self._cur_rank] + dist_context = self._dist_contexts[self._mode] + return dist_context.dist_startup_programs[self._cur_rank] @property def dist_context(self): @@ -1770,15 +1793,30 @@ class Engine: @property def serial_main_program(self): - return self._serial_main_progs[self._mode] + dist_context = self._dist_contexts[self._mode] + return dist_context.serial_main_program @property def serial_startup_program(self): - return self._serial_startup_progs[self._mode] + dist_context = self._dist_contexts[self._mode] + return dist_context.serial_startup_program + + @property + def feed_vars(self): + dist_context = self._dist_contexts[self._mode] + return dist_context.serial_feed_vars @property def fetch_vars(self): - return self._fetch_vars[self._mode] + dist_context = self._dist_contexts[self._mode] + return dist_context.serial_fetch_vars + + @property + def optimizer(self): + dist_context = self._dist_contexts[self._mode] + if dist_context._serial_optimizer: + return dist_context._serial_optimizer + return self._optimizer @property def inputs(self): diff --git a/python/paddle/distributed/auto_parallel/interface.py b/python/paddle/distributed/auto_parallel/interface.py index a0dcb488658..b154209700a 100644 --- a/python/paddle/distributed/auto_parallel/interface.py +++ b/python/paddle/distributed/auto_parallel/interface.py @@ -67,29 +67,43 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): """ if process_mesh is not None: - assert isinstance(process_mesh, ProcessMesh), \ - "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) + assert isinstance( + process_mesh, ProcessMesh + ), "Argument process_mesh {} is not an instance of ProcessMesh".format( + process_mesh + ) else: process_mesh = get_current_process_mesh() - assert process_mesh is not None, \ - "Specify the process mesh argument or use ProcessMesh context manager first." - assert isinstance(shard_spec, list), \ - "Argument shard_spec {} is not an instance of list".format(shard_spec) - dist_tensor = DistributedTensor(x) + assert ( + process_mesh is not None + ), "Specify the process mesh argument or use ProcessMesh context manager first." + assert isinstance( + shard_spec, list + ), "Argument shard_spec {} is not an instance of list".format(shard_spec) + if isinstance(x, str): + x = paddle.fluid.default_main_program().global_block()._var_recursive(x) + dist_tensor = DistributedTensor(x) + else: + dist_tensor = DistributedTensor(x) serial_tensor = dist_tensor.serial_tensor dist_tensor.dist_attr.process_mesh = process_mesh - if serial_tensor.type == core.VarDesc.VarType.READER \ - or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY \ - or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES: + if ( + serial_tensor.type == core.VarDesc.VarType.READER + or serial_tensor.type == core.VarDesc.VarType.LOD_TENSOR_ARRAY + or serial_tensor.type == core.VarDesc.VarType.STEP_SCOPES + ): tensor_shape = [] else: tensor_shape = serial_tensor.shape if shard_spec is not None: - assert verify_shard_spec(shard_spec, tensor_shape, process_mesh), \ - "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( - serial_tensor.name, shard_spec, tensor_shape, process_mesh) + assert verify_shard_spec( + shard_spec, tensor_shape, process_mesh + ), "For tensor {}, shard_spec {} is invalid with tensor_shape {} and process_mesh {}.".format( + serial_tensor.name, shard_spec, tensor_shape, process_mesh + ) dist_tensor.dist_attr.dims_mapping = convert_to_dims_mapping( - shard_spec, process_mesh) + shard_spec, process_mesh + ) if process_mesh is not None: dist_tensor.dist_attr.mark_annotated("process_mesh") if shard_spec is not None: @@ -97,6 +111,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None): default_dist_ctx = get_default_distributed_context() default_dist_ctx.add_dist_tensor_for_program(dist_tensor) dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x) + default_dist_ctx.add_process_mesh(process_mesh) return x @@ -144,41 +159,54 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None): """ if process_mesh is not None: - assert isinstance(process_mesh, ProcessMesh), \ - "Argument process_mesh {} is not an instance of ProcessMesh".format(process_mesh) + assert isinstance( + process_mesh, ProcessMesh + ), "Argument process_mesh {} is not an instance of ProcessMesh".format( + process_mesh + ) else: process_mesh = get_current_process_mesh() - assert process_mesh is not None, \ - "Specify the process mesh argument or use ProcessMesh context manager first." + assert ( + process_mesh is not None + ), "Specify the process mesh argument or use ProcessMesh context manager first." in_dims_mappings = [] if in_shard_specs is not None: - assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in in_shard_specs), \ - "in_shard_spec {} is not a list of list or None".format(in_shard_specs) + assert all( + (isinstance(shard_spec, list) or shard_spec is None) + for shard_spec in in_shard_specs + ), "in_shard_spec {} is not a list of list or None".format( + in_shard_specs + ) for shard_spec in in_shard_specs: if shard_spec is not None: in_dims_mappings.append( - convert_to_dims_mapping(shard_spec, process_mesh)) + convert_to_dims_mapping(shard_spec, process_mesh) + ) else: in_dims_mappings.append(None) out_dims_mappings = [] if out_shard_specs is not None: - assert all((isinstance(shard_spec, list) or shard_spec is None) for shard_spec in out_shard_specs), \ - "out_shard_spec {} is not a list of list or None".format(out_shard_specs) + assert all( + (isinstance(shard_spec, list) or shard_spec is None) + for shard_spec in out_shard_specs + ), "out_shard_spec {} is not a list of list or None".format( + out_shard_specs + ) for shard_spec in out_shard_specs: if shard_spec is not None: out_dims_mappings.append( - convert_to_dims_mapping(shard_spec, process_mesh)) + convert_to_dims_mapping(shard_spec, process_mesh) + ) else: out_dims_mappings.append(None) - op = DistributedOperatorHelper(op, process_mesh, in_dims_mappings, - out_dims_mappings) + op = DistributedOperatorHelper( + op, process_mesh, in_dims_mappings, out_dims_mappings + ) return op def recompute(op): - class RecomputeOperator: - def __init__(self, op): self._op = op @@ -219,11 +247,13 @@ def add_to_collection(collection_name, value, name=None): _g_collections[collection_name] = [] if name is not None: for _, v in _g_collections[collection_name]: - if v == value: return + if v == value: + return _g_collections[collection_name].append((name, value)) else: for _, v in _g_collections[collection_name]: - if v == value: return + if v == value: + return _g_collections[collection_name].append((None, value)) diff --git a/python/paddle/distributed/auto_parallel/operators/__init__.py b/python/paddle/distributed/auto_parallel/operators/__init__.py index 4a0a05a4f1c..406ec4d8b36 100644 --- a/python/paddle/distributed/auto_parallel/operators/__init__.py +++ b/python/paddle/distributed/auto_parallel/operators/__init__.py @@ -35,3 +35,4 @@ from . import dist_fused_attention from . import dist_reduce_sum_p from . import dist_shape from . import dist_assign +from . import dist_scale diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index e7e7ad1e0ea..9137322cc71 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -14,7 +14,11 @@ import abc import paddle -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) from ..dist_attribute import OperatorDistributedAttribute from ..utils import _get_comm_group, _get_corresponding_rank, is_optimize_op from ..process_group import new_process_group @@ -22,16 +26,22 @@ from ..process_group import new_process_group _g_distributed_operator_impl_containers = {} _g_elementwise_ops = [ - "elementwise", "gelu", "dropout", "cast", "gather", "concat", - "fused_softmax_mask_upper_triangle" + "elementwise", + "gelu", + "dropout", + "cast", + "gather", + "concat", + "fused_softmax_mask_upper_triangle", ] BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} -class ParallelMode(): +class ParallelMode: """ the parallel mode for communication or auxiliary operator """ + DataParallel = "auto_parallel/data_parallel" ModelParallel = "auto_parallel/model_parallel" PipelineParalel = "auto_parallel/pipeline_paralel" @@ -47,7 +57,6 @@ def is_elementwise_op(op_type): class DistributedOperatorImplContainer: - def __init__(self, op_type): self._type = op_type self._impls = [] @@ -65,8 +74,9 @@ class DistributedOperatorImplContainer: return self._impls def register_impl(self, dist_impl): - assert self.type == dist_impl.type, \ - "Op type of container must be same as that of the implementation." + assert ( + self.type == dist_impl.type + ), "Op type of container must be same as that of the implementation." impl_idx = len(self.impls) dist_impl.idx = impl_idx self._impls.append(dist_impl) @@ -97,7 +107,6 @@ class DistributedOperatorImplContainer: class DistributedOperatorImpl(abc.ABC): - def __init__(self, name): self._name = name self._type = None @@ -176,60 +185,75 @@ def register_distributed_operator_impl(op_type, dist_impl): def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): """ - Here just return the first compatible implemention. + Here just return the first compatible implemention. This will be improved by cost model in the future. """ op_type = dist_op.serial_op.type dist_op_impl_container = get_distributed_operator_impl_container(op_type) dist_op_eltwise_impl_container = get_distributed_operator_impl_container( - "elementwise") + "elementwise" + ) dist_op_default_impl_container = get_distributed_operator_impl_container( - "default") + "default" + ) compatible_impls = [] if partial: if fwd: # First, find impls in the corresponding container if dist_op_impl_container: compatible_impls.extend( - dist_op_impl_container.get_input_compatible_impls(dist_op)) + dist_op_impl_container.get_input_compatible_impls(dist_op) + ) # Second, find impls in the elementwise container if dist_op_eltwise_impl_container and is_elementwise_op(op_type): compatible_impls.extend( dist_op_eltwise_impl_container.get_input_compatible_impls( - dist_op)) + dist_op + ) + ) # Third, find impls in the default container if dist_op_default_impl_container: compatible_impls.extend( dist_op_default_impl_container.get_input_compatible_impls( - dist_op)) + dist_op + ) + ) else: # First, find impls in the corresponding container if dist_op_impl_container: compatible_impls.extend( - dist_op_impl_container.get_output_compatible_impls(dist_op)) + dist_op_impl_container.get_output_compatible_impls(dist_op) + ) # Second, find impls in the elementwise container if dist_op_eltwise_impl_container and is_elementwise_op(op_type): compatible_impls.extend( dist_op_eltwise_impl_container.get_output_compatible_impls( - dist_op)) + dist_op + ) + ) # Third, find impls in the default container if dist_op_default_impl_container: compatible_impls.extend( dist_op_default_impl_container.get_output_compatible_impls( - dist_op)) + dist_op + ) + ) else: # First, find impls in the corresponding container if dist_op_impl_container: compatible_impls.extend( - dist_op_impl_container.get_compatible_impls(dist_op)) + dist_op_impl_container.get_compatible_impls(dist_op) + ) # Second, find impls in the elementwise container if dist_op_eltwise_impl_container and is_elementwise_op(op_type): compatible_impls.extend( - dist_op_eltwise_impl_container.get_compatible_impls(dist_op)) + dist_op_eltwise_impl_container.get_compatible_impls(dist_op) + ) # Third, find impls in the default container if dist_op_default_impl_container: compatible_impls.extend( - dist_op_default_impl_container.get_compatible_impls(dist_op)) + dist_op_default_impl_container.get_compatible_impls(dist_op) + ) if compatible_impls: # For now, just return the first compatible impl @@ -242,18 +266,18 @@ def find_compatible_distributed_operator_impls(dist_op, fwd=True, partial=True): def is_parameter_related(varname, block): if ".subprog_" in varname: - varname = varname[:varname.index(".subprog_")] + varname = varname[: varname.index(".subprog_")] if ".cast_fp" in varname: - varname = varname[:varname.index(".cast_fp")] + varname = varname[: varname.index(".cast_fp")] if ".quantized" in varname: - varname = varname[:varname.index(".quantized")] + varname = varname[: varname.index(".quantized")] assert block.has_var(varname) var = block.var(varname) return var.is_parameter def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): - var_shape = block.var(src_var.name).shape + var_shape = block._var_recursive(src_var.name).shape var_topoloy = src_var_dist_attr.process_mesh.topology var_dims_mapping = src_var_dist_attr.dims_mapping @@ -278,8 +302,9 @@ def infer_shape(block, src_var, src_var_dist_attr, op_input_dist_attr): return exact_shape -def set_comm_op_dist_attr_for_program(new_op, process_mesh, tensor_dist_attr, - ctx): +def set_comm_op_dist_attr_for_program( + new_op, process_mesh, tensor_dist_attr, ctx +): assert process_mesh is not None assert tensor_dist_attr is not None @@ -304,9 +329,11 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): assert len(new_op.input(input_name)) == 1 ref_tensor_dist_attr = ref_dist_attr.get_input_dist_attr( - ref_op.input(input_name)[0]) + ref_op.input(input_name)[0] + ) new_op_dist_attr.set_input_dist_attr( - new_op.input(input_name)[0], ref_tensor_dist_attr) + new_op.input(input_name)[0], ref_tensor_dist_attr + ) for output_name in ref_op.output_names: assert output_name in new_op.output_names @@ -314,9 +341,11 @@ def naive_copy_op_dist_attr_for_program(new_op, ref_op, ctx): assert len(new_op.output(output_name)) == 1 ref_tensor_dist_attr = ref_dist_attr.get_output_dist_attr( - ref_op.output(output_name)[0]) + ref_op.output(output_name)[0] + ) new_op_dist_attr.set_output_dist_attr( - new_op.output(output_name)[0], ref_tensor_dist_attr) + new_op.output(output_name)[0], ref_tensor_dist_attr + ) ctx.set_op_dist_attr_for_program(new_op, new_op_dist_attr) @@ -327,9 +356,9 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): Args: dist_ctx (DistributedContext): dist context. - op (Operator): the current (backward) operator which might need. - act_grad_names (list): list of input activation grads variable name to the current operator. - out_grad_names (list): list of the output parameter's grads variable name of the current operator. + op (Operator): the current (backward) operator which might need. + act_grad_names (list): list of input activation grads variable name to the current operator. + out_grad_names (list): list of the output parameter's grads variable name of the current operator. rank (int): global ranks index for current process. """ dp_group = None @@ -349,9 +378,12 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): batch_size_axis = var_dim_mapping[0] if len(var_dim_mapping) > 0 else -1 if batch_size_axis > -1 and mesh_shape[batch_size_axis] > 1: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - batch_size_axis, rank) + group_ranks = _get_comm_group( + process_mesh.processes, + process_mesh.topology, + batch_size_axis, + rank, + ) dp_group = new_process_group(group_ranks) break @@ -360,13 +392,13 @@ def get_data_parallel_group(dist_ctx, op, act_grad_names, rank): def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): """ - insert the allreudce and scale ops for gradients of model + insert the allreudce and scale ops for gradients of model parameters for operator in data parallelism. Args: dist_ctx (DistributedContext): dist context. - op (Operator): the current (backward) operator which might need. - allreduce_var_names (list): list of the parameter's grads variable name in the current operator output. + op (Operator): the current (backward) operator which might need. + allreduce_var_names (list): list of the parameter's grads variable name in the current operator output. """ op_dist_attr = dist_ctx.get_op_dist_attr_for_program(op) @@ -378,33 +410,39 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): for var_name in allreduce_var_names: added_ops = [] grad_var = main_block.var(var_name) - allreduce_op = main_block.append_op(type='c_allreduce_sum', - inputs={'X': [grad_var]}, - outputs={'Out': [grad_var]}, - attrs={ - 'ring_id': dp_group.id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) - allreduce_op._set_attr('op_namescope', - str('/') + ParallelMode.DataParallel) + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [grad_var]}, + outputs={'Out': [grad_var]}, + attrs={ + 'ring_id': dp_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward, + }, + ) + allreduce_op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) added_ops.append(allreduce_op) if dist_ctx.gradient_scale: - scale_op = main_block.append_op(type='scale', - inputs={'X': grad_var}, - outputs={'Out': grad_var}, - attrs={ - 'scale': 1.0 / dp_degree, - OP_ROLE_KEY: OpRole.Backward - }) - scale_op._set_attr('op_namescope', - str('/') + ParallelMode.DataParallel) + scale_op = main_block.append_op( + type='scale', + inputs={'X': grad_var}, + outputs={'Out': grad_var}, + attrs={'scale': 1.0 / dp_degree, OP_ROLE_KEY: OpRole.Backward}, + ) + scale_op._set_attr( + 'op_namescope', str('/') + ParallelMode.DataParallel + ) added_ops.append(scale_op) dims_mapping = op_dist_attr.get_output_dims_mapping(grad_var.name) - assert dims_mapping is not None, "Unexception: dims_mapping of output [{}] of op [{}] is None".format( - grad_var.name, op_dist_attr.op_type) + assert ( + dims_mapping is not None + ), "Unexception: dims_mapping of output [{}] of op [{}] is None".format( + grad_var.name, op_dist_attr.op_type + ) # NOTE auxiliary op's dist attr should follow dist_op not dist_tensor for new_op in added_ops: new_op_attr = OperatorDistributedAttribute() @@ -414,25 +452,29 @@ def sync_and_scale_gradients(dist_ctx, op, dp_group, allreduce_var_names): dist_ctx.set_op_dist_attr_for_program(new_op, new_op_attr) -def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, - rank): +def gradient_synchronization( + dist_ctx, op, act_grad_names, out_grad_names, rank +): """ - conduct the allreudce and scaling(dp size)for gradients of model + conduct the allreudce and scaling(dp size)for gradients of model parameters for operator in data parallelism. Args: dist_ctx (DistributedContext): dist context. - op (Operator): the current (backward) operator which might need. - act_grad_names (list): list of input activation grads variable name to the current operator. - out_grad_names (list): list of the output parameter's grads variable name of the current operator. + op (Operator): the current (backward) operator which might need. + act_grad_names (list): list of input activation grads variable name to the current operator. + out_grad_names (list): list of the output parameter's grads variable name of the current operator. rank (int): global ranks index for current process. """ if not is_in_backward_phase(dist_ctx): return - if is_optimize_op(op) or len(act_grad_names) == 0 or len( - out_grad_names) == 0: + if ( + is_optimize_op(op) + or len(act_grad_names) == 0 + or len(out_grad_names) == 0 + ): return dp_group = get_data_parallel_group(dist_ctx, op, act_grad_names, rank) @@ -444,13 +486,19 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names, def is_data_parallel_scale_op(op): - return op.type == "scale" and op.desc.has_attr("op_namescope") \ - and ParallelMode.DataParallel in op.desc.attr("op_namescope") + return ( + op.type == "scale" + and op.desc.has_attr("op_namescope") + and ParallelMode.DataParallel in op.desc.attr("op_namescope") + ) def is_data_parallel_reduce_op(op): - return op.type in ["c_reduce_sum", "c_allreduce_sum"] and op.desc.has_attr("op_namescope") \ - and ParallelMode.DataParallel in op.desc.attr("op_namescope") + return ( + op.type in ["c_reduce_sum", "c_allreduce_sum"] + and op.desc.has_attr("op_namescope") + and ParallelMode.DataParallel in op.desc.attr("op_namescope") + ) def is_in_backward_phase(dist_ctx): diff --git a/python/paddle/distributed/auto_parallel/operators/dist_default.py b/python/paddle/distributed/auto_parallel/operators/dist_default.py index a5139e00189..69f0288bcf4 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_default.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_default.py @@ -29,7 +29,11 @@ from paddle.fluid import core, unique_name from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import Program, Parameter, Variable, program_guard from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype -from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY +from paddle.distributed.fleet.meta_optimizers.common import ( + OpRole, + OP_ROLE_KEY, + OP_ROLE_VAR_KEY, +) from ..process_group import new_process_group from ..utils import _get_comm_group, _get_corresponding_rank from ..cost import _g_op_cost_factory @@ -37,6 +41,7 @@ from ..cost import build_comp_desc_from_dist_op, build_dp_costs from ..cost import build_comp_costs_from_descs __op_not_need_param_init__ = ["while", "cond"] +__op_has_shape_attr__ = ["fill_constant_batch_size_like", "fill_constant"] def prim_operator_data_parallel_functor(ctx, src_op): @@ -46,35 +51,41 @@ def prim_operator_data_parallel_functor(ctx, src_op): var_name = src_op.output_arg_names[0] if var_name in ctx.grads_params: - assert var_name not in ctx.synced_gradient, "in primtive mode, grad is already {} synced".format( - var_name) + assert ( + var_name not in ctx.synced_gradient + ), "in primtive mode, grad is already {} synced".format(var_name) ctx.synced_gradient.add(var_name) sync_group = new_process_group(ctx.data_parallel_group) - allreduce_op = main_block.append_op(type='c_allreduce_sum', - inputs={'X': [var_name]}, - outputs={'Out': [var_name]}, - attrs={ - 'ring_id': sync_group.id, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Backward - }) + allreduce_op = main_block.append_op( + type='c_allreduce_sum', + inputs={'X': [var_name]}, + outputs={'Out': [var_name]}, + attrs={ + 'ring_id': sync_group.id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Backward, + }, + ) param = ctx.grads_params[var_name] startup_block = dist_op_context.startup_block - new_op = startup_block.append_op(type='c_broadcast', - inputs={'X': [param]}, - outputs={'Out': [param]}, - attrs={ - 'ring_id': sync_group.id, - 'root': 0, - 'use_calc_stream': True, - OP_ROLE_KEY: OpRole.Forward - }) + new_op = startup_block.append_op( + type='c_broadcast', + inputs={'X': [param]}, + outputs={'Out': [param]}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) grad_var = main_block.var(var_name) dims_mapping = ctx.get_tensor_dist_attr_for_program( - grad_var).dims_mapping + grad_var + ).dims_mapping dist_attr = ctx.get_op_dist_attr_for_program(src_op) process_mesh = dist_attr.process_mesh op_attr = OperatorDistributedAttribute() @@ -87,7 +98,6 @@ def prim_operator_data_parallel_functor(ctx, src_op): class DistributedDefault(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedDefault, self).__init__(op_type) @@ -97,7 +107,6 @@ register_distributed_operator_impl_container(DistributedDefault("default")) # Replicated Default class DistributedDefaultImpl0(DistributedOperatorImpl): - def __init__(self, name): super(DistributedDefaultImpl0, self).__init__(name) self._forward_implemented = True @@ -115,13 +124,14 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes op_type = dist_op.serial_op.type - cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], - ctx, processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster + ) res_cost = [cost_mapping] return res_cost @@ -129,16 +139,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): def calc_bwd_cost(self, dist_op, ctx, cluster): # calc comp op cost res = [] - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) dist_attr = dist_op.dist_attr process_mesh = dist_attr.process_mesh processes = process_mesh.processes backward_op = dist_op.serial_op op_type = backward_op.type - cost_mapping = build_comp_costs_from_descs(_g_op_cost_factory[op_type], - ctx, processes, desc_mapping, - cluster) + cost_mapping = build_comp_costs_from_descs( + _g_op_cost_factory[op_type], ctx, processes, desc_mapping, cluster + ) res.append(cost_mapping) main_block = backward_op.block @@ -147,7 +158,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and not is_parameter_related( - varname, main_block): + varname, main_block + ): var_dim_mapping = dist_attr.get_input_dims_mapping(varname) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] @@ -159,16 +171,25 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and is_parameter_related( - varname, main_block): + varname, main_block + ): var_dim_mapping = dist_attr.get_input_dims_mapping( - varname) + varname + ) mesh_shape = process_mesh.topology batch_size_axis = var_dim_mapping[0] parallel_axis = batch_size_axis attrs = {"use_calc_stream": True} var_names = [varname + "@GRAD"] - build_dp_costs(res, dist_op, ctx, var_names, attrs, - parallel_axis, cluster) + build_dp_costs( + res, + dist_op, + ctx, + var_names, + attrs, + parallel_axis, + cluster, + ) return res def is_input_compatible(self, dist_op): @@ -312,8 +333,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): batch_dim_mappings.append(dims_mapping[1]) # Check batch dim mapping compatibility - if not all(batch_dim_mappings[0] == dim_mapping - for dim_mapping in batch_dim_mappings): + if not all( + batch_dim_mappings[0] == dim_mapping + for dim_mapping in batch_dim_mappings + ): return False return True @@ -350,7 +373,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for arg_name in op_desc.output_arg_names(): if op_desc.type() == 'fill_any_like': input_tensor = dist_op.get_serial_input( - op_desc.input_arg_names()[0]) + op_desc.input_arg_names()[0] + ) if input_tensor.is_parameter: continue serial_tensor = dist_op.get_serial_output(arg_name) @@ -367,7 +391,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): return changed compatible_dim_mapping = compute_compatible_dim_mapping( - batch_dim_mappings) + batch_dim_mappings + ) if compatible_dim_mapping is None: return False @@ -377,19 +402,24 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): continue dims_mapping = op_dist_attr.get_input_dims_mapping(arg_name) if arg_name not in input_xshape_arg_names: - if len(dims_mapping) >= 1 and \ - compatible_dim_mapping != dims_mapping[0]: + if ( + len(dims_mapping) >= 1 + and compatible_dim_mapping != dims_mapping[0] + ): dims_mapping[0] = compatible_dim_mapping changed = True else: - if len(dims_mapping) >= 2 and \ - compatible_dim_mapping != dims_mapping[1]: + if ( + len(dims_mapping) >= 2 + and compatible_dim_mapping != dims_mapping[1] + ): dims_mapping[1] = compatible_dim_mapping changed = True for arg_name in op_desc.output_arg_names(): if op_desc.type() == 'fill_any_like': input_tensor = dist_op.get_serial_input( - op_desc.input_arg_names()[0]) + op_desc.input_arg_names()[0] + ) if input_tensor.is_parameter: continue if op_desc.type() in ["shape", "slice"]: @@ -399,13 +429,17 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): continue dims_mapping = op_dist_attr.get_output_dims_mapping(arg_name) if arg_name not in output_xshape_arg_names: - if len(dims_mapping - ) >= 1 and compatible_dim_mapping != dims_mapping[0]: + if ( + len(dims_mapping) >= 1 + and compatible_dim_mapping != dims_mapping[0] + ): dims_mapping[0] = compatible_dim_mapping changed = True else: - if len(dims_mapping - ) >= 2 and compatible_dim_mapping != dims_mapping[1]: + if ( + len(dims_mapping) >= 2 + and compatible_dim_mapping != dims_mapping[1] + ): dims_mapping[1] = compatible_dim_mapping changed = True @@ -422,17 +456,20 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # check validation of inputs / outputs for input_name in src_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( src_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in src_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( src_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) # replicate op in dist program dist_op_desc = main_block.append_op(type='nop').desc @@ -443,8 +480,29 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for output_name in src_op.desc.output_names(): dist_op_desc.set_output(output_name, kwargs[output_name]) + if ( + src_op.has_attr('shape') + and src_op.attr('shape') + and src_op.type in __op_has_shape_attr__ + ): + shape_list = src_op.attr('shape') + Out_var = main_block._var_recursive(kwargs['Out'][0]) + op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) + dim_mapping = op_dist_attr.get_output_dims_mapping(Out_var.name) + process_mesh_shape = op_dist_attr.process_mesh.shape + assert len(shape_list) == len(dim_mapping) + # modify target shape + for idx, axis in enumerate(dim_mapping): + if axis >= 0: + if len(shape_list) > idx: + shape_list[idx] = ( + shape_list[idx] // process_mesh_shape[axis] + ) + dist_op_desc._set_attr('shape', shape_list) + # data parallel synchronization for primtive operators from paddle.incubate.autograd import prim_enabled + if prim_enabled(): assert is_prim_op(src_op) prim_operator_data_parallel_functor(ctx, src_op) @@ -455,9 +513,11 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): return for varname in dist_op_desc.input_arg_names(): - if startup_block.has_var(varname) and startup_block.var( - varname - ).is_parameter and varname not in dist_op_context.already_init_sync_vars: + if ( + startup_block.has_var(varname) + and startup_block.var(varname).is_parameter + and varname not in dist_op_context.already_init_sync_vars + ): dist_op_context.already_init_sync_vars.add(varname) param = startup_block.var(varname) param_dist_attr = ctx.get_tensor_dist_attr_for_program(param) @@ -466,38 +526,41 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): # FIXME (JZ-LIANG) Remove this hack to support any op mesh group for Pipeline Parallelism if rank_id not in process_mesh.processes: - rank_id = _get_corresponding_rank(ctx, process_mesh, - rank_id) + rank_id = _get_corresponding_rank( + ctx, process_mesh, rank_id + ) # NOTE all not splited axis should be presented in mesh for axis, size in enumerate(process_mesh.topology): if size <= 1 or axis in dims_mapping: pass else: - group_ranks = _get_comm_group(process_mesh.processes, - process_mesh.topology, - axis, rank_id) + group_ranks = _get_comm_group( + process_mesh.processes, + process_mesh.topology, + axis, + rank_id, + ) sync_group = new_process_group(group_ranks) - new_op = startup_block.append_op(type='c_broadcast', - inputs={'X': param}, - outputs={'Out': param}, - attrs={ - 'ring_id': - sync_group.id, - 'root': - 0, - 'use_calc_stream': - True, - OP_ROLE_KEY: - OpRole.Forward - }) + new_op = startup_block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': sync_group.id, + 'root': 0, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Forward, + }, + ) # set distributed attribute op_attr = OperatorDistributedAttribute() op_attr.process_mesh = process_mesh - op_attr.set_output_dims_mapping(param.name, - dims_mapping) + op_attr.set_output_dims_mapping( + param.name, dims_mapping + ) op_attr.set_input_dims_mapping(param.name, dims_mapping) ctx.set_op_dist_attr_for_program(new_op, op_attr) @@ -509,24 +572,30 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): main_block = dist_op_context.work_block backward_op = dist_op_context.cur_src_op dist_attr = ctx.get_op_dist_attr_for_program(backward_op) - assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( - str(backward_op)) + assert ( + dist_attr is not None + ), "backward op [{}] don't have dist attribute !".format( + str(backward_op) + ) rank_id = dist_op_context.rank_id # check validation of inputs / outputs for input_name in backward_op.desc.input_names(): assert input_name in kwargs, "input [{}] is not given".format( - input_name) + input_name + ) assert len(kwargs[input_name]) == len( backward_op.desc.input(input_name) ), "number of tensor for input [{}] is not match".format(input_name) for output_name in backward_op.desc.output_names(): assert output_name in kwargs, "input [{}] is not given".format( - output_name) + output_name + ) assert len(kwargs[output_name]) == len( backward_op.desc.output(output_name) ), "number of tensor for input [{}] is not match".format( - output_name) + output_name + ) # replicate op in dist program dist_op_desc = main_block.append_op(type='nop').desc @@ -543,7 +612,8 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): for input_name in backward_op.desc.input_names(): for varname in backward_op.desc.input(input_name): if "@GRAD" not in varname and not is_parameter_related( - varname, main_block): + varname, main_block + ): act_grad_names.append(varname) out_grad_names = [] @@ -556,9 +626,11 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): if is_parameter_related(fwd_name, main_block): out_grad_names.append(varname) - gradient_synchronization(ctx, backward_op, act_grad_names, - out_grad_names, rank_id) + gradient_synchronization( + ctx, backward_op, act_grad_names, out_grad_names, rank_id + ) register_distributed_operator_impl( - "default", DistributedDefaultImpl0("replicate_parallel")) + "default", DistributedDefaultImpl0("replicate_parallel") +) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py index 3b519c2cc5b..68f28a87630 100644 --- a/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py +++ b/python/paddle/distributed/auto_parallel/operators/dist_fill_constant_batch_size_like.py @@ -32,21 +32,22 @@ from .dist_default import DistributedDefaultImpl0 from ..cost import FillConstantBatchSizeLikeOpCost from ..cost import build_comp_desc_from_dist_op, build_dp_costs from ..cost import build_comp_costs_from_descs -from paddle.distributed.auto_parallel.cost.comm_op_cost import AllreduceSumOpCost +from paddle.distributed.auto_parallel.cost.comm_op_cost import ( + AllreduceSumOpCost, +) class DistributedFillConstantBatchSizeLike(DistributedOperatorImplContainer): - def __init__(self, op_type): super(DistributedFillConstantBatchSizeLike, self).__init__(op_type) register_distributed_operator_impl_container( - DistributedFillConstantBatchSizeLike("fill_constant_batch_size_like")) + DistributedFillConstantBatchSizeLike("fill_constant_batch_size_like") +) class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): - def __init__(self, name): super(DistributedFillConstantBatchSizeLikeImpl0, self).__init__(name) self._forward_implemented = True @@ -56,7 +57,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): cost = None if int(op_role) == int(OpRole.Backward): raise ValueError( - "The fill_constant_batch_size_like has no grad op.") + "The fill_constant_batch_size_like has no grad op." + ) else: cost = self.calc_fwd_cost(dist_op, ctx, cluster) assert cost is not None @@ -64,13 +66,18 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): def calc_fwd_cost(self, dist_op, ctx, cluster): # calc comp op cost - desc_mapping = build_comp_desc_from_dist_op(dist_op=dist_op, - dist_context=ctx) + desc_mapping = build_comp_desc_from_dist_op( + dist_op=dist_op, dist_context=ctx + ) processes = dist_op.dist_attr.process_mesh.processes op_type = dist_op.serial_op.type cost_mapping = build_comp_costs_from_descs( - FillConstantBatchSizeLikeOpCost, ctx, processes, desc_mapping, - cluster) + FillConstantBatchSizeLikeOpCost, + ctx, + processes, + desc_mapping, + cluster, + ) res_cost = [cost_mapping] return res_cost @@ -92,8 +99,9 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): return True def is_auto_compatible(self, dist_op): - if (not self.is_input_compatible(dist_op)) or \ - (not self.is_output_compatible(dist_op)): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): return False op_desc = dist_op.serial_op.desc op_dist_attr = dist_op.dist_attr @@ -116,7 +124,8 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): # only the batch size dimemsion of input and output are relative. dim_changed = compute_compatible_and_update_dim_mapping( - [x_dims_mapping, out_dims_mapping], [0, 0]) + [x_dims_mapping, out_dims_mapping], [0, 0] + ) if dim_changed: changed = True @@ -128,24 +137,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): kwargs: inputname_mapping & outputname_mapping """ DistributedDefaultImpl0.forward(ctx, *args, **kwargs) - dist_op_context = ctx.dist_op_context - src_op = dist_op_context.cur_src_op - op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) - main_block = dist_op_context.work_block - op = main_block.ops[-1] - assert op.type == "fill_constant_batch_size_like" - - # modify shape attr according to how output are partitioned - out_name = op.output('Out')[0] - dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) - process_mesh_shape = op_dist_attr.process_mesh.topology - shape_list = op.attr("shape") - # modify target shape - for idx, axis in enumerate(dims_mapping): - if axis >= 0: - shape_list[idx] = shape_list[idx] // process_mesh_shape[axis] - - op._set_attr("shape", shape_list) @staticmethod def backward(ctx, *args, **kwargs): @@ -154,4 +145,5 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): register_distributed_operator_impl( "fill_constant_batch_size_like", - DistributedFillConstantBatchSizeLikeImpl0("fill_by_shape")) + DistributedFillConstantBatchSizeLikeImpl0("fill_by_shape"), +) diff --git a/python/paddle/distributed/auto_parallel/operators/dist_scale.py b/python/paddle/distributed/auto_parallel/operators/dist_scale.py new file mode 100644 index 00000000000..9fc28d05a20 --- /dev/null +++ b/python/paddle/distributed/auto_parallel/operators/dist_scale.py @@ -0,0 +1,90 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import compute_compatible_and_update_dim_mapping +from .common import ( + DistributedOperatorImpl, + DistributedOperatorImplContainer, + register_distributed_operator_impl, + register_distributed_operator_impl_container, +) +from .dist_default import DistributedDefaultImpl0 + + +class DistributedScale(DistributedOperatorImplContainer): + def __init__(self, op_type): + super().__init__(op_type) + + +register_distributed_operator_impl_container(DistributedScale("scale")) + + +class DistributedScaleImpl(DistributedOperatorImpl): + def __init__(self, name): + super().__init__(name) + self._forward_implemented = True + self._backward_implemented = True + + def is_input_compatible(self, dist_op): + return True + + def is_output_compatible(self, dist_op): + return True + + def is_auto_compatible(self, dist_op): + if (not self.is_input_compatible(dist_op)) or ( + not self.is_output_compatible(dist_op) + ): + return False + + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + if x_dims_mapping != out_dims_mapping: + return False + + return True + + def update_dims_mapping(self, dist_op): + changed = False + op_desc = dist_op.serial_op.desc + op_dist_attr = dist_op.dist_attr + x_name = op_desc.input('X')[0] + out_name = op_desc.output('Out')[0] + x_dims_mapping = op_dist_attr.get_input_dims_mapping(x_name) + out_dims_mapping = op_dist_attr.get_output_dims_mapping(out_name) + + for i in range(len(x_dims_mapping)): + dim_changed = compute_compatible_and_update_dim_mapping( + [x_dims_mapping, out_dims_mapping], [i, i] + ) + if dim_changed: + changed = True + + return changed + + @staticmethod + def forward(ctx, *args, **kwargs): + DistributedDefaultImpl0.forward(ctx, *args, **kwargs) + + @staticmethod + def backward(ctx, *args, **kwargs): + DistributedDefaultImpl0.backward(ctx, *args, **kwargs) + + +register_distributed_operator_impl("scale", DistributedScaleImpl("scale")) diff --git a/python/paddle/distributed/auto_parallel/parallelizer.py b/python/paddle/distributed/auto_parallel/parallelizer.py index 75fb3d1ec52..983004694f7 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer.py +++ b/python/paddle/distributed/auto_parallel/parallelizer.py @@ -54,9 +54,9 @@ class AutoParallelizer: AutoParallelizer is the main controller class to do the auto parallel process. And the auto parallel process will be triggered in the wrapped parallelize function. To facilitate the auto parallelization, it will contain information about program, cluster and the - related context. In this basic version, the program information will be retrevied from + related context. In this basic version, the program information will be retrevied from Fleet object, and the cluster information can be retrevied in the new created Cluster object, - and the context information can be retrevied in the new created DistributedContext. + and the context information can be retrevied in the new created DistributedContext. """ def __init__(self, fleet): @@ -79,8 +79,12 @@ class AutoParallelizer: self._pass_context = PassContext() self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING") - self._need_rank_mapping = True if self._need_rank_mapping and \ - self._need_rank_mapping.lower() == 'true' else False + self._need_rank_mapping = ( + True + if self._need_rank_mapping + and self._need_rank_mapping.lower() == 'true' + else False + ) # self._pass_context = None def _remove_distributed_attrs(self, main_program): @@ -93,8 +97,9 @@ class AutoParallelizer: if suffix in attr_name: op._remove_attr(attr_name) - def _apply_pre_optimization_passes(self, main_program, startup_program, - loss, params_grads, no_grad_set): + def _apply_pre_optimization_passes( + self, main_program, startup_program, loss, params_grads, no_grad_set + ): # apply amp pass if self._dist_strategy.amp: config = copy.deepcopy(self._dist_strategy.amp_configs) @@ -104,12 +109,14 @@ class AutoParallelizer: if config["use_pure_fp16"]: config["base_opt"] = self._optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) - auto_parallel_fp16_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_fp16_pass.apply( + [main_program], [startup_program], self._pass_context + ) else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_amp_pass.apply( + [main_program], [startup_program], self._pass_context + ) # apply recompute pass if self._dist_strategy.recompute: @@ -117,14 +124,22 @@ class AutoParallelizer: config["dist_context"] = self._dist_context config["no_grad_set"] = copy.deepcopy(no_grad_set) config["loss"] = loss - auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", - config) - auto_parallel_recompute_pass.apply([main_program], - [startup_program], - self._pass_context) - - def _generate_backward(self, main_program, startup_program, loss, - parameter_list, no_grad_set, callbacks): + auto_parallel_recompute_pass = new_pass( + "auto_parallel_recompute", config + ) + auto_parallel_recompute_pass.apply( + [main_program], [startup_program], self._pass_context + ) + + def _generate_backward( + self, + main_program, + startup_program, + loss, + parameter_list, + no_grad_set, + callbacks, + ): with program_guard(main_program, startup_program): params_grads = append_backward( @@ -132,7 +147,8 @@ class AutoParallelizer: parameter_list, no_grad_set, callbacks, - distop_context=self._dist_context.dist_op_context) + distop_context=self._dist_context.dist_op_context, + ) self._completer = Completer(self._dist_context) self._completer.complete_backward_annotation(main_program) self._dist_context.block_state.parse_backward_blocks(main_program) @@ -151,18 +167,21 @@ class AutoParallelizer: return optimize_ops - def _apply_post_optimization_passes(self, main_program, startup_program, - rank, params_grads): + def _apply_post_optimization_passes( + self, main_program, startup_program, rank, params_grads + ): if self._dist_strategy.sharding: config = copy.deepcopy(self._dist_strategy.sharding_configs) config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["global_rank"] = rank - auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", - config) - auto_parallel_sharding_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_sharding_pass = new_pass( + "auto_parallel_sharding", config + ) + auto_parallel_sharding_pass.apply( + [main_program], [startup_program], self._pass_context + ) params_grads = self._pass_context.get_attr("params_grads") config = copy.deepcopy(self._dist_strategy.sharding_configs) @@ -170,18 +189,20 @@ class AutoParallelizer: config["params_grads"] = params_grads config["rank_id"] = rank auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config) - auto_parallel_clip_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_clip_pass.apply( + [main_program], [startup_program], self._pass_context + ) if self._dist_strategy.gradient_merge: config = copy.deepcopy(self._dist_strategy.gradient_merge_configs) config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_gradient_merge_pass = new_pass( - "auto_parallel_gradient_merge_pass", config) - auto_parallel_gradient_merge_pass.apply([main_program], - [startup_program], - self._pass_context) + "auto_parallel_gradient_merge_pass", config + ) + auto_parallel_gradient_merge_pass.apply( + [main_program], [startup_program], self._pass_context + ) def _get_dist_program(self, rank, dist_context=None, relaunch_phase=False): completed_main_program = None @@ -195,8 +216,9 @@ class AutoParallelizer: self._dist_context = DistributedContext() _logger.info("Start annotation dist attr.") self._completer = Completer(self._dist_context) - completed_main_program = self._completer.complete_forward_annotation( - serial_main_program) + completed_main_program = ( + self._completer.complete_forward_annotation(serial_main_program) + ) else: completed_main_program = serial_main_program self._dist_context = copy.deepcopy(dist_context) @@ -206,49 +228,77 @@ class AutoParallelizer: # serial backward pass params_grads = self._generate_backward( - completed_main_program, serial_startup_program, serial_loss, - self._parameter_list, self._no_grad_set, self._callbacks) + completed_main_program, + serial_startup_program, + serial_loss, + self._parameter_list, + self._no_grad_set, + self._callbacks, + ) # serial forward pass - self._apply_pre_optimization_passes(completed_main_program, - serial_startup_program, serial_loss, - params_grads, self._no_grad_set) + self._apply_pre_optimization_passes( + completed_main_program, + serial_startup_program, + serial_loss, + params_grads, + self._no_grad_set, + ) # Logical partition partitioner = Partitioner(self._dist_context, rank) - dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( - completed_main_program, serial_startup_program, params_grads) + ( + dist_main_prog, + dist_startup_prog, + dist_params_grads, + ) = partitioner.partition( + completed_main_program, serial_startup_program, params_grads + ) # TODO refactor the placement of optimizer # generate optimize program - dist_optimize_ops = self._apply_optimize(dist_main_prog, - dist_startup_prog, - dist_params_grads) + dist_optimize_ops = self._apply_optimize( + dist_main_prog, dist_startup_prog, dist_params_grads + ) set_grad_var_shape(dist_main_prog, self._dist_context) make_data_unshard(dist_main_prog, dist_startup_prog, self._dist_context) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank, - self._dist_context, dist_params_grads) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank, + self._dist_context, + dist_params_grads, + ) resharder.reshard() - self._apply_post_optimization_passes(dist_main_prog, dist_startup_prog, - rank, dist_params_grads) + self._apply_post_optimization_passes( + dist_main_prog, dist_startup_prog, rank, dist_params_grads + ) g_process_group_map = None if not relaunch_phase: g_process_group_map = copy.deepcopy(_g_process_group_map) _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) for process_mesh in self._dist_context._process_meshes: _g_process_group_map[0].add_ranks(process_mesh.processes) - return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map - - def parallelize(self, - loss, - startup_program, - parameter_list=None, - no_grad_set=None, - callbacks=None): + return ( + dist_optimize_ops, + dist_params_grads, + dist_startup_prog, + dist_main_prog, + g_process_group_map, + ) + + def parallelize( + self, + loss, + startup_program, + parameter_list=None, + no_grad_set=None, + callbacks=None, + ): assert startup_program is not None self._loss = loss self._startup_program = startup_program @@ -259,25 +309,27 @@ class AutoParallelizer: if self._enable_auto_mapping and self._need_rank_mapping: # Do the mapping pass before parallelization - assert self._cluster is not None, \ - "The cluster must not be none when using auto mapping." + assert ( + self._cluster is not None + ), "The cluster must not be none when using auto mapping." dist_programs = {} world_process_group = get_world_process_group() dist_context = None # auto search if self._dist_strategy.auto_search: logging.info("Start searching dist attr.") - serial_program_info = SerialProgramInfo(self._main_program, - self._startup_program, - self._loss, - self._optimizer, - self._cluster) - planner = Planner(serial_program_info, - self, - algorithm_config={ - "name": "mcmc", - "max_search_times": 5 - }) + serial_program_info = SerialProgramInfo( + self._main_program, + self._startup_program, + self._loss, + self._optimizer, + self._cluster, + ) + planner = Planner( + serial_program_info, + self, + algorithm_config={"name": "mcmc", "max_search_times": 5}, + ) dist_context, _ = planner.search() logging.info("End searching dist attr.") @@ -286,31 +338,42 @@ class AutoParallelizer: logging.info("Start serialize searched dist attr") cwd = pathlib.Path().resolve() searched_dist_context_path = os.path.join( - cwd, f"searched_dist_context_{time.time()}.pkl") + cwd, f"searched_dist_context_{time.time()}.pkl" + ) saved_dist_context = {} ops_dist_attr = {} tensors_dist_attr = {} for key, dist_op in dist_context._dist_ops_for_program.items(): ops_dist_attr[key] = dist_op.dist_attr - for key, dist_tensor in dist_context._dist_tensors_for_program.items( - ): + for ( + key, + dist_tensor, + ) in dist_context._dist_tensors_for_program.items(): tensors_dist_attr[key] = dist_tensor.dist_attr saved_dist_context["ops_dist_attr"] = ops_dist_attr saved_dist_context["tensors_dist_attr"] = tensors_dist_attr saved_dist_context[ - "process_meshes"] = dist_context._process_meshes - with open(searched_dist_context_path, - "wb") as dist_context_file: + "process_meshes" + ] = dist_context._process_meshes + with open( + searched_dist_context_path, "wb" + ) as dist_context_file: pickle.dump(saved_dist_context, dist_context_file) os.environ[ - 'PADDLE_SEARCHED_DIST_CONTEXT_PATH'] = searched_dist_context_path + 'PADDLE_SEARCHED_DIST_CONTEXT_PATH' + ] = searched_dist_context_path logging.info( f"End serialize searched dist attr to {searched_dist_context_path}" ) for rank in world_process_group.ranks: - dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, g_process_group_map = self._get_dist_program( - rank, dist_context) + ( + dist_optimize_ops, + dist_params_grads, + dist_startup_prog, + dist_main_prog, + g_process_group_map, + ) = self._get_dist_program(rank, dist_context) dist_programs[rank] = [dist_main_prog, g_process_group_map] # Do the mapping between the distributed program graph and the cluster graph @@ -322,27 +385,42 @@ class AutoParallelizer: json.dump(rank_mapping, rank_mapping_file) enable_elastic = os.getenv("PADDLE_ENABLE_ELASTIC") - enable_elastic = True if enable_elastic and enable_elastic.lower( - ) == 'true' else False + enable_elastic = ( + True + if enable_elastic and enable_elastic.lower() == 'true' + else False + ) if enable_elastic: print("Auto mapping finished, now do elastic re-launch") - sys.exit(paddle.distributed.fleet.elastic.manager. - ELASTIC_AUTO_PARALLEL_EXIT_CODE) + sys.exit( + paddle.distributed.fleet.elastic.manager.ELASTIC_AUTO_PARALLEL_EXIT_CODE + ) original_cmd_args = os.getenv("PADDLE_ORIGINAL_CMD_ARGS") rank_mapping_args = " ".join( - ["--rank_mapping_path", self._rank_mapping_path]) + ["--rank_mapping_path", self._rank_mapping_path] + ) if os.environ.get("WITH_COVERAGE", "OFF") == "ON": coverage_args = ["-m", "coverage", "run", "--branch", "-p"] else: coverage_args = [] - new_cmd_args = "-m paddle.distributed.fleet.launch" + " " + rank_mapping_args + " " + original_cmd_args - new_cmd = [sys.executable, "-u" - ] + coverage_args + shlex.split(new_cmd_args) + new_cmd_args = ( + "-m paddle.distributed.fleet.launch" + + " " + + rank_mapping_args + + " " + + original_cmd_args + ) + new_cmd = ( + [sys.executable, "-u"] + + coverage_args + + shlex.split(new_cmd_args) + ) new_process = subprocess.Popen(new_cmd) new_process.wait() - assert new_process.returncode == 0, \ - "Launch failed with rank mapping" + assert ( + new_process.returncode == 0 + ), "Launch failed with rank mapping" print("Successfully do the second launch for auto mapping!") sys.exit(0) else: @@ -350,27 +428,32 @@ class AutoParallelizer: rank = paddle.distributed.get_rank() dist_context = None searched_dist_context_path = os.getenv( - "PADDLE_SEARCHED_DIST_CONTEXT_PATH", None) + "PADDLE_SEARCHED_DIST_CONTEXT_PATH", None + ) if searched_dist_context_path is not None: - with open(searched_dist_context_path, - "rb") as dist_context_file: + with open( + searched_dist_context_path, "rb" + ) as dist_context_file: saved_dist_context = pickle.load(dist_context_file) dist_context = DistributedContext() for op in self._main_program.global_block().ops: dist_attr = saved_dist_context["ops_dist_attr"][ - op.desc.id()] + op.desc.id() + ] dist_op = DistributedOperator(op, dist_attr) dist_context.add_dist_op_for_program(dist_op) vars = self._main_program.global_block().vars for var in vars.values(): dist_attr = saved_dist_context["tensors_dist_attr"][ - var.desc.id()] + var.desc.id() + ] dist_tensor = DistributedTensor(var, dist_attr) dist_context.add_dist_tensor_for_program(dist_tensor) dist_context._process_meshes = saved_dist_context[ - "process_meshes"] + "process_meshes" + ] else: if self._dist_strategy.auto_search: @@ -379,13 +462,16 @@ class AutoParallelizer: self._startup_program, self._loss, self._optimizer, - cluster=self._cluster) - planner = Planner(serial_program_info, - self, - algorithm_config={ - "name": "mcmc", - "max_search_times": 5 - }) + cluster=self._cluster, + ) + planner = Planner( + serial_program_info, + self, + algorithm_config={ + "name": "mcmc", + "max_search_times": 5, + }, + ) dist_context, _ = planner.search() # rebuild g_process_group @@ -393,8 +479,13 @@ class AutoParallelizer: pg0 = get_process_group(0) for process_mesh in dist_context._process_meshes: pg0.add_ranks(process_mesh.processes) - dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog, _ = self._get_dist_program( - rank, dist_context, relaunch_phase=True) + ( + dist_optimize_ops, + dist_params_grads, + dist_startup_prog, + dist_main_prog, + _, + ) = self._get_dist_program(rank, dist_context, relaunch_phase=True) # NOTE: This is a trick to fix hang in pipeline mode when dist context is searched by planner if self._dist_strategy.auto_search: @@ -405,12 +496,19 @@ class AutoParallelizer: break if is_pipeline: with paddle.static.program_guard(dist_main_prog): - paddle.distributed.barrier() + paddle.distributed.barrier(get_process_group(0)) # Traverse different rank programs and traverse each op of them, # instantiate communication by process_mapping. all_process_groups = get_all_process_groups() for process_group in all_process_groups: + if len(_g_process_group_map) > 0: + tmp = paddle.to_tensor([1], dtype="int32") + paddle.distributed.all_reduce( + tmp, sync_op=True, group=_g_process_group_map[0] + ) + paddle.device.cuda.synchronize() + if rank not in process_group.ranks: continue process_group.instantiate() @@ -422,14 +520,25 @@ class AutoParallelizer: # with inference. self._remove_distributed_attrs(dist_main_prog) - return dist_optimize_ops, dist_params_grads, dist_startup_prog, dist_main_prog + return ( + dist_optimize_ops, + dist_params_grads, + dist_startup_prog, + dist_main_prog, + ) def __deepcopy__(self, memo): cls = self.__class__ result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k == "_main_program" or k == "_startup_program" or k == "_dist_context" or k == "_fleet" or k == "_loss": + if ( + k == "_main_program" + or k == "_startup_program" + or k == "_dist_context" + or k == "_fleet" + or k == "_loss" + ): setattr(result, k, v) else: setattr(result, k, copy.deepcopy(v, memo)) diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index e87c401055e..6f77dbd4e07 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -29,7 +29,6 @@ from ..utils.log_utils import get_logger class Parallelizer: - def __init__(self, mode, completer, dist_context): self._mode = mode self._completer = completer @@ -54,77 +53,139 @@ class Parallelizer: if self._mode == "train" and serial_optimizer: # Generate backward serial_loss = self._dist_context.serial_loss - params_grads = self._generate_backward(serial_main_program, - serial_startup_program, - serial_loss) + params_grads = self._generate_backward( + serial_main_program, serial_startup_program, serial_loss + ) # Apply pre optimization passes time0 = time.time() - serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization( - serial_main_program, serial_startup_program, serial_loss, - serial_optimizer, params_grads) + ( + serial_main_program, + serial_startup_program, + params_grads, + ) = self._apply_pre_optimization( + serial_main_program, + serial_startup_program, + serial_loss, + serial_optimizer, + params_grads, + ) self._logger.debug( - "within parallel apply_pre_optimization time: {}, mode {}". - format(time.time() - time0, self._mode)) + "within parallel apply_pre_optimization time: {}, mode {}".format( + time.time() - time0, self._mode + ) + ) # Do logical partition time0 = time.time() partitioner = Partitioner(self._dist_context, rank) - dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( - serial_main_program, serial_startup_program, params_grads) + ( + dist_main_prog, + dist_startup_prog, + dist_params_grads, + ) = partitioner.partition( + serial_main_program, serial_startup_program, params_grads + ) self._logger.debug( "within parallel partitioner time: {}, mode {}".format( - time.time() - time0, self._mode)) + time.time() - time0, self._mode + ) + ) # Generate optimizer time0 = time.time() - self._generate_optimizer(dist_main_prog, dist_startup_prog, - serial_optimizer, dist_params_grads) + self._generate_optimizer( + dist_main_prog, + dist_startup_prog, + serial_optimizer, + dist_params_grads, + ) self._logger.debug( "within parallel optimizer time: {}, mode {}".format( - time.time() - time0, self._mode)) + time.time() - time0, self._mode + ) + ) # Do reshard process time0 = time.time() set_grad_var_shape(dist_main_prog, self._dist_context) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank, - self._dist_context, dist_params_grads) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank, + self._dist_context, + dist_params_grads, + ) resharder.reshard() self._logger.debug( "within parallel reshard time: {}, mode {}".format( - time.time() - time0, self._mode)) + time.time() - time0, self._mode + ) + ) # Apply post optimization passes time0 = time.time() - self._apply_post_optimization(dist_main_prog, dist_startup_prog, - rank, dist_params_grads) + self._apply_post_optimization( + dist_main_prog, dist_startup_prog, rank, dist_params_grads + ) self._logger.debug( - "within parallel apply_post_optimization time: {}, mode {}". - format(time.time() - time0, self._mode)) + "within parallel apply_post_optimization time: {}, mode {}".format( + time.time() - time0, self._mode + ) + ) else: # Apply pre optimization passes time0 = time.time() - self._apply_pre_optimization(serial_main_program, - serial_startup_program, None, None, - None) + self._apply_pre_optimization( + serial_main_program, serial_startup_program, None, None, None + ) self._logger.debug( - "within parallel apply_pre_optimization time: {}, mode {}". - format(time.time() - time0, self._mode)) + "within parallel apply_pre_optimization time: {}, mode {}".format( + time.time() - time0, self._mode + ) + ) # Do logical partition time0 = time.time() partitioner = Partitioner(self._dist_context, rank) - dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( - serial_main_program, serial_startup_program, []) + ( + dist_main_prog, + dist_startup_prog, + dist_params_grads, + ) = partitioner.partition( + serial_main_program, serial_startup_program, [] + ) # Do reshard process self._logger.debug( "within parallel partitioner time: {}, mode {}".format( - time.time() - time0, self._mode)) + time.time() - time0, self._mode + ) + ) time0 = time.time() - resharder = Resharder(dist_main_prog, dist_startup_prog, rank, - self._dist_context, [], 1) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank, + self._dist_context, + [], + 1, + ) resharder.reshard() self._logger.debug( "within parallel reshard time: {}, mode {}".format( - time.time() - time0, self._mode)) + time.time() - time0, self._mode + ) + ) + # Apply post optimization passes + time0 = time.time() + self._apply_post_optimization( + dist_main_prog, dist_startup_prog, rank, dist_params_grads + ) + self._logger.debug( + "within parallel apply_post_optimization time: {}, mode {}".format( + time.time() - time0, self._mode + ) + ) # Clone program for test if self._mode != 'train': + pipeline_opt = dist_main_prog._pipeline_opt dist_main_prog = dist_main_prog.clone(for_test=True) dist_startup_prog = dist_startup_prog.clone(for_test=True) + dist_main_prog._pipeline_opt = pipeline_opt # Store the distributed programs for further usages self._dist_context.dist_main_programs[rank] = dist_main_prog @@ -133,13 +194,15 @@ class Parallelizer: def _generate_backward(self, main_program, startup_program, loss): with program_guard(main_program, startup_program): params_grads = append_backward( - loss, distop_context=self._dist_context.dist_op_context) + loss, distop_context=self._dist_context.dist_op_context + ) self._completer.complete_backward_annotation(main_program) self._dist_context.block_state.parse_backward_blocks(main_program) return params_grads - def _generate_optimizer(self, main_program, startup_program, optimizer, - params_grads): + def _generate_optimizer( + self, main_program, startup_program, optimizer, params_grads + ): # NOTE: `apply_gradients` will add an Accumulator for a parameter only once, # but optimizer will be called repeatedly in re-launch, so optimizer need to be copied. optimizer = copy.deepcopy(optimizer) @@ -150,8 +213,9 @@ class Parallelizer: self._completer.complete_update_annotation(main_program) return optimizer_ops - def _apply_pre_optimization(self, main_program, startup_program, loss, - optimizer, params_grads): + def _apply_pre_optimization( + self, main_program, startup_program, loss, optimizer, params_grads + ): if self._strategy is None: return @@ -162,10 +226,11 @@ class Parallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_quantization_pass = new_pass( - "auto_parallel_quantization", config) - auto_parallel_quantization_pass.apply([main_program], - [startup_program], - self._pass_context) + "auto_parallel_quantization", config + ) + auto_parallel_quantization_pass.apply( + [main_program], [startup_program], self._pass_context + ) main_program = self._pass_context.get_attr("main_program") startup_program = self._pass_context.get_attr("startup_program") params_grads = self._pass_context.get_attr("params_grads") @@ -176,17 +241,21 @@ class Parallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["loss"] = loss - config["input_data"] = self._dist_context.serial_feed_vars["inputs"] \ + config["input_data"] = ( + self._dist_context.serial_feed_vars["inputs"] + self._dist_context.serial_feed_vars["labels"] + ) if config["use_pure_fp16"]: config["base_opt"] = optimizer auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) - auto_parallel_fp16_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_fp16_pass.apply( + [main_program], [startup_program], self._pass_context + ) else: auto_parallel_amp_pass = new_pass("auto_parallel_amp", config) - auto_parallel_amp_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_amp_pass.apply( + [main_program], [startup_program], self._pass_context + ) # apply recompute pass # recompute is then train-only optimization @@ -195,16 +264,18 @@ class Parallelizer: config["dist_context"] = self._dist_context config["no_grad_set"] = None config["loss"] = loss - auto_parallel_recompute_pass = new_pass("auto_parallel_recompute", - config) - auto_parallel_recompute_pass.apply([main_program], - [startup_program], - self._pass_context) + auto_parallel_recompute_pass = new_pass( + "auto_parallel_recompute", config + ) + auto_parallel_recompute_pass.apply( + [main_program], [startup_program], self._pass_context + ) return main_program, startup_program, params_grads - def _apply_post_optimization(self, main_program, startup_program, rank, - params_grads): + def _apply_post_optimization( + self, main_program, startup_program, rank, params_grads + ): if self._strategy is None: return @@ -221,10 +292,12 @@ class Parallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["global_rank"] = rank - auto_parallel_sharding_pass = new_pass("auto_parallel_sharding", - config) - auto_parallel_sharding_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_sharding_pass = new_pass( + "auto_parallel_sharding", config + ) + auto_parallel_sharding_pass.apply( + [main_program], [startup_program], self._pass_context + ) params_grads = self._pass_context.get_attr("params_grads") # GradClip is train-only optimization @@ -233,10 +306,19 @@ class Parallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads config["rank_id"] = rank - auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", - config) - auto_parallel_clip_pass.apply([main_program], [startup_program], - self._pass_context) + auto_parallel_clip_pass = new_pass( + "auto_parallel_grad_clip", config + ) + auto_parallel_clip_pass.apply( + [main_program], [startup_program], self._pass_context + ) + + if self._strategy.pipeline.enable: + self._strategy.gradient_merge.enable = True + self._strategy.gradient_merge.k_steps = ( + self._strategy.pipeline.accumulate_steps + ) + self._strategy.gradient_merge.avg = True # gradient_merge is then train-only optimization if self._mode == "train" and self._strategy.gradient_merge.enable: @@ -244,7 +326,18 @@ class Parallelizer: config["dist_context"] = self._dist_context config["params_grads"] = params_grads auto_parallel_gradient_merge_pass = new_pass( - "auto_parallel_gradient_merge_pass", config) - auto_parallel_gradient_merge_pass.apply([main_program], - [startup_program], - self._pass_context) + "auto_parallel_gradient_merge_pass", config + ) + auto_parallel_gradient_merge_pass.apply( + [main_program], [startup_program], self._pass_context + ) + + if self._strategy.pipeline.enable: + config = copy.deepcopy(self._strategy.pipeline.to_dict()) + config["dist_context"] = self._dist_context + auto_parallel_pipeline_pass = new_pass( + "auto_parallel_pipeline", config + ) + auto_parallel_pipeline_pass.apply( + [main_program], [startup_program], self._pass_context + ) diff --git a/python/paddle/distributed/auto_parallel/process_group.py b/python/paddle/distributed/auto_parallel/process_group.py index 10ff0d36fce..dd1e209835b 100644 --- a/python/paddle/distributed/auto_parallel/process_group.py +++ b/python/paddle/distributed/auto_parallel/process_group.py @@ -31,10 +31,11 @@ def get_all_process_groups(): def get_process_group(group_id, g_process_group_map=None): global _g_process_group_map - return _g_process_group_map.get( - group_id, - None) if g_process_group_map is None else g_process_group_map.get( - group_id, None) + return ( + _g_process_group_map.get(group_id, None) + if g_process_group_map is None + else g_process_group_map.get(group_id, None) + ) def get_world_process_group(): @@ -45,23 +46,23 @@ def get_world_process_group(): def clear_all_process_groups(): global _g_process_group_map _g_process_group_map = {} - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) def new_process_group(ranks, group_id=None): global _g_process_group_map # A key constructed from ranks is used for avoiding duplication - new_key = ''.join(map(str, sorted(ranks))) + new_key = ''.join(map(str, ranks)) for pg_id, pg in _g_process_group_map.items(): - cur_key = ''.join(map(str, sorted(pg.ranks))) + cur_key = ''.join(map(str, pg.ranks)) if pg_id != 0 and new_key == cur_key: return pg # If not matching the existing one, construt a new process group num_groups = len(_g_process_group_map) # Note: our process group may interfere with the original implementation # so the created group id should start from the original _new_ring_id() - if group_id == None: - group_id = _new_ring_id() + num_groups + 1 + if group_id is None: + group_id = _new_ring_id() + num_groups + 1000 new_pg = ProcessGroup(group_id, ranks) _g_process_group_map[group_id] = new_pg @@ -75,14 +76,15 @@ def new_process_group(ranks, group_id=None): # the instantiation process in a more general way. In the future, the process group may # handle the communication implementation choice. class ProcessGroup: - def __init__(self, group_id, ranks): - if group_id == 0 and get_process_group(0) is not None: - assert group_id != 0, "Process group id 0 is reserved for all ranks." + if group_id == 1000 and get_process_group(0) is not None: + assert ( + group_id != 1000 + ), "Process group id 1000 is reserved for all ranks." self._group_id = group_id - self._ranks = sorted(ranks) + self._ranks = ranks # Add the current ranks into group 0 - if group_id != 0: + if group_id != 1000: global _g_process_group_map _g_process_group_map[0].add_ranks(ranks) self._is_instantiate = False @@ -103,17 +105,19 @@ class ProcessGroup: if set(new_ranks) <= set(self.ranks): return else: - assert self.is_instantiate() == False, \ - "Cannot add new ranks after instantiating the process group" + assert ( + self.is_instantiate() == False + ), "Cannot add new ranks after instantiating the process group" self._ranks.extend(new_ranks) - self._ranks = sorted(list(set(self.ranks))) + self._ranks = list(set(self.ranks)) def local_rank(self, global_rank): if global_rank in self.ranks: return self.ranks.index(global_rank) else: - assert False, \ - "Rank {} doesn't belong to this group".format(global_rank) + assert False, "Rank {} doesn't belong to this group".format( + global_rank + ) def is_instantiate(self): return self._is_instantiate @@ -137,24 +141,36 @@ class ProcessGroup: if core.is_compiled_with_cuda(): place = core.CUDAPlace(genv.device_id) - core.NCCLParallelContext(strategy, - place).init_with_ring_id(ring_id) + core.NCCLParallelContext(strategy, place).init_with_ring_id( + ring_id + ) else: - assert False, ("No CUDA device found") + assert False, "No CUDA device found" # TODO(shenliang03): This is a temporary solution to solve the problem of # hang caused by cross-creation of new_group paddle.disable_static() _enable_legacy_dygraph() - paddle.set_device('gpu:%d' % - paddle.distributed.ParallelEnv().dev_id) - tmp = paddle.to_tensor( - [1], dtype="int32") if _non_static_mode() else fill_constant( - [0], dtype="int32", value="1") + paddle.set_device( + 'gpu:%d' % paddle.distributed.ParallelEnv().dev_id + ) + tmp = ( + paddle.to_tensor([1], dtype="int32") + if _non_static_mode() + else fill_constant([0], dtype="int32", value="1") + ) paddle.distributed.all_reduce(tmp, sync_op=True, group=self) paddle.distributed.wait(tmp, group=self) - paddle.enable_static() + # TODO(shenliang03) AlltoAll create communicator + alltoall_tmp = paddle.empty( + shape=[self.nranks, self.nranks], dtype="int32" + ) + out = paddle._legacy_C_ops.alltoall( + alltoall_tmp, 'use_calc_stream', True, 'ring_id', ring_id + ) + paddle.device.cuda.synchronize() + paddle.enable_static() self._is_instantiate = True def is_member(self): @@ -172,7 +188,8 @@ class ProcessGroup: def __str__(self): string = "id: {}, nranks: {}, ranks: {}.".format( - self.id, self.nranks, ", ".join(map(str, self.ranks))) + self.id, self.nranks, ", ".join(map(str, self.ranks)) + ) return string def __hash__(self): @@ -182,4 +199,4 @@ class ProcessGroup: # Note that Process group 0 is reserved for representing all ranks. # At the beginning, group 0 is empty and new ranks will be added automatically. _g_process_group_map = OrderedDict() -_g_process_group_map[0] = ProcessGroup(0, []) +_g_process_group_map[0] = ProcessGroup(1000, []) diff --git a/python/paddle/distributed/auto_parallel/process_mesh.py b/python/paddle/distributed/auto_parallel/process_mesh.py index a18fc196477..72dc9043cab 100644 --- a/python/paddle/distributed/auto_parallel/process_mesh.py +++ b/python/paddle/distributed/auto_parallel/process_mesh.py @@ -41,19 +41,19 @@ def reset_current_process_mesh(): class ProcessMesh(object): """ - The `Processmesh` object describes the topology of the used processes. + The `Processmesh` object describes the topology of the used processes. Args: mesh (list|numpy.array): an n-dimensional array describes the toplogy of the processes. dim_names (list, optional): the i-th element of this list gives the name of the i-th dimension of the mesh. - + Examples: .. code-block:: python import paddle - + mesh = auto.ProcessMesh([[2, 4, 5], [0, 1, 3]], dim_names=["x", "y"]) assert mesh.shape == [2, 3] assert mesh.processe_ids == [2, 4, 5, 0, 1, 3] @@ -68,10 +68,10 @@ class ProcessMesh(object): assert process_ids is not None mesh = np.array(process_ids).reshape(shape) - if not isinstance(mesh, list) and \ - not isinstance(mesh, np.ndarray): + if not isinstance(mesh, list) and not isinstance(mesh, np.ndarray): raise ValueError( - 'The mesh must be an instance of list or np.ndarray.') + 'The mesh must be an instance of list or np.ndarray.' + ) if isinstance(mesh, list): mesh = np.array(mesh) @@ -79,30 +79,37 @@ class ProcessMesh(object): self._shape = list(self._mesh.shape) self._process_ids = self._mesh.flatten().tolist() - assert all(isinstance(p, int) for p in self._process_ids), \ - ("All elements of the mesh must be integer") - assert min( - self._process_ids) >= 0, ('All elements of the mesh must be >= 0.') + assert all( + isinstance(p, int) for p in self._process_ids + ), "All elements of the mesh must be integer" + assert ( + min(self._process_ids) >= 0 + ), 'All elements of the mesh must be >= 0.' unique_process_ids = set(self._process_ids) assert len(unique_process_ids) == len( - self._process_ids), ('All elements of the mesh must be unique.') + self._process_ids + ), 'All elements of the mesh must be unique.' if dim_names is not None: - assert len(dim_names) == len(self._shape), \ - ("The length of dims_names must be same as the shape of the mesh.") + assert len(dim_names) == len( + self._shape + ), "The length of dims_names must be same as the shape of the mesh." self._dim_names = copy.deepcopy(dim_names) else: self._dim_names = ["d" + str(i) for i in range(len(self._shape))] unique_dim_names = set(self._dim_names) - assert len(unique_dim_names) == len(self._dim_names), ( - 'All dim_names {} must be unique.'.format(dim_names)) + assert len(unique_dim_names) == len( + self._dim_names + ), 'All dim_names {} must be unique.'.format(dim_names) + + # # Store all process meshes + # from .dist_context import get_default_distributed_context + # default_dist_cxt = get_default_distributed_context() + # default_dist_cxt.add_process_mesh(self) - # Store all process meshes - from .dist_context import get_default_distributed_context - default_dist_cxt = get_default_distributed_context() - default_dist_cxt.add_process_mesh(self) # Add new processes to process group 0 from .process_group import get_process_group + pg0 = get_process_group(0) pg0.add_ranks(self.processes) @@ -183,20 +190,24 @@ class ProcessMesh(object): def __exit__(self, exc_type, exc_value, exc_traceback): from .dist_tensor import DistributedTensor from .dist_op import DistributedOperator + default_prog = paddle.fluid.default_main_program() cur_block = default_prog.current_block() new_var_names = list(cur_block.vars.keys()) new_op_size = len(cur_block.ops) from .dist_context import get_default_distributed_context + default_dist_ctx = get_default_distributed_context() for name in new_var_names: if name not in self._old_var_names: tensor = cur_block.vars[name] dist_tensor = default_dist_ctx.get_dist_tensor_for_program( - tensor) + tensor + ) if dist_tensor is None: - dist_tensor = DistributedTensor(cur_block.vars[name], - {"process_mesh": self}) + dist_tensor = DistributedTensor( + cur_block.vars[name], {"process_mesh": self} + ) dist_tensor.dist_attr.mark_annotated("process_mesh") default_dist_ctx.add_dist_tensor_for_program(dist_tensor) else: @@ -229,5 +240,6 @@ class ProcessMesh(object): def __str__(self): str = "shape {}, process_ids {}, dim_nams {}".format( - self.shape, self.process_ids, self.dim_names) + self.shape, self.process_ids, self.dim_names + ) return str diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index bb3d2d6cfba..4e5d5b0bf32 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -24,7 +24,10 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole import paddle.fluid.layers.utils as utils from ..collective import _get_global_env from .dist_context import DistributedContext -from .dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute +from .dist_attribute import ( + OperatorDistributedAttribute, + TensorDistributedAttribute, +) from .process_group import new_process_group, ProcessGroup, _g_process_group_map from .cost import build_comm_desc, CommContext from .cost import AllgatherOpCost, SendOpCost @@ -35,7 +38,11 @@ from .utils import print_program_with_dist_attr, is_gradient_clip_op # NOTE: If op in _g_special_ops or _g_gradient_clip_ops, it will not be resharded. _g_special_ops = ['check_finite_and_unscale', 'update_loss_scaling'] _g_gradient_clip_ops = [ - "sum", "sqrt", "fill_constant", "elementwise_max", "elementwise_div" + "sum", + "sqrt", + "fill_constant", + "elementwise_max", + "elementwise_div", ] _g_subblock_ops = ["while", "conditional_block"] @@ -267,20 +274,25 @@ class Inserter: def insert_cast_op(block, idx, tensor, op_role, tensor_type): # to avoid name conflict with framework new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join(["cast@RESHARD", 'tmp'])) - out = block.create_var(name=new_var_name, - dtype=tensor_type, - type=tensor.type, - lod_level=tensor.lod_level) - cast_op = block._insert_op(idx, - type='cast', - inputs={'X': [tensor]}, - outputs={'Out': [out]}, - attrs={ - 'in_dtype': tensor.dtype, - 'out_dtype': out.dtype, - 'op_role': op_role - }) + ".".join(["cast@RESHARD", 'tmp']) + ) + out = block.create_var( + name=new_var_name, + dtype=tensor_type, + type=tensor.type, + lod_level=tensor.lod_level, + ) + cast_op = block._insert_op( + idx, + type='cast', + inputs={'X': [tensor]}, + outputs={'Out': [out]}, + attrs={ + 'in_dtype': tensor.dtype, + 'out_dtype': out.dtype, + 'op_role': op_role, + }, + ) cast_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @@ -290,16 +302,18 @@ class Inserter: op_type = 'send_v2' # use pair comm group process_group = new_process_group([src, dst]) - send_op = block._insert_op(idx, - type=op_type, - inputs={'X': [tensor]}, - attrs={ - 'ring_id': process_group.id, - 'peer': process_group.ranks.index(dst), - 'use_calc_stream': True, - 'op_role': op_role, - 'dynamic_shape': True - }) + send_op = block._insert_op( + idx, + type=op_type, + inputs={'X': [tensor]}, + attrs={ + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(dst), + 'use_calc_stream': True, + 'op_role': op_role, + 'dynamic_shape': False, + }, + ) send_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod @@ -308,19 +322,21 @@ class Inserter: op_type = 'recv_v2' # use pair group process_group = new_process_group([src, dst]) - recv_op = block._insert_op(idx, - type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [tensor]}, - attrs={ - 'ring_id': process_group.id, - 'peer': process_group.ranks.index(src), - 'out_shape': tensor.shape, - 'dtype': tensor.dtype, - 'use_calc_stream': True, - 'op_role': op_role, - 'dynamic_shape': True - }) + recv_op = block._insert_op( + idx, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [tensor]}, + attrs={ + 'ring_id': process_group.id, + 'peer': process_group.ranks.index(src), + 'out_shape': tensor.shape, + 'dtype': tensor.dtype, + 'use_calc_stream': True, + 'op_role': op_role, + 'dynamic_shape': False, + }, + ) recv_op._set_attr('op_namescope', "/auto_parallel/reshard") @staticmethod @@ -328,21 +344,23 @@ class Inserter: """Insert reset_lod op into block at the given index.""" new_var_name = paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join(["reset_lod@RESHARD", 'tmp'])) - reset_lod_out = block.create_var(name=new_var_name, - shape=X.shape, - type=X.type, - dtype=X.dtype, - lod_level=X.lod_level) - - reset_op = block._insert_op(idx, - type="lod_reset", - inputs={ - 'X': X, - 'Y': Y - }, - outputs={'Out': reset_lod_out}, - attrs={'op_role': op_role}) + ".".join(["reset_lod@RESHARD", 'tmp']) + ) + reset_lod_out = block.create_var( + name=new_var_name, + shape=X.shape, + type=X.type, + dtype=X.dtype, + lod_level=X.lod_level, + ) + + reset_op = block._insert_op( + idx, + type="lod_reset", + inputs={'X': X, 'Y': Y}, + outputs={'Out': reset_lod_out}, + attrs={'op_role': op_role}, + ) reset_op._set_attr('op_namescope', "/auto_parallel/reshard") return reset_lod_out @@ -358,24 +376,29 @@ class Inserter: with paddle.static.program_guard(block.program): out = block.create_var( name=paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join([helper.name, 'tmp'])), + ".".join([helper.name, 'tmp']) + ), dtype=tensors[0].dtype, shape=None, lod_level=tensors[0].lod_level, type=tensors[0].type, persistable=False, - stop_gradient=False) - concat_op = block._insert_op(idx, - type='concat', - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) + stop_gradient=False, + ) + concat_op = block._insert_op( + idx, + type='concat', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs, + ) concat_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @staticmethod - def insert_slice_op(block, idx, tensor, starts, ends, axes, new_var_name, - op_role): + def insert_slice_op( + block, idx, tensor, starts, ends, axes, new_var_name, op_role + ): """Insert slice op into block at the given block.""" # This is a hack to insert split op to get slice tensor # 1. [128, 128] => [64, 128]: split @@ -390,19 +413,19 @@ class Inserter: # use assign if len(diff_dims) == 0: - out = block.create_var(name=new_var_name, - dtype=tensor.dtype, - type=tensor.type, - shape=slice_shape, - lod_level=tensor.lod_level) + out = block.create_var( + name=new_var_name, + dtype=tensor.dtype, + type=tensor.type, + shape=slice_shape, + lod_level=tensor.lod_level, + ) inputs = {'X': [tensor]} outputs = {"Out": [out]} attrs = {"in_place": False} - slice_op = block._insert_op(idx, - type="assign", - inputs=inputs, - outputs=outputs, - attrs=attrs) + slice_op = block._insert_op( + idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs + ) slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @@ -423,23 +446,27 @@ class Inserter: new_shape.append(item // num_or_sections) with paddle.static.program_guard(block.program): outs = [ - block.create_var(name=paddle.fluid.unique_name. - generate_with_ignorable_key(".".join( - ['split@RESHARD', 'tmp'])), - dtype=tensor.dtype, - shape=None, - type=tensor.type, - persistable=False, - lod_level=tensor.lod_level, - stop_gradient=False) + block.create_var( + name=paddle.fluid.unique_name.generate_with_ignorable_key( + ".".join(['split@RESHARD', 'tmp']) + ), + dtype=tensor.dtype, + shape=None, + type=tensor.type, + persistable=False, + lod_level=tensor.lod_level, + stop_gradient=False, + ) for i in range(num_or_sections) ] out = outs[cur_idx] - split_op = block._insert_op(idx, - type="split", - inputs=inputs, - outputs={'Out': outs}, - attrs=attrs) + split_op = block._insert_op( + idx, + type="split", + inputs=inputs, + outputs={'Out': outs}, + attrs=attrs, + ) split_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @@ -452,17 +479,21 @@ class Inserter: "starts": starts, "ends": ends, "infer_flags": infer_flags, - 'op_role': op_role + 'op_role': op_role, } - out = block.create_var(name=new_var_name, - dtype=tensor.dtype, - type=tensor.type, - lod_level=tensor.lod_level) - slice_op = block._insert_op(idx, - type="slice", - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) + out = block.create_var( + name=new_var_name, + dtype=tensor.dtype, + type=tensor.type, + lod_level=tensor.lod_level, + ) + slice_op = block._insert_op( + idx, + type="slice", + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs, + ) slice_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @@ -483,19 +514,20 @@ class Inserter: outs = [ block.create_var( name=paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join([helper.name, 'tmp'])), + ".".join([helper.name, 'tmp']) + ), dtype=tensor.dtype, shape=None, lod_level=tensor.lod_level, type=tensor.type, persistable=False, - stop_gradient=False) for i in range(num_or_sections) + stop_gradient=False, + ) + for i in range(num_or_sections) ] - split_op = block._insert_op(idx, - type="split", - inputs=inputs, - outputs={'Out': outs}, - attrs=attrs) + split_op = block._insert_op( + idx, type="split", inputs=inputs, outputs={'Out': outs}, attrs=attrs + ) split_op._set_attr('op_namescope', "/auto_parallel/reshard") return outs @@ -508,27 +540,30 @@ class Inserter: with paddle.static.program_guard(block.program): out = block.create_var( name=paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join([helper.name, 'tmp'])), + ".".join([helper.name, 'tmp']) + ), dtype=paddle.int64, shape=None, type=core.VarDesc.VarType.LOD_TENSOR, persistable=False, - stop_gradient=False) + stop_gradient=False, + ) inputs = {} attrs = {'force_cpu': False} attrs['str_value'] = str(int("1")) attrs['value'] = int("1") attrs['dtype'] = out.dtype attrs['op_role'] = op_role - utils.get_shape_tensor_inputs(inputs=inputs, - attrs=attrs, - shape=[0], - op_type='fill_constant') - fillconstant_op = block._insert_op(idx, - type='fill_constant', - inputs=inputs, - outputs={'Out': [out]}, - attrs=attrs) + utils.get_shape_tensor_inputs( + inputs=inputs, attrs=attrs, shape=[0], op_type='fill_constant' + ) + fillconstant_op = block._insert_op( + idx, + type='fill_constant', + inputs=inputs, + outputs={'Out': [out]}, + attrs=attrs, + ) out.stop_gradient = True fillconstant_op._set_attr('op_namescope', "/auto_parallel/reshard") return out @@ -544,7 +579,8 @@ class Inserter: if not group.is_instantiate(): # insert fill_constant op fill_constant_out = Inserter.insert_fill_constant_op( - block, idx, op_role) + block, idx, op_role + ) fill_constant_out.stop_gradient = True # insert c_allreduce_sum op @@ -554,10 +590,11 @@ class Inserter: inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]}, attrs={ - 'ring_id': 0, + 'ring_id': 1000, 'use_calc_stream': True, - 'op_role': op_role - }) + 'op_role': op_role, + }, + ) allreduce_op._set_attr('op_namescope', "/auto_parallel/reshard") # insert c_sync_calc_stream op sync_calc_op = block._insert_op( @@ -565,7 +602,8 @@ class Inserter: type="c_sync_calc_stream", inputs={'X': [fill_constant_out]}, outputs={'Out': [fill_constant_out]}, - attrs={'op_role': op_role}) + attrs={'op_role': op_role}, + ) sync_calc_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset = 3 @@ -576,37 +614,42 @@ class Inserter: with paddle.static.program_guard(block.program): allgather_out = block.create_var( name=paddle.fluid.unique_name.generate_with_ignorable_key( - ".".join([helper.name, 'tmp'])), + ".".join([helper.name, 'tmp']) + ), dtype=tensor.dtype, shape=None, lod_level=tensor.lod_level, type=tensor.type, persistable=False, - stop_gradient=False) - allgather_op = block._insert_op(idx + idx_offset, - type=op_type, - inputs={'X': [tensor]}, - outputs={'Out': [allgather_out]}, - attrs={ - 'ring_id': group.id, - 'use_calc_stream': True, - 'nranks': group.nranks, - 'op_role': op_role - }) + stop_gradient=False, + ) + allgather_op = block._insert_op( + idx + idx_offset, + type=op_type, + inputs={'X': [tensor]}, + outputs={'Out': [allgather_out]}, + attrs={ + 'ring_id': group.id, + 'use_calc_stream': True, + 'nranks': group.nranks, + 'op_role': op_role, + }, + ) allgather_op._set_attr('op_namescope', "/auto_parallel/reshard") idx_offset += 1 # insert split op - split_out = Inserter.insert_split_op(block, idx + idx_offset, - allgather_out, group.nranks, - op_role) + split_out = Inserter.insert_split_op( + block, idx + idx_offset, allgather_out, group.nranks, op_role + ) idx_offset += 1 tensor_list.extend(split_out) return tensor_list, idx_offset @staticmethod - def concat_partitions_with_op(partition_tensor_list, tensor, - partition_index, block, idx, op_role): + def concat_partitions_with_op( + partition_tensor_list, tensor, partition_index, block, idx, op_role + ): """Concat the tensors and insert concat op.""" if not partition_tensor_list: partition_tensor_list.append((tensor, partition_index)) @@ -614,18 +657,42 @@ class Inserter: i = 0 has_concat = False while i < len(partition_tensor_list): - concat_axis, first_order, new_partition = Resharder.compute_concat_info( - partition_tensor_list[i][1], partition_index) + ( + concat_axis, + first_order, + new_partition, + ) = Resharder.compute_concat_info( + partition_tensor_list[i][1], partition_index + ) if concat_axis != -1: has_concat = True - _ = Inserter.insert_concat_op(block, idx[0], [partition_tensor_list[i][0], tensor], concat_axis, op_role) \ - if first_order == 0 else \ - Inserter.insert_concat_op(block, idx[0], [tensor, partition_tensor_list[i][0]], concat_axis, op_role) + _ = ( + Inserter.insert_concat_op( + block, + idx[0], + [partition_tensor_list[i][0], tensor], + concat_axis, + op_role, + ) + if first_order == 0 + else Inserter.insert_concat_op( + block, + idx[0], + [tensor, partition_tensor_list[i][0]], + concat_axis, + op_role, + ) + ) partition_tensor_list.pop(i) idx[0] += 1 - Inserter.concat_partitions_with_op(partition_tensor_list, _, - new_partition, block, - idx, op_role) + Inserter.concat_partitions_with_op( + partition_tensor_list, + _, + new_partition, + block, + idx, + op_role, + ) break i += 1 if not has_concat: @@ -639,7 +706,9 @@ class Remover: def remove_no_need_ops(auto_parallel_main_prog, dist_context, rank_id): """Remove no need ops in the main program""" not_remove_op_ref = [ - "create_py_reader", "create_double_buffer_reader", "read" + "create_py_reader", + "create_double_buffer_reader", + "read", ] # NOTE: The nested sub block is not be supported now. @@ -663,7 +732,9 @@ class Remover: for var_name in op.output_arg_names: dim_list.extend( get_var_with_recursion( - var_name, block, auto_parallel_main_prog).shape) + var_name, block, auto_parallel_main_prog + ).shape + ) for i in range(idx, -1, -1): if ops[i].type == "create_py_reader": ops[i]._set_attr("shape_concat", dim_list) @@ -674,10 +745,13 @@ class Remover: if op.type == "c_sync_comm_stream": need_save = [] for var_name in op.input_arg_names: - process_mesh = dist_context.get_tensor_dist_attr_for_program( - get_var_with_recursion( - var_name, block, - auto_parallel_main_prog)).process_mesh + process_mesh = ( + dist_context.get_tensor_dist_attr_for_program( + get_var_with_recursion( + var_name, block, auto_parallel_main_prog + ) + ).process_mesh + ) if rank_id in process_mesh.processes: need_save.append(var_name) if not need_save: @@ -693,15 +767,20 @@ class Remover: op_dist_attr = dist_context.get_op_dist_attr_for_program(op) if op_dist_attr is not None: op_process_mesh = op_dist_attr.process_mesh - if rank_id not in op_process_mesh.processes and op.type not in not_remove_op_ref: + if ( + rank_id not in op_process_mesh.processes + and op.type not in not_remove_op_ref + ): remove_op_idx.append(idx) for idx in remove_op_idx[::-1]: - block._remove_op(idx) + block._remove_op(idx, sync=False) + block._sync_with_cpp() @staticmethod - def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads, - feed_var_names): + def remove_no_need_vars( + auto_parallel_main_prog, dist_params_grads, feed_var_names + ): """Remove no need vars in the main program""" for block_idx, block in enumerate(auto_parallel_main_prog.blocks): remove_vars = set() @@ -724,7 +803,10 @@ class Remover: param_grad_map = {} for op in ops: if int(op.attr('op_role')) == int(OpRole.Optimize): - if "Param" in op.input_names and "Grad" in op.input_names: + if ( + "Param" in op.input_names + and "Grad" in op.input_names + ): param_name = op.input("Param")[0] grad_name = op.input("Grad")[0] param_grad_map[param_name] = grad_name @@ -743,7 +825,9 @@ class Remover: grad_name = dist_params_grads[idx][1].name if grad_name != param_grad_map[param_name]: dist_params_grads[idx] = ( - vars[param_name], vars[param_grad_map[param_name]]) + vars[param_name], + vars[param_grad_map[param_name]], + ) idx += 1 for var in remove_vars: @@ -752,23 +836,28 @@ class Remover: block._remove_var(var) @staticmethod - def remove_no_need_in_main(auto_parallel_main_prog, dist_context, rank_id, - dist_params_grads): + def remove_no_need_in_main( + auto_parallel_main_prog, dist_context, rank_id, dist_params_grads + ): """Remove no need vars and ops in the main program.""" - Remover.remove_no_need_ops(auto_parallel_main_prog, dist_context, - rank_id) - Resharder.change_while_op_input_and_output(auto_parallel_main_prog, - dist_context) + Remover.remove_no_need_ops( + auto_parallel_main_prog, dist_context, rank_id + ) + Resharder.change_while_op_input_and_output( + auto_parallel_main_prog, dist_context + ) # 'feed_var_names' cannot be removed from auto_parallel_main_prog feed_var_names = [] for var in sum(list(dist_context.serial_feed_vars.values()), []): feed_var_names.append(var.name) - Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads, - feed_var_names) + Remover.remove_no_need_vars( + auto_parallel_main_prog, dist_params_grads, feed_var_names + ) @staticmethod - def remove_no_need_in_startup(auto_parallel_main_prog, - auto_parallel_startup_prog): + def remove_no_need_in_startup( + auto_parallel_main_prog, auto_parallel_startup_prog + ): """Remove no need vars and ops in the startup program.""" main_input_vars = set() main_ops = auto_parallel_main_prog.global_block().ops @@ -838,7 +927,8 @@ class Remover: if is_no_need_op: remove_op_idx.append(idx) for idx in remove_op_idx[::-1]: - startup_block._remove_op(idx) + startup_block._remove_op(idx, sync=False) + startup_block._sync_with_cpp() class Resharder: @@ -853,28 +943,43 @@ class Resharder: dist_params_grads (list): The list contains the tuple of param and grad. batch_size (int): The batch size. Default: None. """ + while_block_info = {} - def __init__(self, - auto_parallel_main_prog, - auto_parallel_startup_prog, - rank_id, - dist_context, - dist_params_grads, - batch_size=None): - assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_main_prog should be Program, " \ - "but got {}.".format(type(auto_parallel_main_prog)) + def __init__( + self, + auto_parallel_main_prog, + auto_parallel_startup_prog, + rank_id, + dist_context, + dist_params_grads, + batch_size=None, + ): + assert isinstance(auto_parallel_main_prog, Program), ( + "The type of auto_parallel_main_prog should be Program, " + "but got {}.".format(type(auto_parallel_main_prog)) + ) if auto_parallel_startup_prog is not None: - assert isinstance(auto_parallel_main_prog, Program), "The type of auto_parallel_startup_prog should be Program or None, " \ - "but got {}.".format(type(auto_parallel_startup_prog)) - assert isinstance(rank_id, int), "The type of rank_id should be int, " \ - "but got {}.".format(type(rank_id)) - assert isinstance(dist_context, DistributedContext), "The type of dist_context should be DistributedContext, " \ - "but got {}.".format(type(dist_context)) + assert isinstance(auto_parallel_main_prog, Program), ( + "The type of auto_parallel_startup_prog should be Program or None, " + "but got {}.".format(type(auto_parallel_startup_prog)) + ) + assert isinstance( + rank_id, int + ), "The type of rank_id should be int, " "but got {}.".format( + type(rank_id) + ) + assert isinstance(dist_context, DistributedContext), ( + "The type of dist_context should be DistributedContext, " + "but got {}.".format(type(dist_context)) + ) if batch_size is not None: - assert isinstance(batch_size, int), "The type of batch_size should be int, " \ - "but got {}.".format(type(batch_size)) + assert isinstance( + batch_size, int + ), "The type of batch_size should be int, " "but got {}.".format( + type(batch_size) + ) self._auto_parallel_main_prog = auto_parallel_main_prog self._auto_parallel_startup_prog = auto_parallel_startup_prog @@ -946,29 +1051,37 @@ class Resharder: for i in range(len(process_shape)): idx = relative_process // (product // process_shape[i]) product = product // process_shape[i] - relative_process = relative_process - relative_process // product * product + relative_process = ( + relative_process - relative_process // product * product + ) process_index.append(idx) return process_index @staticmethod - def compute_partition_index(process, complete_shape, dims_mapping, - process_shape, process_group): + def compute_partition_index( + process, complete_shape, dims_mapping, process_shape, process_group + ): """Compute the partition index in complete tensor.""" partition_shape = Resharder.compute_partition_shape( - complete_shape, dims_mapping, process_shape) - process_index = Resharder.compute_process_index(process, process_group, - process_shape) + complete_shape, dims_mapping, process_shape + ) + process_index = Resharder.compute_process_index( + process, process_group, process_shape + ) partition_index = [] for i in range(len(complete_shape)): if dims_mapping[i] == -1: partition_index.append([0, partition_shape[i]]) else: - partition_index.append([ - process_index[dims_mapping[i]] * partition_shape[i], - (process_index[dims_mapping[i]] + 1) * partition_shape[i] - ]) + partition_index.append( + [ + process_index[dims_mapping[i]] * partition_shape[i], + (process_index[dims_mapping[i]] + 1) + * partition_shape[i], + ] + ) return partition_index @@ -983,12 +1096,16 @@ class Resharder: for idx, item in enumerate(partition_index_x): if item != partition_index_y[idx]: differ_count += 1 - if item[1] == partition_index_y[idx][ - 0] and item[0] < partition_index_y[idx][1]: + if ( + item[1] == partition_index_y[idx][0] + and item[0] < partition_index_y[idx][1] + ): concat_axis = idx new_partition.append([item[0], partition_index_y[idx][1]]) - elif item[0] == partition_index_y[idx][ - 1] and item[1] > partition_index_y[idx][0]: + elif ( + item[0] == partition_index_y[idx][1] + and item[1] > partition_index_y[idx][0] + ): first_order = 1 concat_axis = idx new_partition.append([partition_index_y[idx][0], item[1]]) @@ -1021,12 +1138,14 @@ class Resharder: has_concat = False while i < len(partition_index_list): concat_axis, _, new_partition = Resharder.compute_concat_info( - partition_index_list[i], partition_index) + partition_index_list[i], partition_index + ) if concat_axis != -1: has_concat = True partition_index_list.pop(i) - Resharder.concat_partitions(partition_index_list, - new_partition) + Resharder.concat_partitions( + partition_index_list, new_partition + ) break i += 1 if not has_concat: @@ -1038,7 +1157,8 @@ class Resharder: for sub_block_idx in Resharder.while_block_info: sub_block = auto_parallel_main_prog.blocks[sub_block_idx] parent_while_op_id = Resharder.while_block_info[sub_block_idx][ - "op_id"] + "op_id" + ] parent_block = auto_parallel_main_prog.blocks[sub_block.parent_idx] sub_block_op_inputs = set() @@ -1046,10 +1166,12 @@ class Resharder: for op in sub_block.ops: # skip the input and output of operators inserted in the reshard phase dist_op = dist_context.get_dist_op_for_program(op) - if dist_op or (op.type == "slice" and not dist_op) or ( - op.type == "split" - and not dist_op) or (op.type == "assign" - and not dist_op): + if ( + dist_op + or (op.type == "slice" and not dist_op) + or (op.type == "split" and not dist_op) + or (op.type == "assign" and not dist_op) + ): for var_name in op.output_arg_names: if var_name not in sub_block_op_outputs: sub_block_op_outputs.append(var_name) @@ -1080,8 +1202,9 @@ class Resharder: for var_name in while_op.output("Out"): for output_name in sub_block_op_outputs[::-1]: if output_name.find(var_name) != -1 and ( - len(var_name) == len(output_name) - or "@RESHARD" in output_name): + len(var_name) == len(output_name) + or "@RESHARD" in output_name + ): if output_name not in new_Out: new_Out.append(output_name) assert new_Out @@ -1090,8 +1213,9 @@ class Resharder: def is_overlapped(self, shape_x, shape_y): """Judge whether two partitions intersect on the specified dimension.""" overlapped = False - if (shape_y[0] <= shape_x[0] < shape_y[1]) or (shape_x[0] <= shape_y[0] - < shape_x[1]): + if (shape_y[0] <= shape_x[0] < shape_y[1]) or ( + shape_x[0] <= shape_y[0] < shape_x[1] + ): overlapped = True return overlapped @@ -1119,8 +1243,9 @@ class Resharder: # the dims mapping of condition tensor should be replicative for var_name in input_cond: - var = get_var_with_recursion(var_name, sub_block, - self.auto_parallel_main_prog) + var = get_var_with_recursion( + var_name, sub_block, self.auto_parallel_main_prog + ) dist_tensor = self.dist_context.get_dist_tensor_for_program(var) tensor_dist_attr = dist_tensor.dist_attr var_dims_mapping = tensor_dist_attr.dims_mapping @@ -1143,13 +1268,22 @@ class Resharder: if op_input: op_input_dims_mapping = dist_attr[1] if all( - map(lambda x: x, [ - tensor_dims_mapping, tensor_process_mesh, - op_input_dims_mapping, op_process_mesh - ])): + map( + lambda x: x, + [ + tensor_dims_mapping, + tensor_process_mesh, + op_input_dims_mapping, + op_process_mesh, + ], + ) + ): # judge whether need reshard by dims_mapping if tensor_dims_mapping != op_input_dims_mapping: - if tensor_process_mesh not in self.dist_context.process_meshes: + if ( + tensor_process_mesh + not in self.dist_context.process_meshes + ): # assert whether -1 when union. for item in tensor_dims_mapping: if item != -1: @@ -1173,10 +1307,16 @@ class Resharder: else: op_output_dims_mapping = dist_attr[1] if all( - map(lambda x: x, [ - tensor_dims_mapping, tensor_process_mesh, - op_output_dims_mapping, op_process_mesh - ])): + map( + lambda x: x, + [ + tensor_dims_mapping, + tensor_process_mesh, + op_output_dims_mapping, + op_process_mesh, + ], + ) + ): if tensor_dims_mapping != op_output_dims_mapping: raise ValueError( "It is not supported that tensor dims mapping is different from op output dims mapping." @@ -1193,10 +1333,9 @@ class Resharder: op_process_mesh = dist_op.dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: - if set(process_mesh.processes) & (set( - op_process_mesh.processes)) and len( - process_mesh.processes) < len( - op_process_mesh.processes): + if set(process_mesh.processes) & ( + set(op_process_mesh.processes) + ) and len(process_mesh.processes) < len(op_process_mesh.processes): process_meshes.append(process_mesh) # it means the process mesh is not a union when process meshes is null @@ -1232,40 +1371,55 @@ class Resharder: target_process_group = target_process_mesh.processes target_process_shape = target_process_mesh.topology + op_role = dist_attr[2] + if source_tensor.shape[0] < 0: assert source_tensor.shape[0] == -1 new_shape = list(source_tensor.shape) new_shape[0] = self.batch_size source_tensor.desc.set_shape(new_shape) - complete_shape = Resharder.compute_complete_shape( - source_tensor.shape, source_process_shape, - source_dims_mapping) if not serial else source_tensor.shape + complete_shape = ( + Resharder.compute_complete_shape( + source_tensor.shape, source_process_shape, source_dims_mapping + ) + if not serial + else source_tensor.shape + ) op_desc_seq = {} # TODO: if the target process group has the same process with source process group - if set(target_process_group).intersection(set( - source_process_group)) and set(target_process_group).difference( - set(source_process_group)): + if set(target_process_group).intersection( + set(source_process_group) + ) and set(target_process_group).difference(set(source_process_group)): pass elif target_process_group != source_process_group: partition_process_mapping_list = [] for source_process in source_process_group: # get partition index of source process - source_partition_index = Resharder.compute_partition_index(source_process, complete_shape, source_dims_mapping, \ - source_process_shape, source_process_group) + source_partition_index = Resharder.compute_partition_index( + source_process, + complete_shape, + source_dims_mapping, + source_process_shape, + source_process_group, + ) if not partition_process_mapping_list: # the item in partition_process_mapping_list is source_partition_index, which processes and whether has been used partition_process_mapping_list.append( - [source_partition_index, [source_process], [False]]) + [source_partition_index, [source_process], [False]] + ) else: partition_list = list( - [item[0] for item in partition_process_mapping_list]) + [item[0] for item in partition_process_mapping_list] + ) process_list = list( - [item[1] for item in partition_process_mapping_list]) + [item[1] for item in partition_process_mapping_list] + ) has_used = list( - [item[2] for item in partition_process_mapping_list]) + [item[2] for item in partition_process_mapping_list] + ) if partition_list.count(source_partition_index) == 1: index = partition_list.index(source_partition_index) @@ -1273,32 +1427,52 @@ class Resharder: has_used[index].append(False) else: partition_process_mapping_list.append( - [source_partition_index, [source_process], [False]]) + [source_partition_index, [source_process], [False]] + ) for target_process in target_process_group: # has_sent means the source_partition_index has been sent to target_process has_sent = [] target_partition_index = Resharder.compute_partition_index( - target_process, complete_shape, target_dims_mapping, - target_process_shape, target_process_group) + target_process, + complete_shape, + target_dims_mapping, + target_process_shape, + target_process_group, + ) partition_index_list = [] all_partition_index_list = [] for source_process in source_process_group: source_partition_index = Resharder.compute_partition_index( - source_process, complete_shape, source_dims_mapping, - source_process_shape, source_process_group) + source_process, + complete_shape, + source_dims_mapping, + source_process_shape, + source_process_group, + ) to_send_process = None - if all(_ for _ in list(map(self.is_overlapped, source_partition_index, target_partition_index))) \ - and source_partition_index not in has_sent: - idx = list([ - item[0] for item in partition_process_mapping_list - ]).index(source_partition_index) - has_used = list([ - item[2] for item in partition_process_mapping_list - ])[idx] - process_list = list([ - item[1] for item in partition_process_mapping_list - ])[idx] + if ( + all( + _ + for _ in list( + map( + self.is_overlapped, + source_partition_index, + target_partition_index, + ) + ) + ) + and source_partition_index not in has_sent + ): + idx = list( + [item[0] for item in partition_process_mapping_list] + ).index(source_partition_index) + has_used = list( + [item[2] for item in partition_process_mapping_list] + )[idx] + process_list = list( + [item[1] for item in partition_process_mapping_list] + )[idx] i = 0 while i < len(has_used): if not has_used[i]: @@ -1311,7 +1485,9 @@ class Resharder: has_used = list(map(lambda x: False, has_used)) to_send_process = process_list[0] has_used[0] = True - assert to_send_process is not None, "Failed to find the send process." + assert ( + to_send_process is not None + ), "Failed to find the send process." if to_send_process not in op_desc_seq.keys(): op_desc_seq[to_send_process] = [] @@ -1320,25 +1496,34 @@ class Resharder: all_partition_index_list.append(source_partition_index) # append send and recv op desc - is_bool = ( - dist_tensor.serial_tensor.dtype == paddle.bool) - send_op_desc = SendOpDesc(source_partition_index, - to_send_process, - target_process, - is_bool=is_bool) - recv_op_desc = RecvOpDesc(source_partition_index, - to_send_process, - target_process, - is_bool=is_bool) + is_bool = dist_tensor.serial_tensor.dtype == paddle.bool + send_op_desc = SendOpDesc( + source_partition_index, + to_send_process, + target_process, + is_bool=is_bool, + ) + recv_op_desc = RecvOpDesc( + source_partition_index, + to_send_process, + target_process, + is_bool=is_bool, + ) op_desc_seq[to_send_process].append(send_op_desc) op_desc_seq[target_process].append(recv_op_desc) has_sent.append(source_partition_index) - Resharder.concat_partitions(partition_index_list, - source_partition_index) + Resharder.concat_partitions( + partition_index_list, source_partition_index + ) + if int(op_role) == int(OpRole.Forward): + self.dist_context.up_down_streams.add_pair_stream( + to_send_process, target_process + ) # append concat op desc op_desc_seq[target_process].append( - ConcatOpDesc(all_partition_index_list)) + ConcatOpDesc(all_partition_index_list) + ) # append slice op desc slice_starts = [] @@ -1348,17 +1533,21 @@ class Resharder: to_slice_tensor_shape = [] for idx, item in enumerate(concatenated_partition_index): - slice_starts.append(target_partition_index[idx][0] - - item[0]) + slice_starts.append( + target_partition_index[idx][0] - item[0] + ) slice_ends.append(target_partition_index[idx][1] - item[0]) slices_axes.append(idx) to_slice_tensor_shape.append(item[1] - item[0]) op_desc_seq[target_process].append( - SliceOpDesc(slice_starts, - slice_ends, - slices_axes, - shape=to_slice_tensor_shape)) + SliceOpDesc( + slice_starts, + slice_ends, + slices_axes, + shape=to_slice_tensor_shape, + ) + ) # in the same process group, it will use allgahther and slice op. else: @@ -1368,16 +1557,26 @@ class Resharder: process_index = [] for source_process in source_process_group: source_partition_index = Resharder.compute_partition_index( - source_process, complete_shape, source_dims_mapping, - source_process_shape, source_process_group) + source_process, + complete_shape, + source_dims_mapping, + source_process_shape, + source_process_group, + ) if source_partition_index not in partition_index_list: partition_index_list.append(source_partition_index) - process_index.append([[ - source_process, - ], source_partition_index]) + process_index.append( + [ + [ + source_process, + ], + source_partition_index, + ] + ) else: - process_index[partition_index_list.index( - source_partition_index)][0].append(source_process) + process_index[ + partition_index_list.index(source_partition_index) + ][0].append(source_process) for i in range(len(process_index[0][0])): group = [] @@ -1391,28 +1590,50 @@ class Resharder: slice_ends = [] slices_axes = [] target_partition_index = Resharder.compute_partition_index( - process, complete_shape, target_dims_mapping, - target_process_shape, target_process_group) + process, + complete_shape, + target_dims_mapping, + target_process_shape, + target_process_group, + ) for idx, item in enumerate(target_partition_index): slice_starts.append(item[0]) slice_ends.append(item[1]) slices_axes.append(idx) to_slice_tensor_shape = dist_tensor.global_sizes() - slice_op_desc = SliceOpDesc(starts=slice_starts, - ends=slice_ends, - axes=slices_axes, - shape=to_slice_tensor_shape) - allgather_shape = None if not serial else dist_tensor.local_sizes( - rank=process) - op_desc_seq[process] = [AllGatherOpDesc(group=group, shape=allgather_shape, is_bool=(source_tensor.dtype == paddle.bool)), - ConcatOpDesc(partition_index_list=all_partition_index_list), slice_op_desc] \ - if len(group) > 1 else [slice_op_desc] + slice_op_desc = SliceOpDesc( + starts=slice_starts, + ends=slice_ends, + axes=slices_axes, + shape=to_slice_tensor_shape, + ) + allgather_shape = ( + None + if not serial + else dist_tensor.local_sizes(rank=process) + ) + op_desc_seq[process] = ( + [ + AllGatherOpDesc( + group=group, + shape=allgather_shape, + is_bool=(source_tensor.dtype == paddle.bool), + ), + ConcatOpDesc( + partition_index_list=all_partition_index_list + ), + slice_op_desc, + ] + if len(group) > 1 + else [slice_op_desc] + ) return op_desc_seq - def parse_op_desc(self, block, op_desc_seq, var_name, reshard_op, - dist_attr): + def parse_op_desc( + self, block, op_desc_seq, var_name, reshard_op, dist_attr + ): """Parse op desc sequence and insert op in the block""" tensor_list = [] partition_tensor_list = [] @@ -1425,55 +1646,84 @@ class Resharder: if op.desc.id == reshard_op.desc.id: idx = index break - assert idx is not None, "The op for reshard cannot be found in the rank {} program.".format( - self.rank_id) + assert ( + idx is not None + ), "The op for reshard cannot be found in the rank {} program.".format( + self.rank_id + ) matched_op = block.ops[idx] - source_tensor = get_var_with_recursion(var_name, block, - self.auto_parallel_main_prog) + source_tensor = get_var_with_recursion( + var_name, block, self.auto_parallel_main_prog + ) for op_desc in op_desc_list: if isinstance(op_desc, AllGatherOpDesc): # noqa: F401 if var_name not in self.has_allgather.keys(): self.has_allgather[var_name] = [] - if not self.has_allgather[var_name] or op_desc.group not in list( - map(lambda x: x[0], self.has_allgather[var_name])): + if not self.has_allgather[ + var_name + ] or op_desc.group not in list( + map(lambda x: x[0], self.has_allgather[var_name]) + ): if op_desc.is_bool: # for bool data allgather, cast to int64 -> allgather -> cast bool out_cast = Inserter.insert_cast_op( - block, idx, source_tensor, - reshard_op.attr('op_role'), paddle.int64) + block, + idx, + source_tensor, + reshard_op.attr('op_role'), + paddle.int64, + ) tensor_list, idx_offset = Inserter.insert_allgather_op( - block, idx + 1, out_cast, op_desc.group, - reshard_op.attr('op_role')) + block, + idx + 1, + out_cast, + op_desc.group, + reshard_op.attr('op_role'), + ) idx += idx_offset tensor_name_list = [] for var in tensor_list: out_cast = Inserter.insert_cast_op( - block, idx, var, reshard_op.attr('op_role'), - paddle.bool) + block, + idx, + var, + reshard_op.attr('op_role'), + paddle.bool, + ) tensor_name_list.append(out_cast.name) idx += 1 self.has_allgather[var_name].append( - [op_desc.group, tensor_name_list]) + [op_desc.group, tensor_name_list] + ) else: tensor_list, idx_offset = Inserter.insert_allgather_op( - block, idx, source_tensor, op_desc.group, - reshard_op.attr('op_role')) + block, + idx, + source_tensor, + op_desc.group, + reshard_op.attr('op_role'), + ) idx += idx_offset tensor_name_list = [var.name for var in tensor_list] self.has_allgather[var_name].append( - [op_desc.group, tensor_name_list]) + [op_desc.group, tensor_name_list] + ) else: for item in self.has_allgather[var_name]: if op_desc.group == item[0]: tensor_list = [ get_var_with_recursion( - var_name, block, - self.auto_parallel_main_prog) + var_name, + block, + self.auto_parallel_main_prog, + ) for var_name in item[1] ] break - assert tensor_list, "The result of parsing allgather op should not be None." + assert ( + tensor_list + ), "The result of parsing allgather op should not be None." elif isinstance(op_desc, SendOpDesc): if var_name not in self.has_sent.keys(): @@ -1481,16 +1731,30 @@ class Resharder: if op_desc.dst not in self.has_sent[var_name]: if op_desc.is_bool: out_cast = Inserter.insert_cast_op( - block, idx, source_tensor, - reshard_op.attr('op_role'), paddle.int64) - Inserter.insert_send_op(block, idx + 1, out_cast, - op_desc.src, op_desc.dst, - reshard_op.attr('op_role')) + block, + idx, + source_tensor, + reshard_op.attr('op_role'), + paddle.int64, + ) + Inserter.insert_send_op( + block, + idx + 1, + out_cast, + op_desc.src, + op_desc.dst, + reshard_op.attr('op_role'), + ) idx += 2 else: - Inserter.insert_send_op(block, idx, source_tensor, - op_desc.src, op_desc.dst, - reshard_op.attr('op_role')) + Inserter.insert_send_op( + block, + idx, + source_tensor, + op_desc.src, + op_desc.dst, + reshard_op.attr('op_role'), + ) idx += 1 self.has_sent[var_name].append(op_desc.dst) @@ -1509,13 +1773,23 @@ class Resharder: shape=shape, lod_level=source_tensor.lod_level, dtype=paddle.int64, - type=source_tensor.type) - Inserter.insert_recv_op(block, idx, recv_tensor, - op_desc.src, op_desc.dst, - reshard_op.attr('op_role')) + type=source_tensor.type, + ) + Inserter.insert_recv_op( + block, + idx, + recv_tensor, + op_desc.src, + op_desc.dst, + reshard_op.attr('op_role'), + ) out_cast = Inserter.insert_cast_op( - block, idx + 1, recv_tensor, - reshard_op.attr('op_role'), paddle.bool) + block, + idx + 1, + recv_tensor, + reshard_op.attr('op_role'), + paddle.bool, + ) tensor_list.append(out_cast) idx += 2 self.has_recv[var_name][op_desc.src] = out_cast @@ -1525,26 +1799,45 @@ class Resharder: shape=shape, lod_level=source_tensor.lod_level, dtype=source_tensor.dtype, - type=source_tensor.type) - Inserter.insert_recv_op(block, idx, recv_tensor, - op_desc.src, op_desc.dst, - reshard_op.attr('op_role')) + type=source_tensor.type, + ) + Inserter.insert_recv_op( + block, + idx, + recv_tensor, + op_desc.src, + op_desc.dst, + reshard_op.attr('op_role'), + ) # for lod tensor, need reset lod after received if recv_tensor.lod_level != 0: set_lod = False # use data lod to reset tensor lod - for tmp_block in self.auto_parallel_main_prog.blocks: + for ( + tmp_block + ) in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] - if tmp_var.is_data and tmp_var.lod_level == recv_tensor.lod_level: - reset_lod_out = Inserter.insert_reset_lod_op( - block, idx + 1, recv_tensor, - tmp_var, reshard_op.attr('op_role')) + if ( + tmp_var.is_data + and tmp_var.lod_level + == recv_tensor.lod_level + ): + reset_lod_out = ( + Inserter.insert_reset_lod_op( + block, + idx + 1, + recv_tensor, + tmp_var, + reshard_op.attr('op_role'), + ) + ) tensor_list.append(reset_lod_out) idx += 2 self.has_recv[var_name][ - op_desc.src] = reset_lod_out + op_desc.src + ] = reset_lod_out set_lod = True break if set_lod: @@ -1562,16 +1855,24 @@ class Resharder: idx_list = [idx] for index, tensor in enumerate(tensor_list): Inserter.concat_partitions_with_op( - partition_tensor_list, tensor, - partition_index_list[index], block, idx_list, - reshard_op.attr('op_role')) + partition_tensor_list, + tensor, + partition_index_list[index], + block, + idx_list, + reshard_op.attr('op_role'), + ) idx = idx_list[0] elif isinstance(op_desc, SliceOpDesc): - assert len( - partition_tensor_list) == 1 or not partition_tensor_list - to_slice_tensor = partition_tensor_list[0][0] if len( - partition_tensor_list) == 1 else source_tensor + assert ( + len(partition_tensor_list) == 1 or not partition_tensor_list + ) + to_slice_tensor = ( + partition_tensor_list[0][0] + if len(partition_tensor_list) == 1 + else source_tensor + ) new_name = unique_name.generate(var_name + "@RESHARD") target_tensor = Inserter.insert_slice_op( block, @@ -1581,7 +1882,8 @@ class Resharder: ends=op_desc.ends, axes=op_desc.axes, new_var_name=new_name, - op_role=reshard_op.attr('op_role')) + op_role=reshard_op.attr('op_role'), + ) process_mesh = dist_attr[0] dims_mapping = dist_attr[1] @@ -1590,83 +1892,119 @@ class Resharder: tensor_attr.dims_mapping = dims_mapping tensor_attr.process_mesh = process_mesh self.dist_context.set_tensor_dist_attr_for_program( - target_tensor, tensor_attr) + target_tensor, tensor_attr + ) if matched_op.type == "while": # var_reshard_mapping means the while op input need be changed to - if "var_reshard_mapping" not in Resharder.while_block_info[ - op.attr("sub_block").id].keys(): - Resharder.while_block_info[op.attr( - "sub_block").id]["var_reshard_mapping"] = {} - if var_name not in Resharder.while_block_info[op.attr( - "sub_block").id]["var_reshard_mapping"].keys(): + if ( + "var_reshard_mapping" + not in Resharder.while_block_info[ + op.attr("sub_block").id + ].keys() + ): Resharder.while_block_info[op.attr("sub_block").id][ - "var_reshard_mapping"][var_name] = [] + "var_reshard_mapping" + ] = {} + if ( + var_name + not in Resharder.while_block_info[ + op.attr("sub_block").id + ]["var_reshard_mapping"].keys() + ): + Resharder.while_block_info[op.attr("sub_block").id][ + "var_reshard_mapping" + ][var_name] = [] Resharder.while_block_info[op.attr("sub_block").id][ - "var_reshard_mapping"][var_name].append( - [dist_attr, target_tensor.name]) + "var_reshard_mapping" + ][var_name].append([dist_attr, target_tensor.name]) # rename op input name according to new name for op in block.ops: # just for while op while_op_X_append = [] for name in op.input_arg_names: - op_dist_attr = self.dist_context.get_op_dist_attr_for_program( - op) + op_dist_attr = ( + self.dist_context.get_op_dist_attr_for_program(op) + ) if name == var_name and op_dist_attr is not None: if op.desc.id() == matched_op.desc.id(): if matched_op.type == "while": old_name = name new_name = target_tensor.name assert old_name != new_name - op_input_dist_attr = op_dist_attr.get_input_dist_attr( - old_name) + op_input_dist_attr = ( + op_dist_attr.get_input_dist_attr( + old_name + ) + ) op_dist_attr.set_input_dist_attr( - new_name, op_input_dist_attr) + new_name, op_input_dist_attr + ) op_dist_attr.set_input_dims_mapping( - new_name, dims_mapping) - if old_name in op_dist_attr._inputs_dist_attrs: + new_name, dims_mapping + ) + if ( + old_name + in op_dist_attr._inputs_dist_attrs + ): op_dist_attr.del_input_dist_attr( - old_name) + old_name + ) while_op_X_append.append(new_name) continue else: op.desc._rename_input( - name, target_tensor.name) + name, target_tensor.name + ) old_name = name new_name = target_tensor.name assert old_name != new_name - op_input_dist_attr = op_dist_attr.get_input_dist_attr( - old_name) + op_input_dist_attr = ( + op_dist_attr.get_input_dist_attr( + old_name + ) + ) op_dist_attr.set_input_dist_attr( - new_name, op_input_dist_attr) + new_name, op_input_dist_attr + ) op_dist_attr.set_input_dims_mapping( - new_name, dims_mapping) + new_name, dims_mapping + ) op_dist_attr.del_input_dist_attr(old_name) continue op_process_mesh = op_dist_attr.process_mesh - op_input_dims_mapping = op_dist_attr.get_input_dims_mapping( - var_name) + op_input_dims_mapping = ( + op_dist_attr.get_input_dims_mapping(var_name) + ) # NOTE: For op whose process mesh is a union, its input will not be renamed by other op reshard result now which means that it will have more reshard operation. - if op_process_mesh == process_mesh and op_input_dims_mapping == dims_mapping: + if ( + op_process_mesh == process_mesh + and op_input_dims_mapping == dims_mapping + ): op.desc._rename_input(name, target_tensor.name) old_name = name new_name = target_tensor.name assert old_name != new_name - op_input_dist_attr = op_dist_attr.get_input_dist_attr( - old_name) + op_input_dist_attr = ( + op_dist_attr.get_input_dist_attr(old_name) + ) op_dist_attr.set_input_dist_attr( - new_name, op_input_dist_attr) + new_name, op_input_dist_attr + ) op_dist_attr.set_input_dims_mapping( - new_name, dims_mapping) + new_name, dims_mapping + ) op_dist_attr.del_input_dist_attr(old_name) # for while op, the input X should reset if while_op_X_append: proto = OpProtoHolder.instance().get_op_proto(op.type) - op.desc.set_input(proto.inputs[0].name, - op.input("X") + while_op_X_append) + op.desc.set_input( + proto.inputs[0].name, + op.input("X") + while_op_X_append, + ) def _get_subblock_input_attrs(self, op, var_name): # NOTE: Multi while loop is not supported @@ -1684,27 +2022,70 @@ class Resharder: if name == var_name: process_mesh = dist_attr.process_mesh input_dims_mapping = dist_attr.get_input_dims_mapping( - var_name) + var_name + ) has_exist = False for input_attr in input_attrs: - if process_mesh == input_attr[ - 0] and input_dims_mapping == input_attr[1]: + if ( + process_mesh == input_attr[0] + and input_dims_mapping == input_attr[1] + ): has_exist = True break if not has_exist: - input_attrs.append([process_mesh, input_dims_mapping]) + input_attrs.append( + [ + process_mesh, + input_dims_mapping, + op.attr('op_role'), + ] + ) return input_attrs + def _get_subblock_output_attrs(self, op, var_name): + assert op.type in _g_subblock_ops + sub_block = self.auto_parallel_main_prog.blocks[op.attr("sub_block").id] + ops = sub_block.ops + output_attrs = [] + + for op in ops: + dist_op = self.dist_context.get_dist_op_for_program(op) + if not dist_op: + continue + dist_attr = dist_op.dist_attr + for name in op.output_arg_names: + if name == var_name: + process_mesh = dist_attr.process_mesh + output_dims_mapping = dist_attr.get_output_dims_mapping( + var_name + ) + has_exist = False + for output_attr in output_attrs: + if ( + process_mesh == output_attr[0] + and output_dims_mapping == output_attr[1] + ): + has_exist = True + break + if not has_exist: + output_attrs.append( + [ + process_mesh, + output_dims_mapping, + op.attr('op_role'), + ] + ) + return output_attrs + def _get_common_op_input_attrs(self, op, var_name): process_meshes = [] dist_op = self.dist_context.get_dist_op_for_program(op) dist_attr = dist_op.dist_attr op_process_mesh = dist_attr.process_mesh for process_mesh in self.dist_context.process_meshes: - if set(process_mesh.processes) & (set( - op_process_mesh.processes)) and len( - process_mesh.processes) < len( - op_process_mesh.processes): + if set(process_mesh.processes) & ( + set(op_process_mesh.processes) + ) and len(process_mesh.processes) < len(op_process_mesh.processes): process_meshes.append(process_mesh) # it means that the process mesh is not a union when process meshes is none @@ -1714,7 +2095,9 @@ class Resharder: input_dims_mapping = dist_attr.get_input_dims_mapping(var_name) input_attrs = [] for process_mesh in process_meshes: - input_attrs.append([process_mesh, input_dims_mapping]) + input_attrs.append( + [process_mesh, input_dims_mapping, op.attr('op_role')] + ) return input_attrs @@ -1723,6 +2106,8 @@ class Resharder: if op.type in _g_subblock_ops: op_input_attrs = self._get_subblock_input_attrs(op, var_name) + if not op_input_attrs: + op_input_attrs = self._get_subblock_output_attrs(op, var_name) else: op_input_attrs = self._get_common_op_input_attrs(op, var_name) @@ -1735,32 +2120,28 @@ class Resharder: processes = set() process_mesh_count = len(self.dist_context.process_meshes) if process_mesh_count > 1: - global_process_mesh_idx = None + global_process_mesh_idx = [] + has_sub_process_mesh = False for process_mesh in self.dist_context.process_meshes: for process in process_mesh.processes: processes.add(process) for idx, process_mesh in enumerate( - self.dist_context.process_meshes): + self.dist_context.process_meshes + ): if len(set(process_mesh.processes)) == len(processes): - global_process_mesh_idx = idx - break + global_process_mesh_idx.append(idx) + elif set(process_mesh.processes) < processes: + has_sub_process_mesh = True - if global_process_mesh_idx is not None: - is_removed = False - global_mesh = self.dist_context.process_meshes[idx] - for i, mesh in enumerate(self.dist_context.process_meshes): - if i == idx: - continue - if set(mesh.processes) < set(global_mesh.processes): - is_removed = True - - if is_removed: + if has_sub_process_mesh: + for idx in reversed(global_process_mesh_idx): self.dist_context.process_meshes.pop(idx) def _change_subblock_op_input_and_output(self, block_idx, block): if "var_reshard_mapping" in Resharder.while_block_info[block_idx]: var_reshard_mapping = Resharder.while_block_info[block_idx][ - "var_reshard_mapping"] + "var_reshard_mapping" + ] for op in block.ops: for var_name in op.input_arg_names: if var_name in var_reshard_mapping: @@ -1769,9 +2150,11 @@ class Resharder: dist_attr = dist_op.dist_attr target_name = None for item in var_reshard_mapping[var_name]: - if dist_attr.process_mesh == item[0][ - 0] and dist_attr.get_input_dims_mapping( - var_name) == item[0][1]: + if ( + dist_attr.process_mesh == item[0][0] + and dist_attr.get_input_dims_mapping(var_name) + == item[0][1] + ): target_name = item[1] break if target_name is None: @@ -1779,15 +2162,18 @@ class Resharder: else: op.desc._rename_input(var_name, target_name) dist_op = self.dist_context.get_dist_op_for_program( - op) + op + ) op_dist_attr = dist_op.dist_attr old_name = var_name new_name = target_name assert old_name != new_name - op_input_dist_attr = op_dist_attr.get_input_dist_attr( - old_name) + op_input_dist_attr = ( + op_dist_attr.get_input_dist_attr(old_name) + ) op_dist_attr.set_input_dist_attr( - new_name, op_input_dist_attr) + new_name, op_input_dist_attr + ) op_dist_attr.del_input_dist_attr(old_name) # the outputs also need to be renamed when the output name is the same with input name in inplace op @@ -1807,9 +2193,11 @@ class Resharder: new_name = target_name assert old_name != new_name op_output_dist_attr = op_dist_attr.get_output_dist_attr( - old_name) + old_name + ) op_dist_attr.set_output_dist_attr( - new_name, op_output_dist_attr) + new_name, op_output_dist_attr + ) op_dist_attr.del_output_dist_attr(old_name) def _reshard_input(self, block): @@ -1824,18 +2212,22 @@ class Resharder: dist_op = self.dist_context.get_dist_op_for_program(op) if dist_op is not None: - op_input_dist_attrs = [ - ] # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] + op_input_dist_attrs = ( + [] + ) # [(op_process_mesh, op_input_dims_mapping), (op_process_mesh, op_input_dims_mapping)] if op.type in _g_subblock_ops: if not self.is_condition_replicative(op): raise ValueError( "Please check the condition due to the dims mapping is not replicative." ) - if op.attr( - "sub_block").id not in Resharder.while_block_info: + if ( + op.attr("sub_block").id + not in Resharder.while_block_info + ): Resharder.while_block_info[op.attr("sub_block").id] = {} - Resharder.while_block_info[op.attr( - "sub_block").id]["op_id"] = op.desc.id() + Resharder.while_block_info[op.attr("sub_block").id][ + "op_id" + ] = op.desc.id() if op.type == "while": # condition var process mesh is the same with op and dims_mapping is replicative, so it do not need reshard @@ -1852,17 +2244,24 @@ class Resharder: # skip lod_tensor_blocking_queue_? name if "lod_tensor_blocking_queue" in var_name: continue - var = get_var_with_recursion(var_name, block, - self.auto_parallel_main_prog) + var = get_var_with_recursion( + var_name, block, self.auto_parallel_main_prog + ) dist_tensor = self.dist_context.get_dist_tensor_for_program( - var) + var + ) # judge whether union tensor dims_mapping all -1 is_union_process_mesh_tensor = False - if dist_tensor.dist_attr.process_mesh not in self.dist_context.process_meshes and self.dist_context.process_meshes: + if ( + dist_tensor.dist_attr.process_mesh + not in self.dist_context.process_meshes + and self.dist_context.process_meshes + ): is_union_process_mesh_tensor = True assert dist_tensor.dist_attr.dims_mapping.count( - -1) == len(dist_tensor.dist_attr.dims_mapping) + -1 + ) == len(dist_tensor.dist_attr.dims_mapping) op_input_attrs = self.get_op_input_attrs(op, var_name) for input_attr in op_input_attrs: @@ -1872,18 +2271,23 @@ class Resharder: if is_union_process_mesh_tensor: # if op process mesh is subset of union tensor process mesh, need no reshard if set(input_attr[0].processes) <= set( - dist_tensor.dist_attr.process_mesh.processes + dist_tensor.dist_attr.process_mesh.processes ): continue if dist_tensor is not None and self.need_reshard( - dist_tensor, input_attr): + dist_tensor, input_attr + ): reshard_op_desc = self.find_op_desc_seq( - dist_tensor, input_attr) - self.parse_op_desc(block, reshard_op_desc, var_name, - op, input_attr) + dist_tensor, input_attr + ) + self.parse_op_desc( + block, reshard_op_desc, var_name, op, input_attr + ) cur_op_count = len(block.ops) - idx_offset = idx_offset + cur_op_count - pre_op_count + idx_offset = ( + idx_offset + cur_op_count - pre_op_count + ) pre_op_count = cur_op_count idx = idx + idx_offset + 1 else: @@ -1898,34 +2302,43 @@ class Resharder: shape=var.shape, lod_level=var.lod_level, dtype=paddle.int64, - type=var.type) - Inserter.insert_recv_op(block, idx + 1, - recv_cast_out, send_rank, recv_rank, - op.attr('op_role')) + type=var.type, + ) + Inserter.insert_recv_op( + block, + idx + 1, + recv_cast_out, + send_rank, + recv_rank, + op.attr('op_role'), + ) reset_lod_out = None if var.lod_level != 0: set_lod = False for tmp_block in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] - if tmp_var.is_data and tmp_var.lod_level == var.lod_level: + if ( + tmp_var.is_data + and tmp_var.lod_level == var.lod_level + ): reset_lod_out = block.create_var( - name=unique_name.generate(var.name + - "@RESETLOD"), + name=unique_name.generate( + var.name + "@RESETLOD" + ), shape=recv_cast_out.shape, type=recv_cast_out.type, dtype=recv_cast_out.dtype, - lod_level=recv_cast_out.lod_level) + lod_level=recv_cast_out.lod_level, + ) idx += 1 block._insert_op( idx, type="lod_reset", - inputs={ - 'X': recv_cast_out, - 'Y': tmp_var - }, + inputs={'X': recv_cast_out, 'Y': tmp_var}, outputs={'Out': reset_lod_out}, - attrs={'op_role': op.attr("op_role")}) + attrs={'op_role': op.attr("op_role")}, + ) set_lod = True break if set_lod: @@ -1933,18 +2346,22 @@ class Resharder: assert set_lod is True # cast int64 to bool - block._insert_op(idx + 2, - type='cast', - inputs={ - 'X': [recv_cast_out] if - reset_lod_out is None else [reset_lod_out] - }, - outputs={'Out': [var]}, - attrs={ - 'in_dtype': recv_cast_out.dtype, - 'out_dtype': var.dtype, - 'op_role': op.attr('op_role') - }) + cast_op = block._insert_op( + idx + 2, + type='cast', + inputs={ + 'X': [recv_cast_out] + if reset_lod_out is None + else [reset_lod_out] + }, + outputs={'Out': [var]}, + attrs={ + 'in_dtype': recv_cast_out.dtype, + 'out_dtype': var.dtype, + 'op_role': op.attr('op_role'), + }, + ) + cast_op._set_attr('op_namescope', "/auto_parallel/reshard") else: if var.lod_level != 0: recv_out = block.create_var( @@ -1952,50 +2369,75 @@ class Resharder: shape=var.shape, lod_level=var.lod_level, dtype=var.int64, - type=var.type) - Inserter.insert_recv_op(block, idx + 1, recv_out, send_rank, - recv_rank, op.attr('op_role')) + type=var.type, + ) + Inserter.insert_recv_op( + block, + idx + 1, + recv_out, + send_rank, + recv_rank, + op.attr('op_role'), + ) set_lod = False for tmp_block in self.auto_parallel_main_prog.blocks: for tmp_var_name in tmp_block.vars: tmp_var = tmp_block.vars[tmp_var_name] - if tmp_var.is_data and tmp_var.lod_level == var.lod_level: + if ( + tmp_var.is_data + and tmp_var.lod_level == var.lod_level + ): idx += 1 block._insert_op( idx, type="lod_reset", - inputs={ - 'X': recv_out, - 'Y': tmp_var - }, + inputs={'X': recv_out, 'Y': tmp_var}, outputs={'Out': var}, - attrs={'op_role': op.attr("op_role")}) + attrs={'op_role': op.attr("op_role")}, + ) set_lod = True break if set_lod: break assert set_lod is True else: - Inserter.insert_recv_op(block, idx + 1, var, send_rank, - recv_rank, op.attr('op_role')) + Inserter.insert_recv_op( + block, + idx + 1, + var, + send_rank, + recv_rank, + op.attr('op_role'), + ) def _handle_send(self, block, idx, var, op, send_rank, recv_rank): if var.dtype == paddle.bool: - cast_out = Inserter.insert_cast_op(block, idx + 1, var, - op.attr('op_role'), paddle.int64) - Inserter.insert_send_op(block, idx + 2, cast_out, send_rank, - recv_rank, op.attr('op_role')) + cast_out = Inserter.insert_cast_op( + block, idx + 1, var, op.attr('op_role'), paddle.int64 + ) + Inserter.insert_send_op( + block, + idx + 2, + cast_out, + send_rank, + recv_rank, + op.attr('op_role'), + ) else: - Inserter.insert_send_op(block, idx + 1, var, send_rank, recv_rank, - op.attr('op_role')) + Inserter.insert_send_op( + block, idx + 1, var, send_rank, recv_rank, op.attr('op_role') + ) def _reshard_output(self, block): # insert send and recv op if output process mesh is different from tensor process mesh idx = 0 # skip reader and ops whose process mesh is union skip_ops = [ - "create_py_reader", "create_double_buffer_reader", "read", - "write_to_array", "read_from_array" + "create_py_reader", + "create_double_buffer_reader", + "read", + "write_to_array", + "read_from_array", ] global _g_special_ops skip_ops += _g_special_ops @@ -2007,76 +2449,113 @@ class Resharder: if dist_op is not None and op.type not in skip_ops: idx_offset = 0 for var_name in op.output_arg_names: - var = get_var_with_recursion(var_name, block, - self.auto_parallel_main_prog) + var = get_var_with_recursion( + var_name, block, self.auto_parallel_main_prog + ) dist_tensor = self.dist_context.get_dist_tensor_for_program( - var) + var + ) tensor_process_mesh = dist_tensor.dist_attr.process_mesh output_attr = [ dist_op.dist_attr.process_mesh, - dist_op.dist_attr.get_output_dims_mapping(var_name) + dist_op.dist_attr.get_output_dims_mapping(var_name), ] if dist_tensor is not None and self.need_reshard( - dist_tensor, output_attr, False): + dist_tensor, output_attr, False + ): tensor_processes = set( - tensor_process_mesh.processes) - ( - set(tensor_process_mesh.processes) - & set(output_attr[0].processes)) + tensor_process_mesh.processes + ) - ( + set(tensor_process_mesh.processes) + & set(output_attr[0].processes) + ) if tensor_processes: if len(tensor_processes) != len( - output_attr[0].processes): + output_attr[0].processes + ): if dist_tensor.dist_attr.dims_mapping.count( - -1) != len( - dist_tensor.dist_attr.dims_mapping - ) or output_attr[1].count(-1) != len( - output_attr[1]): + -1 + ) != len( + dist_tensor.dist_attr.dims_mapping + ) or output_attr[ + 1 + ].count( + -1 + ) != len( + output_attr[1] + ): raise ValueError( - "The dims_mapping must be -1") + "The dims_mapping must be -1" + ) else: for index, tensor_process in enumerate( - tensor_processes): + tensor_processes + ): recv_rank = tensor_process actual_index = index if index >= len( - output_attr[0].processes): + output_attr[0].processes + ): actual_index = ( - index - - len(output_attr[0].processes) + index + - len(output_attr[0].processes) ) % len(output_attr[0].processes) item = output_attr[0].processes[ - actual_index] + actual_index + ] if recv_rank == item: continue + if var.shape[0] == -1: + new_shape = list(var.shape) + new_shape[0] = self.batch_size + var.desc.set_shape(new_shape) if self.rank_id == item: # if send bool data, cast then send self._handle_send( - block, idx, var, op, item, - recv_rank) + block, + idx, + var, + op, + item, + recv_rank, + ) if self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( - block, idx, var, op, item, - recv_rank) + block, + idx, + var, + op, + item, + recv_rank, + ) else: for index, tensor_process in enumerate( - tensor_processes): + tensor_processes + ): recv_rank = tensor_process item = output_attr[0].processes[index] if recv_rank == item: continue + if var.shape[0] == -1: + new_shape = list(var.shape) + new_shape[0] = self.batch_size + var.desc.set_shape(new_shape) if self.rank_id == item: # if send bool data, cast then send self._handle_send( - block, idx, var, op, item, - recv_rank) + block, idx, var, op, item, recv_rank + ) if self.rank_id == recv_rank: # if recv bool data, recv then cast self._hadnle_recv( - block, idx, var, op, item, - recv_rank) + block, idx, var, op, item, recv_rank + ) cur_op_count = len(block.ops) - idx_offset = idx_offset + cur_op_count - pre_op_count + idx_offset = ( + idx_offset + cur_op_count - pre_op_count + ) pre_op_count = cur_op_count idx = idx + idx_offset + 1 @@ -2098,13 +2577,17 @@ class Resharder: self._reshard_output(block) # remove no need vars and ops in the main program - Remover.remove_no_need_in_main(self.auto_parallel_main_prog, - self.dist_context, self.rank_id, - self.dist_params_grads) + Remover.remove_no_need_in_main( + self.auto_parallel_main_prog, + self.dist_context, + self.rank_id, + self.dist_params_grads, + ) # remove no need vars and ops in the startip program - Remover.remove_no_need_in_startup(self.auto_parallel_main_prog, - self.auto_parallel_startup_prog) + Remover.remove_no_need_in_startup( + self.auto_parallel_main_prog, self.auto_parallel_startup_prog + ) # reset some variable when remove operation ended Resharder.while_block_info = {} @@ -2122,47 +2605,72 @@ class Resharder: return reshard_op_cost else: dist_tensor = self.dist_context.get_dist_tensor_for_program( - tensor) + tensor + ) # simplified processing: ignore union process mesh and output reshard dist_op = self.dist_context.get_dist_op_for_program(op) dims_mapping = dist_op.dist_attr.get_input_dims_mapping( - tensor.name) + tensor.name + ) process_mesh = dist_op.dist_attr.process_mesh - dist_attr = [process_mesh, dims_mapping] + dist_attr = [ + process_mesh, + dims_mapping, + dist_op.serial_op.attr('op_role'), + ] if dist_tensor is not None and self.need_reshard( - dist_tensor, dist_attr): + dist_tensor, dist_attr + ): if tensor_name not in self._has_resharded: self._has_resharded[tensor_name] = [dist_op] else: for item in self._has_resharded[tensor_name]: item_dist_attr = item.dist_attr - item_dims_mapping = item_dist_attr.get_input_dims_mapping( - tensor_name) + item_dims_mapping = ( + item_dist_attr.get_input_dims_mapping( + tensor_name + ) + ) item_process_mesh = item_dist_attr.process_mesh - if dims_mapping == item_dims_mapping and item_process_mesh == process_mesh: + if ( + dims_mapping == item_dims_mapping + and item_process_mesh == process_mesh + ): return reshard_op_cost self._has_resharded[tensor_name].append(dist_op) - reshard_op_desc = self.find_op_desc_seq(dist_tensor, - dist_attr, - serial=True) + reshard_op_desc = self.find_op_desc_seq( + dist_tensor, dist_attr, serial=True + ) dtype = dist_tensor.serial_tensor.dtype reshard_op_cost = self.parse_op_desc_for_cost( - reshard_op_desc, dtype, cluster) + reshard_op_desc, dtype, cluster + ) return reshard_op_cost - def _concat_partitions_for_cost(self, partition_tensor_list, - partition_index, dtype, rank_id, - local_rank_comp_cost, cluster): + def _concat_partitions_for_cost( + self, + partition_tensor_list, + partition_index, + dtype, + rank_id, + local_rank_comp_cost, + cluster, + ): if not partition_tensor_list: partition_tensor_list.append(partition_index) else: i = 0 has_concat = False while i < len(partition_tensor_list): - concat_axis, first_order, new_partition = Resharder.compute_concat_info( - partition_tensor_list[i], partition_index) + ( + concat_axis, + first_order, + new_partition, + ) = Resharder.compute_concat_info( + partition_tensor_list[i], partition_index + ) if concat_axis != -1: has_concat = True concat_desc = {} @@ -2170,31 +2678,38 @@ class Resharder: concat_desc["attrs"] = {"axis": concat_axis} if first_order == 0: concat_desc["inputs"] = { - "X": [(dtype, partition_tensor_list[i]), - (dtype, partition_index)] + "X": [ + (dtype, partition_tensor_list[i]), + (dtype, partition_index), + ] } else: concat_desc["inputs"] = { - "X": [(dtype, partition_index), - (dtype, partition_tensor_list[i])] + "X": [ + (dtype, partition_index), + (dtype, partition_tensor_list[i]), + ] } partition_tensor_list.pop(i) if rank_id not in local_rank_comp_cost: local_rank_comp_cost[rank_id] = [] local_rank_comp_cost[rank_id].append( - ConcatOpCost(op_desc=concat_desc, cluster=cluster)) - self._concat_partitions_for_cost(partition_tensor_list, - new_partition, dtype, - rank_id, - local_rank_comp_cost, - cluster) + ConcatOpCost(op_desc=concat_desc, cluster=cluster) + ) + self._concat_partitions_for_cost( + partition_tensor_list, + new_partition, + dtype, + rank_id, + local_rank_comp_cost, + cluster, + ) break i += 1 if not has_concat: partition_tensor_list.append(partition_index) def parse_op_desc_for_cost(self, reshard_op_desc, dtype, cluster): - def _get_idx(comm_ranks, group_ranks): res, is_the_same = None, False idx = 0 @@ -2225,28 +2740,41 @@ class Resharder: if isinstance(op_desc, SendOpDesc): group_ranks = [key, op_desc.dst] shape = op_desc.shape - send_desc = build_comm_desc("send_v2", group_ranks, dtype, - shape) + send_desc = build_comm_desc( + "send_v2", group_ranks, dtype, shape + ) idx, is_the_same = _get_idx(comm_ranks, group_ranks) if idx is None: - comm_costs.append([ - (group_ranks, - SendOpCost(op_desc=send_desc, - comm_context=comm_context)) - ]) + comm_costs.append( + [ + ( + group_ranks, + SendOpCost( + op_desc=send_desc, + comm_context=comm_context, + ), + ) + ] + ) comm_ranks.append(set(group_ranks)) else: if not is_the_same: comm_costs[idx].append( - (group_ranks, - SendOpCost(op_desc=send_desc, - comm_context=comm_context))) + ( + group_ranks, + SendOpCost( + op_desc=send_desc, + comm_context=comm_context, + ), + ) + ) elif isinstance(op_desc, AllGatherOpDesc): # NOTE: fill_const and other unnecessary op is not calculated because those cost is very small group_ranks = op_desc.group shape = op_desc.shape - allgather_desc = build_comm_desc("c_allgather", group_ranks, - dtype, shape) + allgather_desc = build_comm_desc( + "c_allgather", group_ranks, dtype, shape + ) split_inputs_shape = [] for idx, dim in enumerate(shape): if idx == 0: @@ -2255,18 +2783,29 @@ class Resharder: split_inputs_shape.append(dim) idx, is_the_same = _get_idx(comm_ranks, group_ranks) if idx is None: - comm_costs.append([ - (group_ranks, - AllgatherOpCost(op_desc=allgather_desc, - comm_context=comm_context)) - ]) + comm_costs.append( + [ + ( + group_ranks, + AllgatherOpCost( + op_desc=allgather_desc, + comm_context=comm_context, + ), + ) + ] + ) comm_ranks.append(set(group_ranks)) else: if not is_the_same: comm_costs[idx].append( - (group_ranks, - AllgatherOpCost(op_desc=allgather_desc, - comm_context=comm_context))) + ( + group_ranks, + AllgatherOpCost( + op_desc=allgather_desc, + comm_context=comm_context, + ), + ) + ) # calc the split op cost if key not in local_rank_comp_cost: local_rank_comp_cost[key] = [] @@ -2277,19 +2816,27 @@ class Resharder: } split_desc["attrs"] = {"num": len(group_ranks), "axis": 0} local_rank_comp_cost[key].append( - SplitOpCost(op_desc=split_desc, cluster=cluster)) + SplitOpCost(op_desc=split_desc, cluster=cluster) + ) elif isinstance(op_desc, ConcatOpDesc): partition_index_list = op_desc._partition_index_list for idx, partion_idex in enumerate(partition_index_list): self._concat_partitions_for_cost( - partition_tensor_list, partion_idex, dtype, key, - local_rank_comp_cost, cluster) + partition_tensor_list, + partion_idex, + dtype, + key, + local_rank_comp_cost, + cluster, + ) elif isinstance(op_desc, SliceOpDesc): if key not in local_rank_comp_cost: local_rank_comp_cost[key] = [] - assert len( - partition_tensor_list) == 1 or not partition_tensor_list + assert ( + len(partition_tensor_list) == 1 + or not partition_tensor_list + ) to_slice_tensor_shape = [] if len(partition_tensor_list) == 1: for item in partition_tensor_list[0]: @@ -2303,13 +2850,14 @@ class Resharder: "axes": op_desc.axes, "starts": op_desc.starts, "ends": op_desc.ends, - "infer_flags": infer_flags + "infer_flags": infer_flags, } slice_desc["inputs"] = { "Input": [(dtype, to_slice_tensor_shape)] } local_rank_comp_cost[key].append( - SliceOpCost(op_desc=slice_desc, cluster=cluster)) + SliceOpCost(op_desc=slice_desc, cluster=cluster) + ) res = (comm_costs, local_rank_comp_cost) diff --git a/python/paddle/distributed/auto_parallel/strategy.py b/python/paddle/distributed/auto_parallel/strategy.py index 927aa25dbfb..f7dd7e6697b 100644 --- a/python/paddle/distributed/auto_parallel/strategy.py +++ b/python/paddle/distributed/auto_parallel/strategy.py @@ -19,7 +19,6 @@ from . import constants class BaseConfig(object): - def __init__(self, category, config_dict=None): self._category = category self._config_dict = None @@ -29,7 +28,9 @@ class BaseConfig(object): else: raise ValueError( "Expected a dictionary. But received: {}".format( - config_dict)) + config_dict + ) + ) # Initialize attributes by the default config config = constants.get_category_default_config(self._category) for field, default_value in config.items(): @@ -75,49 +76,48 @@ class BaseConfig(object): class RecomputeConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.RECOMPUTE super(RecomputeConfig, self).__init__(category, config_dict) class AMPConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.AMP super(AMPConfig, self).__init__(category, config_dict) class ShardingConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.SHARDING super(ShardingConfig, self).__init__(category, config_dict) class GradientMergeConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.GRADIENT_MERGE super(GradientMergeConfig, self).__init__(category, config_dict) -class QATConfig(BaseConfig): +class PipelineConfig(BaseConfig): + def __init__(self, config_dict=None): + category = constants.PIPELINE + super(PipelineConfig, self).__init__(category, config_dict) + +class QATConfig(BaseConfig): def __init__(self, config_dict=None): category = constants.QAT super(QATConfig, self).__init__(category, config_dict) class TuningConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.TUNING super(TuningConfig, self).__init__(category, config_dict) class DatasetConfig(BaseConfig): - def __init__(self, config_dict=None): category = constants.DATASET super(DatasetConfig, self).__init__(category, config_dict) @@ -163,7 +163,8 @@ class Strategy(BaseConfig): # self._config_dict = yaml.load(yaml_file, Loader=yaml.Loader) else: raise ValueError( - "Expected a dictionary. But received: {}".format(config)) + "Expected a dictionary. But received: {}".format(config) + ) else: self._config_dict = {} @@ -182,6 +183,9 @@ class Strategy(BaseConfig): config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None) self.gradient_merge = GradientMergeConfig(config_dict) + config_dict = self._config_dict.get(constants.PIPELINE, None) + self.pipeline = PipelineConfig(config_dict) + config_dict = self._config_dict.get(constants.QAT, None) self.qat = QATConfig(config_dict) diff --git a/python/paddle/distributed/auto_parallel/tuner/profiler.py b/python/paddle/distributed/auto_parallel/tuner/profiler.py index 4b2655028bf..1409243442d 100644 --- a/python/paddle/distributed/auto_parallel/tuner/profiler.py +++ b/python/paddle/distributed/auto_parallel/tuner/profiler.py @@ -22,8 +22,13 @@ import time import paddle from paddle.fluid.framework import Program, _current_expected_place from paddle.fluid.framework import Operator -from paddle.distributed.auto_parallel.process_group import get_all_process_groups, new_process_group -from paddle.distributed.auto_parallel.dist_loader import DistributedDataLoaderFromGenerator +from paddle.distributed.auto_parallel.process_group import ( + get_all_process_groups, + new_process_group, +) +from paddle.distributed.auto_parallel.dist_loader import ( + DistributedDataLoaderFromGenerator, +) from paddle.distributed.collective import _get_global_env paddle.enable_static() @@ -44,25 +49,32 @@ def parse_args(): "--profile_start_step", default=10, type=int, - help="integer indicates the warmup step before starting profile.") - parser.add_argument("--profile_end_step", - default=30, - type=int, - help="integer indicates at the end step of profile.") - parser.add_argument("--rank", - type=int, - required=True, - help="the rank id of the this process.") - parser.add_argument("--device_id", - type=int, - required=True, - help="the device id of the this process.") + help="integer indicates the warmup step before starting profile.", + ) + parser.add_argument( + "--profile_end_step", + default=30, + type=int, + help="integer indicates at the end step of profile.", + ) + parser.add_argument( + "--rank", + type=int, + required=True, + help="the rank id of the this process.", + ) + parser.add_argument( + "--device_id", + type=int, + required=True, + help="the device id of the this process.", + ) parser.add_argument( "--ctx_filename", type=str, required=True, - help= - "the filename to the profile context file saved by optimizaiton tuner") + help="the filename to the profile context file saved by optimizaiton tuner", + ) args = parser.parse_args() @@ -71,14 +83,14 @@ def parse_args(): def init_process_groups(group_map, rank): for group_id, ranks in group_map.items(): - if group_id == 0: + if group_id == 1000: continue new_process_group(ranks=ranks, group_id=group_id) # TODO should instantiate global group first all_process_groups = get_all_process_groups() for process_group in all_process_groups: - if process_group.id == 0 or rank not in process_group.ranks: + if process_group.id == 1000 or rank not in process_group.ranks: continue print(process_group) process_group.instantiate() @@ -109,11 +121,9 @@ def get_cpp_error_type(error): return error_type -def create_dataloader(main_program, - startup_program, - profile_ctx, - epochs=1, - steps_per_epoch=None): +def create_dataloader( + main_program, startup_program, profile_ctx, epochs=1, steps_per_epoch=None +): dataset = profile_ctx["dataset"] main_block = main_program.global_block() @@ -141,7 +151,8 @@ def create_dataloader(main_program, epochs=epochs, steps_per_epoch=steps_per_epoch, data_parallel_world_size=dataset.dp_world_size, - data_parallel_rank=dataset.dp_rank) + data_parallel_rank=dataset.dp_rank, + ) # move read op from the end of program to the start of program new_op_size = len(main_block.ops) @@ -162,8 +173,12 @@ def init_comm(profile_ctx): dist_env = profile_ctx['distributed_env'] genv = _get_global_env() genv = dist_env - print("current process rank: {}, device_id: {}, ip: {}.", genv.rank, - genv.device_id, genv.current_endpoint) + print( + "current process rank: {}, device_id: {}, ip: {}.", + genv.rank, + genv.device_id, + genv.current_endpoint, + ) # init nccl comm group_map = profile_ctx['group_map'] @@ -201,8 +216,9 @@ def profiler(args): """ # load ctx if not os.path.isfile(args.ctx_filename): - raise ValueError("There is no profile context named {}.".format( - args.ctx_filename)) + raise ValueError( + "There is no profile context named {}.".format(args.ctx_filename) + ) with open(args.ctx_filename, 'rb') as f: profile_ctx = pickle.load(f, encoding='latin1') @@ -240,8 +256,9 @@ def profiler(args): print("step: %d, loss_print: %f" % (eval_step, loss[0])) eval_step += 1 - avg_tput = 1.0 * (args.profile_end_step - - args.profile_start_step) / duration + avg_tput = ( + 1.0 * (args.profile_end_step - args.profile_start_step) / duration + ) result_dict = { "Throughtput": avg_tput, diff --git a/python/paddle/distributed/auto_parallel/utils.py b/python/paddle/distributed/auto_parallel/utils.py index 6dc722b53ea..a08a17288a4 100644 --- a/python/paddle/distributed/auto_parallel/utils.py +++ b/python/paddle/distributed/auto_parallel/utils.py @@ -1874,6 +1874,12 @@ def initialize_pg_in_full_mode(all_process_groups, cur_rank): ) ) break + print( + "***process_group: id:", + process_group.id, + "rank:", + process_group.ranks, + ) process_group.instantiate() server_socket.close() diff --git a/python/paddle/distributed/fleet/fleet_executor_utils.py b/python/paddle/distributed/fleet/fleet_executor_utils.py index f5a1d8b1814..a53f2e73511 100644 --- a/python/paddle/distributed/fleet/fleet_executor_utils.py +++ b/python/paddle/distributed/fleet/fleet_executor_utils.py @@ -22,38 +22,47 @@ class TaskNode: Python side TaskNode, connection to the c++ side TaskNode """ - def __init__(self, - rank, - max_run_times, - max_slot_times, - role=None, - node_type=None, - task_id=0, - ops=None, - program=None, - lazy_initialize=False): + def __init__( + self, + rank, + max_run_times, + role=None, + node_type=None, + task_id=0, + ops=None, + program=None, + lazy_initialize=False, + cond_var_name=None, + vars_to_dtype=None, + vars_to_shape=None, + ): """ :param rank (int): Current rank of the task node. :param max_run_times (int): The max run times of the task node. - :param max_slot_times (int): The mas slot times of the task node. :param role (int): The role of the task node. (Will be removed in the future) :param node_type (str): The type of the task node. :param task_id (int): The id of task node. - :param ops (list): A list of op.desc to init the task node. (Will be removed in the future) + :param ops (list): A list of op.desc to init the task node. (Will be removed in the future) :param program (Program): An instance of Program to init the task node. :param lazy_initialize (bool): In user-defined task, the program may change adding feed/fetch op. As efficient consideration, the task node will have the C++ object later. + :param cond_var_name (string): Indicate the cond var name of while. + :param vars_list (list): A list of var name to send. """ - assert ((ops is not None) ^ (program is not None)), \ - "Should provide only one of ops or program to task node." - assert (not ((ops is not None) and lazy_initialize)), \ - "Lazy initialization doesn't support with ops list" + assert (ops is not None) ^ ( + program is not None + ), "Should provide only one of ops or program to task node." + assert not ( + (ops is not None) and lazy_initialize + ), "Lazy initialization doesn't support with ops list" self.id = int(task_id) self.rank = rank self.max_run_times = max_run_times - self.max_slot_times = max_slot_times self.node_type = node_type self.program = program self.lazy_initialize = lazy_initialize + self.cond_var_name = cond_var_name + self.vars_to_dtype = vars_to_dtype + self.vars_to_shape = vars_to_shape self.run_pre_steps = None self.run_at_offset = None self.node = None @@ -61,40 +70,63 @@ class TaskNode: self.downstreams = [] if not lazy_initialize: if ops is not None: - assert role is not None and task_id is not None, \ - "If init task node with ops, should provide `role` and `task_id`." - self.node = core.TaskNode(role, ops, rank, task_id, - max_run_times, max_slot_times) + assert ( + role is not None and task_id is not None + ), "If init task node with ops, should provide `role` and `task_id`." + self.node = core.TaskNode( + role, + ops, + rank, + task_id, + max_run_times, + ) else: - self.node = core.TaskNode(program.desc, rank, self.id, - max_run_times, max_slot_times) + self.node = core.TaskNode( + program.desc, + rank, + self.id, + max_run_times, + ) if self.node_type: self.node.set_type(self.node_type) def task_node(self): if self.lazy_initialize: - self.node = core.TaskNode(self.program.desc, self.rank, self.id, - self.max_run_times, self.max_slot_times) + self.node = core.TaskNode( + self.program.desc, + self.rank, + self.id, + self.max_run_times, + ) if self.node_type: self.node.set_type(self.node_type) if self.run_pre_steps: self.node.set_run_pre_steps(self.run_pre_steps) if self.run_at_offset: self.node.set_run_at_offset(self.run_at_offset) + if self.cond_var_name: + self.node.set_cond_var_name(self.cond_var_name) + if self.vars_to_shape: + self.node.set_vars_to_shape(self.vars_to_shape) + if self.vars_to_dtype: + self.node.set_vars_to_dtype(self.vars_to_dtype) for up in self.upstreams: - self.node.add_upstream_task(up[0], up[1]) + self.node.add_upstream_task(up[0], up[1], up[2]) for down in self.downstreams: - self.node.add_downstream_task(down[0], down[1]) + self.node.add_downstream_task(down[0], down[1], down[2]) self.lazy_initialize = False return self.node def set_program(self, program): - assert self.lazy_initialize, \ - "Inside program is unchangable for immediate initialized task node. Set the lazy_initialize to be true if the inside program need to be update. Remember to do all your change before eval node.task_node()." + assert ( + self.lazy_initialize + ), "Inside program is unchangable for immediate initialized task node. Set the lazy_initialize to be true if the inside program need to be update. Remember to do all your change before eval node.task_node()." self.program = program def get_program(self): - assert self.program is not None, "The task node is not initialized using program" + assert ( + self.program is not None + ), "The task node is not initialized using program" return self.program def set_run_pre_steps(self, steps): @@ -109,17 +141,21 @@ class TaskNode: else: self.node.set_run_at_offset(offset) - def add_upstream_task(self, upstream, buffer_size=2): + def add_upstream_task( + self, upstream, buffer_size=2, depend_type=core.DependType.NORMAL + ): if self.lazy_initialize: - self.upstreams.append((upstream, buffer_size)) + self.upstreams.append((upstream, buffer_size, depend_type)) else: - self.node.add_upstream_task(upstream, buffer_size) + self.node.add_upstream_task(upstream, buffer_size, depend_type) - def add_downstream_task(self, downstream, buffer_size=2): + def add_downstream_task( + self, downstream, buffer_size=2, depend_type=core.DependType.NORMAL + ): if self.lazy_initialize: - self.downstreams.append((downstream, buffer_size)) + self.downstreams.append((downstream, buffer_size, depend_type)) else: - self.node.add_downstream_task(downstream, buffer_size) + self.node.add_downstream_task(downstream, buffer_size, depend_type) def task_id(self): return self.id @@ -142,10 +178,16 @@ class CoordSys: :param coord: The coord to be tested :return: False if valid, True if invalid. """ - return coord['mp_idx'] < 0 or coord['mp_idx'] >= self.mp_degree or \ - coord['sharding_idx'] < 0 or coord['sharding_idx'] >= self.sharding_degree or \ - coord['pp_idx'] < 0 or coord['pp_idx'] >= self.pp_degree or \ - coord['dp_idx'] < 0 or coord['dp_idx'] >= self.dp_degree + return ( + coord['mp_idx'] < 0 + or coord['mp_idx'] >= self.mp_degree + or coord['sharding_idx'] < 0 + or coord['sharding_idx'] >= self.sharding_degree + or coord['pp_idx'] < 0 + or coord['pp_idx'] >= self.pp_degree + or coord['dp_idx'] < 0 + or coord['dp_idx'] >= self.dp_degree + ) def coord_to_rank(self, coord): """ @@ -155,9 +197,15 @@ class CoordSys: """ if self._invalide_coord(coord): return -1 - return int(coord['dp_idx'] * self.pp_degree * self.sharding_degree * self.mp_degree + \ - coord['pp_idx'] * self.sharding_degree * self.mp_degree + \ - coord['sharding_idx'] * self.mp_degree + coord['mp_idx']) + return int( + coord['dp_idx'] + * self.pp_degree + * self.sharding_degree + * self.mp_degree + + coord['pp_idx'] * self.sharding_degree * self.mp_degree + + coord['sharding_idx'] * self.mp_degree + + coord['mp_idx'] + ) def rank_to_coord(self, rank): """ @@ -176,17 +224,14 @@ class CoordSys: 'mp_idx': int(mp_idx), 'sharding_idx': int(sharding_idx), 'pp_idx': int(pp_idx), - 'dp_idx': int(dp_idx) + 'dp_idx': int(dp_idx), } class FleetExecutorUtils: - - def __init__(self, - dist_strategy=None, - rank=None, - nrank=None, - max_run_times=None): + def __init__( + self, dist_strategy=None, rank=None, nrank=None, max_run_times=None + ): self.dist_strategy = dist_strategy self.rank = rank self.nrank = nrank @@ -206,12 +251,14 @@ class FleetExecutorUtils: return op_role == int(OpRole.Optimize.LRSched) def is_forward_op(self, op_role): - return (op_role == int(OpRole.Forward)) or \ - (op_role == (int(OpRole.Forward) | int(OpRole.Loss))) + return (op_role == int(OpRole.Forward)) or ( + op_role == (int(OpRole.Forward) | int(OpRole.Loss)) + ) def is_backward_op(self, op_role): - return (op_role == int(OpRole.Backward)) or \ - (op_role == (int(OpRole.Backward) | int(OpRole.Loss))) + return (op_role == int(OpRole.Backward)) or ( + op_role == (int(OpRole.Backward) | int(OpRole.Loss)) + ) def split_program_to_op_list(self, program): op_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []} @@ -233,17 +280,19 @@ class FleetExecutorUtils: return op_list_map def convert_op_list_to_program(self, op_list, complete_program): - #TODO(liyurui): Complete this convert logic + # TODO(liyurui): Complete this convert logic program_map = { "lr": Program(), "fwd": Program(), "bwd": Program(), - "opt": Program() + "opt": Program(), } return program_map def build_1f1b_dependency(self, task_node_map): - assert not self.is_auto_parallel, "Handly add dependency should not be invoked in auto parallel mode" + assert ( + not self.is_auto_parallel + ), "Handly add dependency should not be invoked in auto parallel mode" # Generated the dependency based on this graph: # lr(1:m) -> forward -> backward -> (m:1)optimize # ↑ ↓ @@ -253,8 +302,9 @@ class FleetExecutorUtils: # add dependency intra stage cur_start_id = self.rank * self.num_of_functionality - pp_buff_size = int(self.dist_strategy['pp_degree'] - - self.coord['pp_idx']) + pp_buff_size = int( + self.dist_strategy['pp_degree'] - self.coord['pp_idx'] + ) task_node_map["lr"].add_downstream_task(cur_start_id + 1) task_node_map["fwd"].add_upstream_task(cur_start_id) task_node_map["fwd"].add_downstream_task(cur_start_id + 2, pp_buff_size) @@ -267,8 +317,8 @@ class FleetExecutorUtils: downstream_coord['pp_idx'] = downstream_coord['pp_idx'] + 1 pp_upstream = self.coord_sys.coord_to_rank(upstream_coord) pp_downstream = self.coord_sys.coord_to_rank(downstream_coord) - first_stage = (pp_upstream == -1) - last_stage = (pp_downstream == -1) + first_stage = pp_upstream == -1 + last_stage = pp_downstream == -1 prev_pp_start_id = pp_upstream * self.num_of_functionality next_pp_start_id = pp_downstream * self.num_of_functionality if not first_stage: @@ -280,33 +330,36 @@ class FleetExecutorUtils: return task_node_map def construct_task_nodes_1f1b(self, program_map): - max_slot_times = int(self.max_run_times - self.coord['pp_idx']) cur_start_id = int(self.rank * self.num_of_functionality) - lr_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - program=program_map["lr"], - task_id=cur_start_id) - fwd_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - program=program_map["fwd"], - task_id=cur_start_id + 1) - bwd_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - program=program_map["bwd"], - task_id=cur_start_id + 2) - opt_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - program=program_map["opt"], - task_id=cur_start_id + 3) + lr_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + program=program_map["lr"], + task_id=cur_start_id, + ) + fwd_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + program=program_map["fwd"], + task_id=cur_start_id + 1, + ) + bwd_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + program=program_map["bwd"], + task_id=cur_start_id + 2, + ) + opt_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + program=program_map["opt"], + task_id=cur_start_id + 3, + ) return { "lr": lr_task_node, "fwd": fwd_task_node, "bwd": bwd_task_node, - "opt": opt_task_node + "opt": opt_task_node, } def task_id_to_rank(self): @@ -317,53 +370,58 @@ class FleetExecutorUtils: return task_id_to_rank def construct_task_nodes_1f1b_op_list(self, op_list_map): - max_slot_times = int(self.max_run_times - self.coord['pp_idx']) cur_start_id = int(self.rank * self.num_of_functionality) - lr_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - role=int(OpRole.Optimize.LRSched), - ops=op_list_map["lr"], - task_id=cur_start_id, - node_type="Amplifier") + lr_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + role=int(OpRole.Optimize.LRSched), + ops=op_list_map["lr"], + task_id=cur_start_id, + node_type="Amplifier", + ) lr_task_node.set_run_pre_steps(self.max_run_times) - fwd_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - role=int(OpRole.Forward), - ops=op_list_map["fwd"], - task_id=cur_start_id + 1, - node_type="Compute") - bwd_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - role=int(OpRole.Backward), - ops=op_list_map["bwd"], - task_id=cur_start_id + 2, - node_type="Compute") - opt_task_node = TaskNode(rank=self.rank, - max_run_times=self.max_run_times, - max_slot_times=max_slot_times, - role=int(OpRole.Optimize), - ops=op_list_map["opt"], - task_id=cur_start_id + 3, - node_type="Amplifier") + fwd_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + role=int(OpRole.Forward), + ops=op_list_map["fwd"], + task_id=cur_start_id + 1, + node_type="Compute", + ) + bwd_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + role=int(OpRole.Backward), + ops=op_list_map["bwd"], + task_id=cur_start_id + 2, + node_type="Compute", + ) + opt_task_node = TaskNode( + rank=self.rank, + max_run_times=self.max_run_times, + role=int(OpRole.Optimize), + ops=op_list_map["opt"], + task_id=cur_start_id + 3, + node_type="Amplifier", + ) opt_task_node.set_run_pre_steps(self.max_run_times) opt_task_node.set_run_at_offset(self.max_run_times - 1) return { "lr": lr_task_node, "fwd": fwd_task_node, "bwd": bwd_task_node, - "opt": opt_task_node + "opt": opt_task_node, } -def run1f1b(program, - rank, - max_run_times, - dist_opt, - nrank, - with_standalone_executor=False): +def run1f1b( + program, + rank, + max_run_times, + dist_opt, + nrank, + with_standalone_executor=False, +): """ Split the program to support 1f1b pipeline scheduler. This funct will split the program based on the op_role. @@ -380,24 +438,29 @@ def run1f1b(program, task_id_to_rank (dict): task nodes' ids to it's corresponding rank """ print("fleet executor will use python side 1f1b scheduler.") - fleet_executor_utils = FleetExecutorUtils(dist_strategy=dist_opt, - rank=rank, - nrank=nrank, - max_run_times=max_run_times) + fleet_executor_utils = FleetExecutorUtils( + dist_strategy=dist_opt, + rank=rank, + nrank=nrank, + max_run_times=max_run_times, + ) op_list_map = fleet_executor_utils.split_program_to_op_list(program) task_node_map = None if with_standalone_executor: program_map = fleet_executor_utils.convert_op_list_to_program( - op_list_map, program) + op_list_map, program + ) task_node_map = fleet_executor_utils.construct_task_nodes_1f1b( - program_map) + program_map + ) else: op_desc_list_map = {"lr": [], "fwd": [], "bwd": [], "opt": []} for key in op_list_map: for op in op_list_map[key]: op_desc_list_map[key].append(op.desc) task_node_map = fleet_executor_utils.construct_task_nodes_1f1b_op_list( - op_desc_list_map) + op_desc_list_map + ) task_node_map = fleet_executor_utils.build_1f1b_dependency(task_node_map) task_id_to_rank = fleet_executor_utils.task_id_to_rank() task_node_list = [task_node_map[key].task_node() for key in task_node_map] @@ -414,10 +477,11 @@ def origin(program, rank): task_id_to_rank (dict): a fake dict, since there is no upstream or downstream, this dict won't be used """ print("fleet executor will use python side origin scheduler.") - task_node = TaskNode(program=program, - rank=rank, - node_type="Compute", - max_run_times=1, - max_slot_times=1) + task_node = TaskNode( + program=program, + rank=rank, + node_type="Compute", + max_run_times=1, + ) task_id_to_rank = {task_node.task_id(): rank} return [task_node.task_node()], task_id_to_rank diff --git a/python/paddle/distributed/parallel.py b/python/paddle/distributed/parallel.py index 8c7187236d4..fa11aeb0a85 100644 --- a/python/paddle/distributed/parallel.py +++ b/python/paddle/distributed/parallel.py @@ -56,6 +56,13 @@ ParallelStrategy = core.ParallelStrategy _global_parallel_env = None +def _is_global_parallel_initialize(): + global _global_parallel_env + if _global_parallel_env is None: + return False + return True + + def _get_global_parallel_env(): global _global_parallel_env if _global_parallel_env is None: diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 5f721a1df50..1d1dab90cf3 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -22,6 +22,7 @@ from .auto_parallel_recompute import * from .auto_parallel_quantization import * from .auto_parallel_data_parallel_optimization import * from .auto_parallel_grad_clip import * +from .auto_parallel_pipeline import * from .cpp_pass import * import os from .ps_trainer_pass import * diff --git a/python/paddle/distributed/passes/auto_parallel_grad_clip.py b/python/paddle/distributed/passes/auto_parallel_grad_clip.py index 34c0b7d56a0..727dd0fcd2e 100644 --- a/python/paddle/distributed/passes/auto_parallel_grad_clip.py +++ b/python/paddle/distributed/passes/auto_parallel_grad_clip.py @@ -21,8 +21,17 @@ from paddle.fluid import core from .pass_base import PassBase, register_pass from ..auto_parallel.reshard import Resharder from ..auto_parallel.process_group import get_world_process_group -from ..auto_parallel.utils import is_gradient_clip_op, is_optimize_op, OP_ROLE_KEY, OpRole, _get_comm_group -from ..auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute +from ..auto_parallel.utils import ( + is_gradient_clip_op, + is_optimize_op, + OP_ROLE_KEY, + OpRole, + _get_comm_group, +) +from ..auto_parallel.dist_attribute import ( + TensorDistributedAttribute, + OperatorDistributedAttribute, +) def _get_params_grads(block): @@ -53,7 +62,8 @@ def _get_dpmp_topology(origin_topology, sharding_group): """ sharding_axis = 1 dp_sharding_topology = [ - origin_topology[0] // sharding_group.nranks, sharding_group.nranks + origin_topology[0] // sharding_group.nranks, + sharding_group.nranks, ] if dp_sharding_topology[0] == 1: sharding_axis = 0 @@ -109,22 +119,24 @@ def _get_dpmp_process_mesh(rank_id, topology, processes, sharding_group): return dpmp_topology, list(dpmp_processes_in_sharding) -def _is_about_global_norm(rank_id, tensor_shape, topology, processes, - dims_mapping, sharding_group): +def _is_about_global_norm( + rank_id, tensor_shape, topology, processes, dims_mapping, sharding_group +): # get current process_mesh where the parameter exist. dpmp_topology, dpmp_processes = _get_dpmp_process_mesh( - rank_id, topology, processes, sharding_group) + rank_id, topology, processes, sharding_group + ) - complete_shape = Resharder.compute_complete_shape(tensor_shape, - dpmp_topology, - dims_mapping) + complete_shape = Resharder.compute_complete_shape( + tensor_shape, dpmp_topology, dims_mapping + ) complete_partitions = [] complete_param_ranks = [] for process in dpmp_processes: partition_index = Resharder.compute_partition_index( - process, complete_shape, dims_mapping, dpmp_topology, - dpmp_processes) + process, complete_shape, dims_mapping, dpmp_topology, dpmp_processes + ) if partition_index not in complete_partitions: complete_partitions.append(partition_index) complete_param_ranks.append(process) @@ -133,7 +145,6 @@ def _is_about_global_norm(rank_id, tensor_shape, topology, processes, class ClipHelper(object): - def __init__(self, params_grads, rank_id, block, dist_context): params, _ = zip(*params_grads) self.params = list(params) @@ -155,9 +166,14 @@ class ClipHelper(object): topology = dist_attr.process_mesh.topology processes = dist_attr.process_mesh.processes dims_mapping = dist_attr.dims_mapping - return _is_about_global_norm(self.rank_id, param.shape, topology, - processes, dims_mapping, - self.sharding_group) + return _is_about_global_norm( + self.rank_id, + param.shape, + topology, + processes, + dims_mapping, + self.sharding_group, + ) def _get_dist_attr(self, name): var = self.block.vars[name] @@ -182,7 +198,8 @@ class ClipHelper(object): in_dist_attr.process_mesh = self.world_ranks in_dist_attr.dims_mapping = [-1] self.dist_context.set_tensor_dist_attr_for_program( - in_var, in_dist_attr) + in_var, in_dist_attr + ) op_dist_attr.set_input_dist_attr(in_name, in_dist_attr) for out_name in op.output_arg_names: out_var = self.block.vars[out_name] @@ -190,7 +207,8 @@ class ClipHelper(object): out_dist_attr.process_mesh = self.world_ranks out_dist_attr.dims_mapping = [-1] self.dist_context.set_tensor_dist_attr_for_program( - out_var, out_dist_attr) + out_var, out_dist_attr + ) op_dist_attr.set_output_dist_attr(out_name, out_dist_attr) self.dist_context.set_op_dist_attr_for_program(op, op_dist_attr) @@ -229,14 +247,18 @@ class ClipGradByGloblNormPass(PassBase): dist_params_grads = self.get_attr("params_grads", None) # dist_params_grads = _get_params_grads(block) - self.clip_helper = ClipHelper(dist_params_grads, rank_id, block, - dist_context) + self.clip_helper = ClipHelper( + dist_params_grads, rank_id, block, dist_context + ) self._remove_no_need_ops_vars(block) def _remove_no_need_ops_vars(self, block): removed_op_out_type = [ - 'clip_by_norm', 'squared_l2_norm', 'square', 'reduce_sum' + 'clip_by_norm', + 'squared_l2_norm', + 'square', + 'reduce_sum', ] removed_op_idx = set() @@ -249,12 +271,14 @@ class ClipGradByGloblNormPass(PassBase): input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: #'clip_by_norm', 'squared_l2_norm', 'square' - param_name = input_name[:input_name.find("@GRAD")] + param_name = input_name[: input_name.find("@GRAD")] is_local = self.clip_helper._is_local_param(param_name) is_calculate = self.clip_helper._is_calcuate_norm( - param_name) - if not is_local or (not is_calculate - and op.type != 'clip_by_norm'): + param_name + ) + if not is_local or ( + not is_calculate and op.type != 'clip_by_norm' + ): removed_op_idx.add(idx) removed_tmp_var.update(set(op.output_arg_names)) else: @@ -266,20 +290,23 @@ class ClipGradByGloblNormPass(PassBase): elif op.type == 'elementwise_mul': input_name = op.input("X")[0] if input_name.find("@GRAD") != -1: - param_name = input_name[:input_name.find("@GRAD")] + param_name = input_name[: input_name.find("@GRAD")] is_local = self.clip_helper._is_local_param(param_name) if not is_local: removed_op_idx.add(idx) if block.ops[idx - 1].type == 'cast': removed_op_idx.add(idx - 1) removed_tmp_var.update( - set(block.ops[idx - 1].output_arg_names)) + set(block.ops[idx - 1].output_arg_names) + ) elif op.type == 'sum': reserved_vars = [] for input_name in op.input_arg_names: - if input_name not in removed_tmp_var and \ - self.clip_helper._is_local_var(input_name): + if ( + input_name not in removed_tmp_var + and self.clip_helper._is_local_var(input_name) + ): reserved_vars.append(input_name) if not reserved_vars: removed_op_idx.add(idx) @@ -287,7 +314,8 @@ class ClipGradByGloblNormPass(PassBase): if block.ops[idx + 1].type == 'cast': removed_op_idx.add(idx + 1) removed_tmp_var.update( - set(block.ops[idx + 1].output_arg_names)) + set(block.ops[idx + 1].output_arg_names) + ) else: op.desc.set_input("X", reserved_vars) @@ -321,10 +349,12 @@ class ClipGradByGloblNormPass(PassBase): 'dtype': input_var.dtype, 'value': 0, 'force_cpu': False, - OP_ROLE_KEY: OpRole.Optimize - }) - fill_constant_op._set_attr('op_namescope', - "/gradient_clip_pass") + OP_ROLE_KEY: OpRole.Optimize, + }, + ) + fill_constant_op._set_attr( + 'op_namescope', "/gradient_clip_pass" + ) offset += 1 self.clip_helper._init_dist_attr(fill_constant_op) @@ -334,12 +364,14 @@ class ClipGradByGloblNormPass(PassBase): inputs={'X': [input_var]}, outputs={'Out': [input_var]}, attrs={ - 'ring_id': 0, + 'ring_id': 1000, 'use_calc_stream': True, OP_ROLE_KEY: OpRole.Optimize, - }) - allreduce_op._set_attr('op_namescope', - "/gradient_clip_pass") + }, + ) + allreduce_op._set_attr( + 'op_namescope', "/gradient_clip_pass" + ) self.clip_helper._init_dist_attr(allreduce_op) for varname in removed_tmp_var: diff --git a/python/paddle/distributed/passes/auto_parallel_pipeline.py b/python/paddle/distributed/passes/auto_parallel_pipeline.py new file mode 100644 index 00000000000..982fd7c228a --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_pipeline.py @@ -0,0 +1,635 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from logging import exception +import os + +from paddle.fluid import core +from .pass_base import PassBase, register_pass +from paddle.fluid.framework import Program, Parameter +from paddle.distributed.fleet.fleet_executor_utils import TaskNode +from paddle.distributed.fleet.meta_optimizers.common import OpRole + +from paddle.distributed.auto_parallel.utils import ( + is_forward_op, + is_backward_op, + is_optimize_op, + is_lr_sched_op, +) + + +__not_shape_var_type__ = [ + core.VarDesc.VarType.READER, + core.VarDesc.VarType.STEP_SCOPES, + core.VarDesc.VarType.LOD_TENSOR_ARRAY, + core.VarDesc.VarType.FEED_MINIBATCH, + core.VarDesc.VarType.FETCH_LIST, +] + + +@register_pass("auto_parallel_pipeline") +class PipelinePass(PassBase): + def __init__(self): + super(PipelinePass, self).__init__() + self.set_attr("dist_context", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + self._dist_context = self.get_attr("dist_context") + self._acc_steps = self.get_attr("accumulate_steps") + self._mode = self.get_attr("schedule_mode") + self._gen_bsz = self.get_attr("generation_batch_size") + self._program = main_program + + if self._mode == "1F1B": + raise NotImplementedError("1F1B has not been implemented") + elif self._mode == "F-Then-B": + raise NotImplementedError("F-Then-B has not been implemented") + elif self._mode == "stream": + self._insert_sync_ops_for_stream() + self._task_stream() + else: + raise ValueError( + "Now only 'F-then-B', '1F1B' and 'stream' are supported." + "The given value is {}.".format(self._mode) + ) + + def _insert_sync_ops_for_stream(self): + + for block in self._program.blocks: + offset = 0 + send_vars = [] + # insert sync ops + for index, op in enumerate(list(block.ops)): + if op.type == 'send_v2': + # step1: set 'use_calc_stream' False + op._set_attr("use_calc_stream", False) + op_role = op.attr('op_role') + # step2: insert 'c_sync_calc_stream' op before 'send_v2' op + var_name = op.input_arg_names[0] + var = block.var(var_name) + block._insert_op_without_sync( + index=index + offset, + type="c_sync_calc_stream", + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={'op_role': op_role}, + ) + offset += 1 + send_vars.append(var_name) + + for var_name in send_vars: + nop_op = block.append_op(type='nop') + nop_op.desc.set_input('X', [var_name]) + nop_op.desc.set_output('Out', [var_name]) + + block._sync_with_cpp() + + def _create_param(self, dst_block, src_var): + copied_kwargs = {} + copied_kwargs['trainable'] = src_var.trainable + copied_kwargs['optimize_attr'] = src_var.optimize_attr + copied_kwargs['regularizer'] = src_var.regularizer + copied_kwargs['do_model_average'] = src_var.do_model_average + copied_kwargs['need_clip'] = src_var.need_clip + + Parameter( + block=dst_block, + type=src_var.type, + name=src_var.name, + shape=src_var.shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + **copied_kwargs + ) + + def _create_inter(self, dst_block, src_var): + dst_block.create_var( + type=src_var.type, + name=src_var.name, + shape=src_var.shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + ) + + def _create_var( + self, src_block, dst_block, src_varname, force_create=False + ): + + if not force_create: + src_var = src_block.var(src_varname) + else: + src_var = src_block._var_recursive(src_varname) + if src_var.type in __not_shape_var_type__: + persist = getattr(src_var, 'persistable', False) + dst_block.create_var( + type=src_var.type, + name=src_var.name, + persistable=persist, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + ) + else: + if isinstance(src_var, Parameter): + self._create_param(dst_block, src_var) + else: + self._create_inter(dst_block, src_var) + + def _create_program(self, src_block, dst_block, src_op, force_create=False): + dst_op_desc = dst_block.desc.append_op() + dst_op_desc.copy_from(src_op.desc) + for input_varname in src_op.input_arg_names: + if src_block.has_var(input_varname) or ( + force_create and src_block._find_var_recursive(input_varname) + ): + self._create_var( + src_block, dst_block, input_varname, force_create + ) + for output_varname in src_op.output_arg_names: + if src_block.has_var(output_varname) or ( + force_create and src_block._find_var_recursive(output_varname) + ): + self._create_var( + src_block, dst_block, output_varname, force_create + ) + + def _get_pp_stage(self, rank): + pp_idx = None + for idx, process_mesh in enumerate(self._dist_context.process_meshes): + if rank in process_mesh.processes: + pp_idx = idx + break + return pp_idx + + def _task_stream(self): + cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) + trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',') + nrank = len(trainer_endpoints) + num_of_functionality = 5 + + # compute current pp stage + pp_stages = len(self._dist_context.process_meshes) + cur_pp_stage = self._get_pp_stage(cur_rank) + + start_prog = Program() + cond_prog = Program() + end_prog = Program() + send_prog = Program() + recv_prog = Program() + + cond_var_name = None + send_vars_name = set() + recv_vars_name = dict() + for ib, src_block in enumerate(self._program.blocks): + if ib == 0: + strat_block = start_prog.block(0) + end_block = end_prog.block(0) + + is_after_while_op = False + for op in src_block.ops: + if op.type == "while": + assert len(op.input('Condition')) == 1 + cond_var_name = op.input('Condition')[0] + is_after_while_op = True + continue + + if not is_after_while_op: + self._create_program( + src_block, strat_block, op, force_create=True + ) + else: + self._create_program( + src_block, end_block, op, force_create=True + ) + elif ib == 1: + send_block = send_prog.block(0) + recv_block = recv_prog.block(0) + + is_after_send_op = False + is_after_recv_op = False + for op in src_block.ops: + if op.type == "send_v2" and not is_after_send_op: + is_after_send_op = True + if cur_pp_stage == pp_stages - 1: + if op.type in ["c_sync_calc_stream", "nop"]: + continue + if ( + op.type not in ["recv_2", "assign"] + and op.has_attr('op_namescope') + and "/auto_parallel/reshard" + in op.attr('op_namescope') + ): + if ( + len(op.desc.input_arg_names()) > 0 + and "@RESHARD" + not in op.desc.input_arg_names()[0] + ): + send_vars_name.add( + op.desc.input_arg_names()[0] + ) + continue + if op.type == "send_v2": + continue + self._create_program( + src_block, send_block, op, force_create=True + ) + continue + + if ( + is_after_send_op + and not is_after_recv_op + and op.type == "recv_v2" + ): + is_after_recv_op = True + if op.has_attr( + 'op_namescope' + ) and "/auto_parallel/reshard" in op.attr( + 'op_namescope' + ): + var_name = op.desc.output_arg_names()[0] + index = var_name.find("@") + if index > 0: + old_var_name = var_name[:index] + else: + old_var_name = var_name + recv_vars_name[var_name] = old_var_name + if not src_block._find_var_recursive(old_var_name): + src_var = src_block._var_recursive(var_name) + recv_block.create_var( + type=src_var.type, + name=old_var_name, + shape=src_var.shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + ) + continue + + self._create_program( + src_block, recv_block, op, force_create=True + ) + continue + + if not is_after_send_op or not is_after_recv_op: + if cur_pp_stage == pp_stages - 1: + if op.type in ["c_sync_calc_stream", "nop"]: + continue + if ( + op.type not in ["recv_2", "assign"] + and op.has_attr('op_namescope') + and "/auto_parallel/reshard" + in op.attr('op_namescope') + ): + if ( + len(op.desc.input_arg_names()) > 0 + and "@RESHARD" + not in op.desc.input_arg_names()[0] + ): + send_vars_name.add( + op.desc.input_arg_names()[0] + ) + continue + if op.type == "send_v2": + continue + self._create_program( + src_block, send_block, op, force_create=True + ) + + if is_after_send_op and is_after_recv_op: + if op.has_attr( + 'op_namescope' + ) and "/auto_parallel/reshard" in op.attr( + 'op_namescope' + ): + var_name = op.desc.output_arg_names()[0] + index = var_name.find("@") + if index > 0: + old_var_name = var_name[:index] + else: + old_var_name = var_name + recv_vars_name[var_name] = old_var_name + if not src_block._find_var_recursive(old_var_name): + src_var = src_block._var_recursive(var_name) + recv_block.create_var( + type=src_var.type, + name=old_var_name, + shape=src_var.shape, + dtype=src_var.dtype, + lod_level=src_var.lod_level, + persistable=src_var.persistable, + error_clip=src_var.error_clip, + stop_gradient=src_var.stop_gradient, + is_data=src_var.is_data, + belong_to_optimizer=src_var.belong_to_optimizer, + ) + continue + + for in_name in op.desc.input_arg_names(): + if in_name in recv_vars_name: + op.desc._rename_input( + in_name, recv_vars_name[in_name] + ) + self._create_program( + src_block, recv_block, op, force_create=True + ) + else: + raise Exception("Only support generation condition.") + + start_prog._sync_with_cpp() + end_prog._sync_with_cpp() + send_prog._sync_with_cpp() + recv_prog._sync_with_cpp() + + assert cond_var_name is not None + + send_task_node_var_dtype = dict() + send_task_node_var_shape = dict() + recv_task_node_var_dtype = dict() + recv_task_node_var_shape = dict() + for var_name in list(send_vars_name): + var = send_prog.global_block().vars[var_name] + dtype = str(var.dtype) + send_task_node_var_dtype[var_name] = dtype[ + dtype.find("paddle.") + len("paddle.") : + ] + send_task_node_var_shape[var_name] = var.shape + for var_name in list(list(set(recv_vars_name.values()))): + var = recv_prog.global_block().vars[var_name] + dtype = str(var.dtype) + recv_task_node_var_dtype[var_name] = dtype[ + dtype.find("paddle.") + len("paddle.") : + ] + recv_task_node_var_shape[var_name] = var.shape + + vars_to_dtype = [] + vars_to_shape = [] + if len(send_task_node_var_dtype) > 0: + assert len(recv_task_node_var_dtype) == 0 + vars_to_dtype = send_task_node_var_dtype + vars_to_shape = send_task_node_var_shape + if len(recv_task_node_var_dtype) > 0: + assert len(send_task_node_var_dtype) == 0 + vars_to_dtype = recv_task_node_var_dtype + vars_to_shape = recv_task_node_var_shape + + start_task_node = TaskNode( + rank=cur_rank, + max_run_times=self._acc_steps, + node_type="Start", + task_id=int(cur_rank * num_of_functionality + 0), + program=start_prog, + lazy_initialize=True, + ) + cond_task_node = TaskNode( + rank=cur_rank, + max_run_times=self._acc_steps, + node_type="Cond", + task_id=int(cur_rank * num_of_functionality + 1), + program=cond_prog, + cond_var_name=cond_var_name, + lazy_initialize=True, + ) + send_task_node = TaskNode( + rank=cur_rank, + max_run_times=self._acc_steps, + node_type="Compute", + task_id=int(cur_rank * num_of_functionality + 2), + program=send_prog, + lazy_initialize=True, + ) + recv_task_node = TaskNode( + rank=cur_rank, + max_run_times=self._acc_steps, + node_type="Compute", + task_id=int(cur_rank * num_of_functionality + 3), + program=recv_prog, + lazy_initialize=True, + vars_to_dtype=vars_to_dtype, + vars_to_shape=vars_to_shape, + ) + end_task_node = TaskNode( + rank=cur_rank, + max_run_times=self._acc_steps, + node_type="Compute", + task_id=int(cur_rank * num_of_functionality + 4), + program=end_prog, + lazy_initialize=True, + ) + + # add dependencies for task nodes intra stage + inf = -1 + pp_buff_size = int(pp_stages - cur_pp_stage) + start_task_node.add_downstream_task( + cond_task_node.task_id(), self._gen_bsz + ) + print( + "Task ", + start_task_node.task_id(), + "'s downstream is:", + cond_task_node.task_id(), + ", buffer size is:", + self._gen_bsz, + ) + cond_task_node.add_upstream_task( + start_task_node.task_id(), self._gen_bsz + ) + print( + "Task ", + cond_task_node.task_id(), + "'s upstream is:", + start_task_node.task_id(), + ", buffer size is:", + self._gen_bsz, + ) + cond_task_node.add_downstream_task(send_task_node.task_id(), inf) + print( + "Task ", + cond_task_node.task_id(), + "'s downstream is:", + send_task_node.task_id(), + ", buffer size is:", + inf, + ) + send_task_node.add_upstream_task(cond_task_node.task_id(), inf) + print( + "Task ", + send_task_node.task_id(), + "'s upstream is:", + cond_task_node.task_id(), + ", buffer size is:", + inf, + ) + send_task_node.add_downstream_task( + recv_task_node.task_id(), pp_buff_size + ) + print( + "Task ", + send_task_node.task_id(), + "'s downstream is:", + recv_task_node.task_id(), + ", buffer size is:", + pp_buff_size, + ) + recv_task_node.add_upstream_task(send_task_node.task_id(), pp_buff_size) + print( + "Task ", + recv_task_node.task_id(), + "'s upstream is:", + send_task_node.task_id(), + ", buffer size is:", + pp_buff_size, + ) + recv_task_node.add_downstream_task( + cond_task_node.task_id(), inf, core.DependType.LOOP + ) + print( + "Task ", + recv_task_node.task_id(), + "'s downstream is:", + cond_task_node.task_id(), + ", buffer size is:", + inf, + ) + cond_task_node.add_upstream_task( + recv_task_node.task_id(), inf, core.DependType.LOOP + ) + print( + "Task ", + cond_task_node.task_id(), + "'s upstream is:", + recv_task_node.task_id(), + ", buffer size is:", + inf, + ) + cond_task_node.add_downstream_task( + end_task_node.task_id(), inf, core.DependType.STOP_LOOP + ) + print( + "Task ", + cond_task_node.task_id(), + "'s downstream is:", + end_task_node.task_id(), + ", buffer size is:", + inf, + ) + end_task_node.add_upstream_task( + cond_task_node.task_id(), inf, core.DependType.STOP_LOOP + ) + print( + "Task ", + end_task_node.task_id(), + "'s upstream is:", + cond_task_node.task_id(), + ", buffer size is:", + inf, + ) + + # add dependencies for task nodes inter stage + # get upstream ranks and downstream ranks of cur_rank + up_down_streams = self._dist_context.up_down_streams + pp_upstream_ranks = up_down_streams.ups(cur_rank) + pp_downstream_ranks = up_down_streams.downs(cur_rank) + + for upstream_rank in pp_upstream_ranks: + upstream_pp_stage = self._get_pp_stage(upstream_rank) + if upstream_pp_stage < pp_stages - 1: + upstream_task_id = int(upstream_rank * num_of_functionality + 2) + send_task_node.add_upstream_task(upstream_task_id) + print( + "Task ", + send_task_node.task_id(), + "'s upstream is:", + upstream_task_id, + ", buffer size is:", + 2, + ) + else: + upstream_task_id = int(upstream_rank * num_of_functionality + 3) + recv_task_node.add_upstream_task(upstream_task_id) + print( + "Task ", + recv_task_node.task_id(), + "'s upstream is:", + upstream_task_id, + ", buffer size is:", + 2, + ) + for downstream_rank in pp_downstream_ranks: + if cur_pp_stage < pp_stages - 1: + downstream_task_id = int( + downstream_rank * num_of_functionality + 2 + ) + send_task_node.add_downstream_task(downstream_task_id) + print( + "Task ", + send_task_node.task_id(), + "'s downstream is:", + downstream_task_id, + ", buffer size is:", + 2, + ) + else: + downstream_task_id = int( + downstream_rank * num_of_functionality + 3 + ) + recv_task_node.add_downstream_task(downstream_task_id) + print( + "Task ", + recv_task_node.task_id(), + "'s downstream is:", + downstream_task_id, + ", buffer size is:", + 2, + ) + + task_id_to_rank = {} + for i in range(nrank): + for j in range(num_of_functionality): + task_id_to_rank[int(i * num_of_functionality + j)] = i + self._program._pipeline_opt = { + "fleet_opt": { + 'tasks': [ + start_task_node, + cond_task_node, + send_task_node, + recv_task_node, + end_task_node, + ], + 'task_id_to_rank': task_id_to_rank, + 'num_micro_batches': self._acc_steps, + 'inference_generation': True, + } + } diff --git a/python/paddle/fluid/executor.py b/python/paddle/fluid/executor.py index c77def7ebf1..21e4cc6644d 100755 --- a/python/paddle/fluid/executor.py +++ b/python/paddle/fluid/executor.py @@ -78,7 +78,7 @@ def _switch_scope(scope): @signature_safe_contextmanager def scope_guard(scope): """ - + This function switches scope through python `with` statement. Scope records the mapping between variable names and variables ( :ref:`api_guide_Variable` ), similar to brackets in programming languages. @@ -96,7 +96,7 @@ def scope_guard(scope): None Examples: - + .. code-block:: python import paddle @@ -147,10 +147,12 @@ def as_numpy(tensor, copy=False): assert isinstance(tensor, core.LoDTensor) lod = tensor.lod() if len(lod) > 0: - raise RuntimeError("Some of your fetched tensors hold LoD information. \ + raise RuntimeError( + "Some of your fetched tensors hold LoD information. \ They can not be completely cast to Python ndarray. \ Please set the parameter 'return_numpy' as 'False' to \ - return LoDTensor itself directly.") + return LoDTensor itself directly." + ) if tensor._is_initialized(): if copy: return np.array(tensor) @@ -164,10 +166,10 @@ def dtype_is_compatible_with(first, second): """ Returns True if the first dtype can be compatible the second one. Currently, we require the two dtype's have to be same. - + Args: dtype (np.dtype|VarType|str): The type of data: float32, int64, etc. - + Returns: True if the two types are same. """ @@ -223,7 +225,7 @@ def check_feed_shape_type(var, feed, num_places=1): 2. Each non-negative number of the two dimensions are same. 3. For negative number or 'None' in a dimension, it means unknown so it is compatible with any number. - + Args: var (Variable): the Variable object feed (LoDTensor): the fed value, which must be a LoDTensor @@ -240,21 +242,29 @@ def check_feed_shape_type(var, feed, num_places=1): if diff_shape is not None: raise ValueError( 'The fed Variable %r should have dimensions = %d, shape = ' - '%r, but received fed shape %r on each device' % - (var.name, len(var.shape), var.shape, diff_shape)) + '%r, but received fed shape %r on each device' + % (var.name, len(var.shape), var.shape, diff_shape) + ) if not dtype_is_compatible_with(feed._dtype(), var.dtype): - var_dtype_format = convert_dtype(var.dtype) if isinstance( - var.dtype, core.VarDesc.VarType) else var.dtype - feed_dtype_format = convert_dtype(feed._dtype()) if isinstance( - feed._dtype(), core.VarDesc.VarType) else feed._dtype() + var_dtype_format = ( + convert_dtype(var.dtype) + if isinstance(var.dtype, core.VarDesc.VarType) + else var.dtype + ) + feed_dtype_format = ( + convert_dtype(feed._dtype()) + if isinstance(feed._dtype(), core.VarDesc.VarType) + else feed._dtype() + ) raise ValueError( - 'The data type of fed Variable %r must be %r, but received %r' % - (var.name, var_dtype_format, feed_dtype_format)) + 'The data type of fed Variable %r must be %r, but received %r' + % (var.name, var_dtype_format, feed_dtype_format) + ) return True def has_feed_operators(block, feed_targets, feed_holder_name): - """ Check whether the block already has feed operators. + """Check whether the block already has feed operators. Return false if the block does not have any feed operators. If some feed operators have been prepended to the block, check that @@ -283,20 +293,22 @@ def has_feed_operators(block, feed_targets, feed_holder_name): if feed_target_name not in feed_targets: raise Exception( "'feed_targets' does not have {} variable".format( - feed_target_name)) + feed_target_name + ) + ) else: break if feed_count > 0 and feed_count != len(feed_targets): raise Exception( - "Feed operators in program desc do not match 'feed_targets'") + "Feed operators in program desc do not match 'feed_targets'" + ) return feed_count > 0 -def has_fetch_operators(block, - fetch_targets, - fetch_holder_name, - fetch_op='fetch'): - """ Check whether the block already has fetch operators. +def has_fetch_operators( + block, fetch_targets, fetch_holder_name, fetch_op='fetch' +): + """Check whether the block already has fetch operators. Return false if the block does not have any fetch operators. If some fetch operators have been appended to the block, check that @@ -324,25 +336,25 @@ def has_fetch_operators(block, assert op.desc.output('Out')[0] == fetch_holder_name fetch_target_name = op.desc.input('X')[0] if fetch_target_name not in [ - var.desc.name() for var in fetch_targets + var.desc.name() for var in fetch_targets ]: raise Exception( "'fetch_targets' does not have {} variable".format( - fetch_target_name)) + fetch_target_name + ) + ) idx = op.desc.attr('col') assert fetch_target_name == fetch_targets[idx].desc.name() if fetch_count > 0 and fetch_count != len(fetch_targets): raise Exception( - "Fetch operators in program desc do not match 'fetch_targets'") + "Fetch operators in program desc do not match 'fetch_targets'" + ) return fetch_count > 0 -def _add_feed_fetch_ops(program, - feed, - fetch_list, - feed_var_name, - fetch_var_name, - use_fetch_v2=False): +def _add_feed_fetch_ops( + program, feed, fetch_list, feed_var_name, fetch_var_name, use_fetch_v2=False +): tmp_program = program.clone() global_block = tmp_program.global_block() @@ -353,7 +365,8 @@ def _add_feed_fetch_ops(program, feed_var = global_block.create_var( name=feed_var_name, type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) + persistable=True, + ) if fetch_var_name in global_block.vars: fetch_var = global_block.var(fetch_var_name) @@ -361,21 +374,25 @@ def _add_feed_fetch_ops(program, fetch_var = global_block.create_var( name=fetch_var_name, type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) + persistable=True, + ) # prepend feed operators if not has_feed_operators(global_block, feed, feed_var_name): for i, name in enumerate(feed): if global_block.has_var(name): out = global_block.var(name) - global_block._prepend_op(type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) + global_block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}, + ) else: warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." - % name) + % name + ) if use_fetch_v2: fetch_op = 'fetch_v2' @@ -383,22 +400,26 @@ def _add_feed_fetch_ops(program, fetch_op = 'fetch' # append fetch_operators - if not has_fetch_operators(global_block, fetch_list, fetch_var_name, - fetch_op): + if not has_fetch_operators( + global_block, fetch_list, fetch_var_name, fetch_op + ): for i, var in enumerate(fetch_list): assert isinstance(var, Variable) or isinstance( - var, six.string_types), ("Wrong type for fetch_list[%s]: %s" % - (i, type(var))) - global_block.append_op(type=fetch_op, - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) + var, six.string_types + ), "Wrong type for fetch_list[%s]: %s" % (i, type(var)) + global_block.append_op( + type=fetch_op, + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}, + ) return tmp_program -def _apply_inplace_addto_pass(program, enable_inplace, enable_addto, - skip_var_names): +def _apply_inplace_addto_pass( + program, enable_inplace, enable_addto, skip_var_names +): use_cuda = True if core.is_compiled_with_cuda() else False attrs = {"use_cuda": use_cuda, "mem_opt_skip_vars": skip_var_names} @@ -407,12 +428,14 @@ def _apply_inplace_addto_pass(program, enable_inplace, enable_addto, empty_startup_program = Program() if enable_inplace: pass_name = "buffer_shared_inplace_pass" - _apply_pass(program, empty_startup_program, pass_name, attrs, - attr_types) + _apply_pass( + program, empty_startup_program, pass_name, attrs, attr_types + ) if enable_addto and use_cuda: pass_name = "inplace_addto_op_pass" - _apply_pass(program, empty_startup_program, pass_name, attrs, - attr_types) + _apply_pass( + program, empty_startup_program, pass_name, attrs, attr_types + ) def _fetch_var(name, scope=None, return_numpy=True): @@ -441,7 +464,8 @@ def _fetch_var(name, scope=None, return_numpy=True): assert var is not None, ( "Cannot find " + name + " in scope. Perhaps you need to make the" " variable persistable by using var.persistable = True in your" - " program.") + " program." + ) tensor = var.get_tensor() if return_numpy: tensor = as_numpy(tensor, copy=True) @@ -449,7 +473,6 @@ def _fetch_var(name, scope=None, return_numpy=True): def _to_name_str(var): - def _to_str(var): if isinstance(var, Variable): return var.desc.name() @@ -474,19 +497,26 @@ def _to_name_str(var): def _is_enable_standalone_executor(): - return framework._enable_standalone_executor_ is None or framework._enable_standalone_executor_ in [ - 1, '1', True, 'True', 'true' - ] + return ( + framework._enable_standalone_executor_ is None + or framework._enable_standalone_executor_ + in [1, '1', True, 'True', 'true'] + ) def _is_dy2st_enable_standalone_executor(): return framework._dy2st_enable_standalone_executor_ in [ - 1, '1', True, 'True', 'true' + 1, + '1', + True, + 'True', + 'true', ] def _prepare_fleet_executor(): from ..distributed.fleet.proto import fleet_executor_desc_pb2 + trainer_endpoints_str = os.getenv("PADDLE_TRAINER_ENDPOINTS", "") trainer_endpoints = trainer_endpoints_str.split(',') fleet_exe_desc = fleet_executor_desc_pb2.FleetExecutorDesc() @@ -504,7 +534,8 @@ def _prepare_fleet_executor(): def _get_strong_program_cache_key_for_new_exe(program, feed, fetch_list): return program.desc.cached_hash_str() + _get_program_cache_key( - feed, fetch_list) + feed, fetch_list + ) def _get_strong_program_cache_key(program, feed, fetch_list): @@ -515,10 +546,16 @@ def _get_strong_program_cache_key(program, feed, fetch_list): block_str.append(var_name) return "\n".join(block_str) - inner_program = program._program if isinstance( - program, compiler.CompiledProgram) else program - return _get_varname_from_block(inner_program.blocks[0]) + str( - id(program)) + _get_program_cache_key(feed, fetch_list) + inner_program = ( + program._program + if isinstance(program, compiler.CompiledProgram) + else program + ) + return ( + _get_varname_from_block(inner_program.blocks[0]) + + str(id(program)) + + _get_program_cache_key(feed, fetch_list) + ) def _get_program_cache_key(feed, fetch_list): @@ -534,30 +571,35 @@ def _get_program_cache_key(feed, fetch_list): def _as_lodtensor(data, place, dtype=None): """ - Convert numpy.ndarray to Tensor, its only support Tensor without LoD information. - For higher dimensional sequence data, please use LoDTensor directly. + Convert numpy.ndarray to Tensor, its only support Tensor without LoD information. + For higher dimensional sequence data, please use LoDTensor directly. - Examples: - >>> import paddle.fluid as fluid - >>> place = fluid.CPUPlace() - >>> exe = fluid.executor(place) - >>> data = np.array(size=(100, 200, 300)) - >>> np_outs = map(lambda x: fluid.executor._as_lodtensor(x, place), data) - >>> ... + Examples: + >>> import paddle.fluid as fluid + >>> place = fluid.CPUPlace() + >>> exe = fluid.executor(place) + >>> data = np.array(size=(100, 200, 300)) + >>> np_outs = map(lambda x: fluid.executor._as_lodtensor(x, place), data) + >>> ... - Args: - data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple - data(core.Place): the place of created tensor - dtype(core.VarDesc.VarType|str): the expected data type of created tensor + Args: + data(numpy.ndarray|list|tuple|scalar): a instance of array, scalar, list or tuple + data(core.Place): the place of created tensor + dtype(core.VarDesc.VarType|str): the expected data type of created tensor - Returns: - LoDTensor - """ - #NOTE(zhiqiu): convert python builtin, like float, int, and list, to numpy ndarray + Returns: + LoDTensor + """ + # NOTE(zhiqiu): convert python builtin, like float, int, and list, to numpy ndarray if not isinstance(data, np.ndarray): - assert dtype is not None, 'The dtype should be given when feed data is not np.ndarray' - dtype = convert_dtype(dtype) if isinstance( - dtype, core.VarDesc.VarType) else dtype + assert ( + dtype is not None + ), 'The dtype should be given when feed data is not np.ndarray' + dtype = ( + convert_dtype(dtype) + if isinstance(dtype, core.VarDesc.VarType) + else dtype + ) if np.isscalar(data): data = np.array([data]).astype(dtype) elif isinstance(data, (list, tuple)): @@ -572,7 +614,9 @@ def _as_lodtensor(data, place, dtype=None): else: raise TypeError( "Convert data of type {} to Tensor is not supported".format( - type(data))) + type(data) + ) + ) # convert numpy.ndarray to tensor tensor = core.LoDTensor() @@ -581,7 +625,6 @@ def _as_lodtensor(data, place, dtype=None): class FetchHandler(object): - def __init__(self, var_dict=None, period_secs=60): assert var_dict != None self.var_dict = var_dict @@ -595,7 +638,8 @@ class FetchHandler(object): @staticmethod def help(): - print(""" + print( + """ class FetchHandlerExample(FetchHandler): def handler(self, res_dict): print(res_dict["auc"]) @@ -604,11 +648,11 @@ class FetchHandlerExample(FetchHandler): auc = Variable() var_dict = {"auc": auc} handler = FetchHandlerExample(var_dict=var_dict) -""") +""" + ) class _StandaloneExecutor(object): - def __init__(self, place, main_program, scope): self._place = core.Place() self._place.set_place(place) @@ -621,15 +665,16 @@ class _StandaloneExecutor(object): Args: feed_names(list): This parameter represents the input names of the model. fetch_list(list): This parameter represents the Tensors that need to be returned - after the model runs. The default is None. + after the model runs. The default is None. return_numpy(bool): This parameter indicates whether convert the fetched Tensors (the Tensor specified in the fetch list) to numpy.ndarray. if it is False, the type of the return value is a list of :code:`LoDTensor`. The default is True. """ fetch_list = self._check_fetch(fetch_list) - tensors = self._new_exe.run(scope, feed_names, - fetch_list)._move_to_list() + tensors = self._new_exe.run( + scope, feed_names, fetch_list + )._move_to_list() if return_numpy: return as_numpy(tensors, copy=True) else: @@ -642,10 +687,10 @@ class _StandaloneExecutor(object): def _update_feed(self, feed): """ - Update the feed dict, remove the feed item which is pruned in program. + Update the feed dict, remove the feed item which is pruned in program. Notes: This is a very low level API. Users should not use this API - directly. + directly. Args: feed(list|dict): feed dict or list. @@ -661,8 +706,9 @@ class _StandaloneExecutor(object): if not isinstance(feed, dict): raise TypeError( - "feed requires dict as its Parameter. But you passed in %s" % - (type(feed))) + "feed requires dict as its Parameter. But you passed in %s" + % (type(feed)) + ) global_block = self._main_program.global_block() for feed_name in list(feed.keys()): @@ -670,7 +716,8 @@ class _StandaloneExecutor(object): feed.pop(feed_name) warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." - % feed_name) + % feed_name + ) return feed @@ -684,19 +731,27 @@ class _StandaloneExecutor(object): fetch_var = fetch_var.name elif not isinstance(fetch_var, str): raise TypeError( - "Required fetch_var shall be str|Variable, but received {}". - format(type(fetch_var).__name__)) + "Required fetch_var shall be str|Variable, but received {}".format( + type(fetch_var).__name__ + ) + ) res.append(fetch_var) return res class _ExecutorCache(object): - class _CachedData(object): - - def __init__(self, program, feed, fetch_list, feed_var_name, - fetch_var_name, place, scope): + def __init__( + self, + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + place, + scope, + ): self.program = program self.feed = feed self.fetch_list = fetch_list @@ -712,18 +767,25 @@ class _ExecutorCache(object): # The program holds no _program, maybe it is constructed by graph. # Convert graph to program in order to generate key. self.program._program = framework.IrGraph( - self.program._graph).to_program() + self.program._graph + ).to_program() self.key = hash( _get_strong_program_cache_key_for_new_exe( - self.program._program, feed, fetch_list)) + self.program._program, feed, fetch_list + ) + ) else: self.key = hash( _get_strong_program_cache_key_for_new_exe( - self.program, feed, fetch_list)) + self.program, feed, fetch_list + ) + ) def __eq__(self, other): - return isinstance( - other, _ExecutorCache._CachedData) and self.key == other.key + return ( + isinstance(other, _ExecutorCache._CachedData) + and self.key == other.key + ) def __hash__(self): return self.key @@ -733,21 +795,41 @@ class _ExecutorCache(object): # the _ExecutorCache instance, otherwise a global cache may not be released after # the Executor instance deleted self._get_cached_program_and_executor = lru_cache(maxsize=8)( - self._get_program_and_executor) + self._get_program_and_executor + ) def clear(self): self._get_cached_program_and_executor.cache_clear() - def get_program_and_executor(self, program, feed, fetch_list, feed_var_name, - fetch_var_name, place, scope): + def get_program_and_executor( + self, + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + place, + scope, + ): return self._get_cached_program_and_executor( - self._CachedData(program, feed, fetch_list, feed_var_name, - fetch_var_name, place, scope)) + self._CachedData( + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + place, + scope, + ) + ) def _get_program_and_executor(self, cached_data): program = cached_data.program - inner_program = program._program if isinstance( - program, compiler.CompiledProgram) else program + inner_program = ( + program._program + if isinstance(program, compiler.CompiledProgram) + else program + ) feed = cached_data.feed fetch_list = cached_data.fetch_list feed_var_name = cached_data.feed_var_name @@ -757,9 +839,13 @@ class _ExecutorCache(object): # To apply IR pass, compile the Program to IrGraph and convert it back to Program if isinstance(program, compiler.CompiledProgram) or isinstance( - program._graph, compiler.CompiledProgram): - compiled_program = program if isinstance( - program, compiler.CompiledProgram) else program._graph + program._graph, compiler.CompiledProgram + ): + compiled_program = ( + program + if isinstance(program, compiler.CompiledProgram) + else program._graph + ) build_strategy = compiled_program._build_strategy # print(f"Program before convert:\n {inner_program}", flush=True) compiled_program._compile(scope, place) @@ -774,21 +860,26 @@ class _ExecutorCache(object): else: build_strategy = None from paddle.incubate.autograd import prim_enabled, prim2orig + if prim_enabled() and program == default_main_program(): prim2orig() inner_program = program - program = _add_feed_fetch_ops(program=inner_program, - feed=feed, - fetch_list=fetch_list, - feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name, - use_fetch_v2=True) + program = _add_feed_fetch_ops( + program=inner_program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + use_fetch_v2=True, + ) - if os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) in [ - 1, '1', True, 'True', 'true' - ] and not program._is_start_up_program_: + if ( + os.environ.get('FLAGS_CONVERT_GRAPH_TO_PROGRAM', None) + in [1, '1', True, 'True', 'true'] + and not program._is_start_up_program_ + ): if program.num_blocks > 1: # If there are multiple blocks in the program, subblock will not be executed with the new executor in temporary logging.warning("There are more than 1 block in program.") @@ -799,13 +890,22 @@ class _ExecutorCache(object): # standalone executor will apply buffer_shared_inplace_pass and # inplace_addto_op_pass to program according to build_strategy - enable_inplace = True if build_strategy is None or build_strategy.enable_inplace else False - enable_addto = True if build_strategy is not None and build_strategy.enable_addto else False + enable_inplace = ( + True + if build_strategy is None or build_strategy.enable_inplace + else False + ) + enable_addto = ( + True + if build_strategy is not None and build_strategy.enable_addto + else False + ) if enable_inplace or enable_addto: # inplace should skip feed and fetch var skip_var_names = eval(_get_program_cache_key(feed, fetch_list)) - _apply_inplace_addto_pass(program, enable_inplace, enable_addto, - skip_var_names) + _apply_inplace_addto_pass( + program, enable_inplace, enable_addto, skip_var_names + ) new_program = program.clone() new_exe = _StandaloneExecutor(place, new_program, scope) @@ -825,10 +925,10 @@ class Executor(object): will set the default device according to its installation version. If Paddle is CPU version, the default device would be set to `CPUPlace()` . If Paddle is GPU version, the default device would be set to `CUDAPlace(0)` . Default is None. - If ``place`` is string, it can be ``cpu``, and ``gpu:x``, where ``x`` + If ``place`` is string, it can be ``cpu``, and ``gpu:x``, where ``x`` is the index of the GPUs. Note: users only pass one Place or None to initialize Executor when using multiple-cards. Other APIs will override the cards. See - `document for multiple-cards `_ + `document for multiple-cards `_ Returns: Executor @@ -899,6 +999,7 @@ class Executor(object): self.ctx_caches = dict() self.trainer_caches = dict() self.scope_caches = dict() + self.micro_scope_cache = dict() self.var_caches = dict() self.pruned_program_caches = dict() p = core.Place() @@ -909,7 +1010,8 @@ class Executor(object): self._prepare_to_run_called = False self._auto_checkpoint_name = unique_name.generate( - "__auto_checkpoint_executor__") + "__auto_checkpoint_executor__" + ) # NOTE: Whether to use experimental executor `StandaloneExecutor`. self._enable_interpreter_core = _is_enable_standalone_executor() @@ -962,6 +1064,12 @@ class Executor(object): def _add_scope_cache(self, scope_cache_key, scope): self.scope_caches[scope_cache_key] = scope + def _add_micro_scopes_cache(self, program_cache_key, micro_scopes: list): + self.micro_scope_cache[program_cache_key] = micro_scopes + + def _get_micro_scopes_cache(self, program_cache_key): + return self.micro_scope_cache.get(program_cache_key, None) + # just for testing, will be removed later @lru_cache() def _log_force_set_program_cache(self, use_program_cache): @@ -979,8 +1087,9 @@ class Executor(object): var = global_block.var(feed_target_name) if var.dtype != core.VarDesc.VarType.STRINGS: if not isinstance(cur_feed, core.LoDTensor): - cur_feed = _as_lodtensor(cur_feed, self.place, - var.dtype) + cur_feed = _as_lodtensor( + cur_feed, self.place, var.dtype + ) check_feed_shape_type(var, cur_feed) idx = op.desc.attr('col') core.set_feed_variable(scope, cur_feed, feed_var_name, idx) @@ -1007,7 +1116,7 @@ class Executor(object): Returns: optimize_ops(list): The optimize operators splited from fetch_list. - fetch_list(list): The updated fetch_list which does not contain optimize operators. + fetch_list(list): The updated fetch_list which does not contain optimize operators. """ _optimize_ops = [] _fetch_list = [] @@ -1018,14 +1127,19 @@ class Executor(object): _optimize_ops.append(item) else: raise TypeError( - "The operator in fetch_list is not an optimize_op") - elif isinstance(item, Variable) or isinstance( - item, str) or isinstance(item, six.string_types): + "The operator in fetch_list is not an optimize_op" + ) + elif ( + isinstance(item, Variable) + or isinstance(item, str) + or isinstance(item, six.string_types) + ): _fetch_list.append(item) else: raise TypeError( "The item in fetch_list should be str, variable or optimize_op, but received %s.", - type(item)) + type(item), + ) for index, item in enumerate(fetch_list): # NOTE(zhiqiu): to support (optimizer_ops, param_and_grads) and optimizer_ops in fetch_list @@ -1037,9 +1151,10 @@ class Executor(object): elif isinstance(item, tuple): if not isinstance(item[0], (list, tuple)): raise TypeError( - "Requires fetch_list[{}][0] shall be one of (list, tuple) when type(fetch_list[{}]) is `tuple`, but received fetch_list[{}][0]'s type is `{}`." - .format(index, index, index, - type(item[0]).__name__)) + "Requires fetch_list[{}][0] shall be one of (list, tuple) when type(fetch_list[{}]) is `tuple`, but received fetch_list[{}][0]'s type is `{}`.".format( + index, index, index, type(item[0]).__name__ + ) + ) for i in item[0]: _get_targets(_optimize_ops, _fetch_list, i) else: @@ -1048,19 +1163,17 @@ class Executor(object): return _fetch_list, _optimize_ops @classmethod - def _prune_program(cls, - program, - feed=None, - fetch_list=None, - optimize_ops=None): + def _prune_program( + cls, program, feed=None, fetch_list=None, optimize_ops=None + ): """ Prune operators and variables which are not needed to generate - :code:`fetch_list` and optimize operators. - Prune operators and variables which are needed - to generate variables to be feeded. + :code:`fetch_list` and optimize operators. + Prune operators and variables which are needed + to generate variables to be feeded. Notes: This is a very low level API. Users should not use this API - directly. + directly. Args: program(Program): the origin program @@ -1114,10 +1227,10 @@ class Executor(object): @classmethod def _update_feed(cls, program, feed): """ - Update the feed dict, remove the feed item which is pruned in program. + Update the feed dict, remove the feed item which is pruned in program. Notes: This is a very low level API. Users should not use this API - directly. + directly. Args: program(Program): the pruned program. @@ -1144,7 +1257,8 @@ class Executor(object): feed.pop(feed_name) warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." - % feed_name) + % feed_name + ) elif isinstance(feed, list) or isinstance(feed, tuple): for i, each in enumerate(feed): @@ -1153,7 +1267,8 @@ class Executor(object): each.pop(feed_name) warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." - % feed_name) + % feed_name + ) return feed ''' @@ -1188,9 +1303,18 @@ class Executor(object): del trainer_instance self._default_executor.close() - def _run_parallel(self, program, scope, feed, fetch_list, fetch_var_name, - return_numpy, return_merged): + def _run_parallel( + self, + program, + scope, + feed, + fetch_list, + fetch_var_name, + return_numpy, + return_merged, + ): from paddle.optimizer.lr import LRScheduler + exe = program._executor # TODO(zhenghuihuang): quantization uses Graph in CompiledProgram # instead of program. We will add support for checking Vars in Graph @@ -1205,9 +1329,11 @@ class Executor(object): if not isinstance(feed_tensor, core.LoDTensor): # always set to CPU place, since the tensor need to be split # it is fast in CPU - feed_tensor = _as_lodtensor(feed[feed_name], - core.CPUPlace(), - var.dtype if var else None) + feed_tensor = _as_lodtensor( + feed[feed_name], + core.CPUPlace(), + var.dtype if var else None, + ) if need_check_feed: check_feed_shape_type(var, feed_tensor, exe.device_count()) feed_tensor_dict[feed_name] = feed_tensor @@ -1218,16 +1344,20 @@ class Executor(object): for i, each in enumerate(feed): if not isinstance(each, dict): raise TypeError( - "Each element of feed list should be a dict") + "Each element of feed list should be a dict" + ) res_dict = dict() for feed_name in each: tensor = each[feed_name] - var = global_block.var( - feed_name) if need_check_feed else None + var = ( + global_block.var(feed_name) if need_check_feed else None + ) if not isinstance(tensor, core.LoDTensor): - tensor = _as_lodtensor(each[feed_name], - program._places[i], - var.dtype if var else None) + tensor = _as_lodtensor( + each[feed_name], + program._places[i], + var.dtype if var else None, + ) if need_check_feed: check_feed_shape_type(var, tensor) res_dict[feed_name] = tensor @@ -1248,23 +1378,26 @@ class Executor(object): ) else: exe.feed_and_split_tensor_into_local_scopes( - {lr_sheduler._var_name: lr_tensor}) + {lr_sheduler._var_name: lr_tensor} + ) fetch_var_names = list(map(_to_name_str, fetch_list)) tensors = exe.run(fetch_var_names, return_merged)._move_to_list() return as_numpy(tensors) if return_numpy else tensors - def run(self, - program=None, - feed=None, - fetch_list=None, - feed_var_name='feed', - fetch_var_name='fetch', - scope=None, - return_numpy=True, - use_program_cache=False, - return_merged=True, - use_prune=False): + def run( + self, + program=None, + feed=None, + fetch_list=None, + feed_var_name='feed', + fetch_var_name='fetch', + scope=None, + return_numpy=True, + use_program_cache=False, + return_merged=True, + use_prune=False, + ): """ Run the specified :code:`Program` or :code:`CompiledProgram`. It should be noted that the executor will execute all the operators in :code:`Program` or :code:`CompiledProgram` without pruning some @@ -1288,12 +1421,12 @@ class Executor(object): so the length of this list should be equal to the number of places. The default is None. fetch_list(list): This parameter represents the Tensors that need to be returned - after the model runs. The default is None. + after the model runs. The default is None. feed_var_name(str): This parameter represents the name of the input Tensor of the feed operator. The default is "feed". fetch_var_name(str): This parameter represents the name of the output Tensor of the fetch operator. The default is "fetch". - scope(Scope): the scope used to run this program, you can switch + scope(Scope): the scope used to run this program, you can switch it to different scope. default is :code:`paddle.static.global_scope()` return_numpy(bool): This parameter indicates whether convert the fetched Tensors (the Tensor specified in the fetch list) to numpy.ndarray. if it is False, @@ -1314,14 +1447,14 @@ class Executor(object): results are variant, please set :code:`return_merged` as False, which denotes that the fetched results will not be merged. The default is True, but it is just for the compatibility, and may use False as default value in the future version. - use_prune(bool): This parameter indicates whether the input :code:`Program` will be pruned. + use_prune(bool): This parameter indicates whether the input :code:`Program` will be pruned. If the parameter is True, the program will be pruned accroding to the given feed and fetch_list, - which means the operators and variables in program that generate :code:`feed` and are not - needed to generate :code:`fetch_list` will be pruned. The default is False, which means the + which means the operators and variables in program that generate :code:`feed` and are not + needed to generate :code:`fetch_list` will be pruned. The default is False, which means the program will not pruned and all the operators and variables will be executed during running. - Note that if the tuple returned from :code:`Optimizer.minimize()` is passed to :code:`fetch_list`, + Note that if the tuple returned from :code:`Optimizer.minimize()` is passed to :code:`fetch_list`, :code:`use_prune` will be overrided to True, and the program will be pruned. - + Returns: List: The fetched result list. @@ -1439,32 +1572,49 @@ class Executor(object): """ # Temporary FLAGS, just for testing the performance of program cache force_use_program_cache = os.environ.get( - 'FLAGS_FORCE_USE_PROGRAM_CACHE', None) + 'FLAGS_FORCE_USE_PROGRAM_CACHE', None + ) if force_use_program_cache is not None: use_program_cache = force_use_program_cache in [ - 1, '1', True, 'True', 'true' + 1, + '1', + True, + 'True', + 'true', ] self._log_force_set_program_cache(use_program_cache) try: - res = self._run_impl(program=program, - feed=feed, - fetch_list=fetch_list, - feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name, - scope=scope, - return_numpy=return_numpy, - use_program_cache=use_program_cache, - use_prune=use_prune, - return_merged=return_merged) + res = self._run_impl( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + scope=scope, + return_numpy=return_numpy, + use_program_cache=use_program_cache, + use_prune=use_prune, + return_merged=return_merged, + ) core.update_autotune_status() return res except Exception as e: six.reraise(*sys.exc_info()) - def _run_impl(self, program, feed, fetch_list, feed_var_name, - fetch_var_name, scope, return_numpy, use_program_cache, - return_merged, use_prune): + def _run_impl( + self, + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + scope, + return_numpy, + use_program_cache, + return_merged, + use_prune, + ): if self._closed: raise RuntimeError("Attempted to use a closed Executor") @@ -1483,17 +1633,20 @@ class Executor(object): program=program, feed=feed, fetch_list=fetch_list, - with_standalone_executor=self. - _fleet_executor_with_standalone) + with_standalone_executor=self._fleet_executor_with_standalone, + return_numpy=return_numpy, + ) if "startup_program" in program._pipeline_opt: program = program._pipeline_opt["startup_program"] else: - return self._run_pipeline(program, - fetch_list=fetch_list, - use_program_cache=use_program_cache) + return self._run_pipeline( + program, + fetch_list=fetch_list, + use_program_cache=use_program_cache, + ) if isinstance(program, Program) and program._heter_pipeline_opt: - #print("program._heter_pipeline_opt: {}".format( + # print("program._heter_pipeline_opt: {}".format( # program._heter_pipeline_opt)) ## change default executor heter_place = program._heter_pipeline_opt["heter_place"] @@ -1503,20 +1656,26 @@ class Executor(object): self._default_executor = core.Executor(p) # TODO(zhangminxu): support heterps pipeline training using exe.run if "startup_program" in program._heter_pipeline_opt: - #print("get startup_program from _pipeline_opt") + # print("get startup_program from _pipeline_opt") program = program._heter_pipeline_opt["startup_program"] - if isinstance(program, Program) and \ - len(program.global_block().ops) == 0: + if ( + isinstance(program, Program) + and len(program.global_block().ops) == 0 + ): if use_default_main_program: - error_info = "Now you are using default_main_program, "\ - "but there are no operators in the program to be executed. "\ - "Please ensure you create model correctly or you can pass "\ + error_info = ( + "Now you are using default_main_program, " + "but there are no operators in the program to be executed. " + "Please ensure you create model correctly or you can pass " "the Program or the CompiledProgram manually." + ) else: - error_info = "There are no operators in the program to be executed. "\ - "If you pass Program manually, please use fluid.program_guard "\ + error_info = ( + "There are no operators in the program to be executed. " + "If you pass Program manually, please use fluid.program_guard " "to ensure the current Program is being used." + ) warnings.warn(error_info) if scope is None: @@ -1526,27 +1685,36 @@ class Executor(object): _origin_fetch_list = fetch_list _origin_program = program fetch_list, optimize_ops = self._split_optimize_ops_in_fetch_list( - fetch_list) + fetch_list + ) if optimize_ops: use_prune = True if use_prune: - cache_key = _get_strong_program_cache_key(program, feed, - _origin_fetch_list) + cache_key = _get_strong_program_cache_key( + program, feed, _origin_fetch_list + ) cached_pruned_program = self._get_pruned_program_cache(cache_key) if cached_pruned_program is None: if isinstance(program, compiler.CompiledProgram): program_scope_cache = self._get_pruned_program_scope_cache( - str(id(_origin_program))) + str(id(_origin_program)) + ) # copy the original program, so it can be cached. program = copy.copy(program) # share the local scopes for same original CompiledProgram. program._share_vars_from = program_scope_cache - if self._get_pruned_program_scope_cache( - str(id(_origin_program))) is None: + if ( + self._get_pruned_program_scope_cache( + str(id(_origin_program)) + ) + is None + ): self._add_pruned_program_scope_cache( - str(id(_origin_program)), program) - pruned_program = self._prune_program(program, feed, fetch_list, - optimize_ops) + str(id(_origin_program)), program + ) + pruned_program = self._prune_program( + program, feed, fetch_list, optimize_ops + ) self._add_pruned_program_cache(cache_key, pruned_program) else: pruned_program = cached_pruned_program @@ -1556,63 +1724,93 @@ class Executor(object): def _can_use_interpreter_core(program, place): if core.is_compiled_with_mlu() or isinstance( - place, core.CustomPlace): + place, core.CustomPlace + ): return False use_standalone_executor_for_compiled_program = os.environ.get( - 'FLAGS_CONVERT_GRAPH_TO_PROGRAM', - None) in [1, '1', True, 'True', 'true'] + 'FLAGS_CONVERT_GRAPH_TO_PROGRAM', None + ) in [1, '1', True, 'True', 'true'] # Only support fleet when 'FLAGS_CONVERT_GRAPH_TO_PROGRAM' is set to true from paddle.distributed.fleet import fleet - if fleet._role_maker is not None and not use_standalone_executor_for_compiled_program: - warnings.warn("Standalone executor is not used for fleet", - UserWarning) + + if ( + fleet._role_maker is not None + and not use_standalone_executor_for_compiled_program + ): + warnings.warn( + "Standalone executor is not used for fleet", UserWarning + ) return False - compiled = isinstance(program, - compiler.CompiledProgram) or isinstance( - program._graph, compiler.CompiledProgram) + compiled = isinstance( + program, compiler.CompiledProgram + ) or isinstance(program._graph, compiler.CompiledProgram) if compiled: - compiled_program = program if isinstance( - program, compiler.CompiledProgram) else program._graph + compiled_program = ( + program + if isinstance(program, compiler.CompiledProgram) + else program._graph + ) # Unsupported case 1: data parallel - if compiled_program._is_data_parallel and len( + if ( + compiled_program._is_data_parallel + and len( compiled_program._get_places( - place, compiled_program._places)) != 1: + place, compiled_program._places + ) + ) + != 1 + ): warnings.warn( "Standalone executor is not used for data parallel", - UserWarning) + UserWarning, + ) return False # Unsupported case 2: parallel graph if core.globals()['FLAGS_enable_parallel_graph'] in [ - 1, '1', True, 'True', 'true' + 1, + '1', + True, + 'True', + 'true', ]: warnings.warn( "Standalone executor is not used for parallel graph", - UserWarning) + UserWarning, + ) return False # Unsupported case 3: inference if compiled_program._is_inference: warnings.warn( "Standalone executor is not used for inference", - UserWarning) + UserWarning, + ) return False # Unsupported case 4: CUDA Graph - if compiled_program._build_strategy is not None and compiled_program._build_strategy.allow_cuda_graph_capture: + if ( + compiled_program._build_strategy is not None + and compiled_program._build_strategy.allow_cuda_graph_capture + ): warnings.warn( "Standalone executor is not used for CUDA Graph", - UserWarning) + UserWarning, + ) return False # Unsupported case 5: async mode - if compiled_program._build_strategy is not None and compiled_program._build_strategy.async_mode: + if ( + compiled_program._build_strategy is not None + and compiled_program._build_strategy.async_mode + ): warnings.warn( "Standalone executor is not used for async mode", - UserWarning) + UserWarning, + ) return False return use_standalone_executor_for_compiled_program @@ -1622,8 +1820,11 @@ class Executor(object): # NOTE: This is an experimental feature. If `export FLAGS_USE_STANDALONE_EXECUTOR=1 `, # use StandaloneExecutor to run the program. - if return_merged and self._enable_interpreter_core and _can_use_interpreter_core( - program, self.place): + if ( + return_merged + and self._enable_interpreter_core + and _can_use_interpreter_core(program, self.place) + ): if feed is None: feed = {} @@ -1633,18 +1834,27 @@ class Executor(object): if not isinstance(feed, dict): raise TypeError( "feed requires dict as its Parameter. But you passed in %s" - % (type(feed))) + % (type(feed)) + ) feed = self._update_feed(program, feed) program, new_exe = self._executor_cache.get_program_and_executor( - program, feed, fetch_list, feed_var_name, fetch_var_name, - self.place, scope) + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + self.place, + scope, + ) self._feed_data(program, feed, feed_var_name, scope) if hasattr(program, 'lr_sheduler'): from paddle.optimizer.lr import LRScheduler - assert isinstance(program.lr_sheduler, - LRScheduler), "must be LRScheduler" + + assert isinstance( + program.lr_sheduler, LRScheduler + ), "must be LRScheduler" lr_sheduler = program.lr_sheduler lr_value = lr_sheduler() lr_var = program.global_block().vars[lr_sheduler._var_name] @@ -1658,8 +1868,9 @@ class Executor(object): else: tensor._copy_from(cpu_tensor, self.place) - return new_exe.run(scope, list(feed.keys()), fetch_list, - return_numpy) + return new_exe.run( + scope, list(feed.keys()), fetch_list, return_numpy + ) compiled = isinstance(program, compiler.CompiledProgram) @@ -1674,13 +1885,15 @@ class Executor(object): varobj = global_block.vars[varname] # Can not check var build by fluid.layers.data(), bucause fluid.layers.data() had not set need_check_feed - if vardesc.persistable() == False and \ - vardesc.type() == core.VarDesc.VarType.LOD_TENSOR and \ - vardesc.need_check_feed() == True and \ - varobj.stop_gradient == True and \ - varobj.is_data == True and \ - varobj.belong_to_optimizer == False and \ - varname not in feed: + if ( + vardesc.persistable() == False + and vardesc.type() == core.VarDesc.VarType.LOD_TENSOR + and vardesc.need_check_feed() == True + and varobj.stop_gradient == True + and varobj.is_data == True + and varobj.belong_to_optimizer == False + and varname not in feed + ): raise ValueError('Need feed data for variable %s' % varname) acp._auto_checkpoint(self, program) @@ -1688,46 +1901,63 @@ class Executor(object): # For backward compatibility, run directly. if not compiled: # In distributed training, the compiled program is saved in Program._graph - has_compiled_graph = isinstance(program._graph, - compiler.CompiledProgram) + has_compiled_graph = isinstance( + program._graph, compiler.CompiledProgram + ) if has_compiled_graph: program._graph._compile(scope, self.place) # _graph in program does not support inference since the _graph is optimized # through optimizer.minimize function and should not be used as inference graph # assert not program._graph._is_inference - return self._run_parallel(program._graph, - scope=scope, - feed=feed, - fetch_list=fetch_list, - fetch_var_name=fetch_var_name, - return_numpy=return_numpy, - return_merged=return_merged) - - return self._run_program(program, - feed=feed, - fetch_list=fetch_list, - feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name, - scope=scope, - return_numpy=return_numpy, - use_program_cache=use_program_cache) + return self._run_parallel( + program._graph, + scope=scope, + feed=feed, + fetch_list=fetch_list, + fetch_var_name=fetch_var_name, + return_numpy=return_numpy, + return_merged=return_merged, + ) + + return self._run_program( + program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + scope=scope, + return_numpy=return_numpy, + use_program_cache=use_program_cache, + ) program._compile(scope, self.place) if program._is_inference: return self._run_inference(program._executor, feed) else: - return self._run_parallel(program, - scope=scope, - feed=feed, - fetch_list=fetch_list, - fetch_var_name=fetch_var_name, - return_numpy=return_numpy, - return_merged=return_merged) - - def _run_program(self, program, feed, fetch_list, feed_var_name, - fetch_var_name, scope, return_numpy, use_program_cache): + return self._run_parallel( + program, + scope=scope, + feed=feed, + fetch_list=fetch_list, + fetch_var_name=fetch_var_name, + return_numpy=return_numpy, + return_merged=return_merged, + ) + + def _run_program( + self, + program, + feed, + fetch_list, + feed_var_name, + fetch_var_name, + scope, + return_numpy, + use_program_cache, + ): from paddle.optimizer.lr import LRScheduler + if feed is None: feed = {} elif isinstance(feed, (list, tuple)): @@ -1736,19 +1966,22 @@ class Executor(object): if not isinstance(feed, dict): raise TypeError( - "feed requires dict as its Parameter. But you passed in %s" % - (type(feed))) + "feed requires dict as its Parameter. But you passed in %s" + % (type(feed)) + ) assert program is not None, "The program should not be Empty" if not isinstance(program, Program): raise TypeError( "Executor requires Program as its Parameter. But you passed in %s" - % (type(program))) + % (type(program)) + ) if not isinstance(fetch_var_name, str): raise TypeError( "The name of fetch variable requires string as its Parameter. But you passed in %s" - % (type(fetch_var_name))) + % (type(fetch_var_name)) + ) if use_program_cache: cache_key = _get_strong_program_cache_key(program, feed, fetch_list) @@ -1761,35 +1994,41 @@ class Executor(object): feed=feed, fetch_list=fetch_list, feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name) + fetch_var_name=fetch_var_name, + ) self._add_program_cache(cache_key, cached_program) fetch_list_str = list(map(_to_name_str, fetch_list)) cached_ctx = self._default_executor.prepare( - cached_program.desc, 0, fetch_list_str, False) + cached_program.desc, 0, fetch_list_str, False + ) # currently, we cache program, vars, sub_scope here # we suppose that in a life cycle of training, a user # will not create many programs. So, here the basic # rule of caching is to cache all unseen (program, var, scope) # when a user use use_program_cache. cached_scope = scope.new_scope() - self._default_executor.create_variables(cached_program.desc, - cached_scope, 0) + self._default_executor.create_variables( + cached_program.desc, cached_scope, 0 + ) self._add_ctx_cache(cache_key, cached_ctx) self._add_scope_cache(cache_key, cached_scope) program = cached_program ctx = cached_ctx scope = cached_scope else: - program = _add_feed_fetch_ops(program=program, - feed=feed, - fetch_list=fetch_list, - feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name) + program = _add_feed_fetch_ops( + program=program, + feed=feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + ) self._feed_data(program, feed, feed_var_name, scope) if hasattr(program, 'lr_sheduler'): - assert isinstance(program.lr_sheduler, - LRScheduler), "must be LRScheduler" + assert isinstance( + program.lr_sheduler, LRScheduler + ), "must be LRScheduler" lr_sheduler = program.lr_sheduler lr_value = lr_sheduler() lr_var = program.global_block().vars[lr_sheduler._var_name] @@ -1798,11 +2037,13 @@ class Executor(object): tensor.set(data, self.place) if not use_program_cache: - self._default_executor.run(program.desc, scope, 0, True, True, - [fetch_var_name]) + self._default_executor.run( + program.desc, scope, 0, True, True, [fetch_var_name] + ) else: - self._default_executor.run_prepared_ctx(ctx, scope, False, False, - False) + self._default_executor.run_prepared_ctx( + ctx, scope, False, False, False + ) arr = scope.find_var(fetch_var_name).get_fetch_list() tensors = arr._move_to_list() if return_numpy: @@ -1814,17 +2055,21 @@ class Executor(object): return exe.run(feed) def _check_fetch_list(self, fetch_list): - is_fetch_var = lambda var: isinstance(var, - (Variable, str, six.string_types)) + is_fetch_var = lambda var: isinstance( + var, (Variable, str, six.string_types) + ) is_tuple_list = lambda var: isinstance(var, (tuple, list)) - if fetch_list is None: return [] - if is_fetch_var(fetch_list): return [fetch_list] + if fetch_list is None: + return [] + if is_fetch_var(fetch_list): + return [fetch_list] - assert is_tuple_list(fetch_list), \ - "Currently , The fetch_list type only should be list or tuple, \n"\ - "but the input type is {}. For more information please refer to \n"\ + assert is_tuple_list(fetch_list), ( + "Currently , The fetch_list type only should be list or tuple, \n" + "but the input type is {}. For more information please refer to \n" "the executor.run(...).".format(type(fetch_list)) + ) res = [] for i, var in enumerate(fetch_list): @@ -1838,9 +2083,10 @@ class Executor(object): res.append(var) else: raise TypeError( - "Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}." - .format(i, - type(var).__name__)) + "Require fetch_list[{}] 's type shall be one of (Variable, str), but received {}.".format( + i, type(var).__name__ + ) + ) return res @@ -1857,25 +2103,30 @@ class Executor(object): pipeline_num = filelist_length print( "Pipeline training: setting the pipeline num to %d is enough because there are only %d files" - % (filelist_length, filelist_length)) + % (filelist_length, filelist_length) + ) if filelist_length < pipeline_num * pipeline_opt["concurrency_list"][0]: print( "Pipeline training: setting the 1st element in concurrency_list to %d is enough because there are only %d files" - % (filelist_length // pipeline_num, filelist_length)) - pipeline_opt["concurrency_list"][ - 0] = filelist_length // pipeline_num + % (filelist_length // pipeline_num, filelist_length) + ) + pipeline_opt["concurrency_list"][0] = ( + filelist_length // pipeline_num + ) dataset.set_thread(pipeline_opt["concurrency_list"][0] * pipeline_num) return pipeline_num - def _prepare_trainer(self, - program=None, - dataset=None, - scope=None, - thread=0, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100): + def _prepare_trainer( + self, + program=None, + dataset=None, + scope=None, + thread=0, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + ): is_heter = 0 use_ps_gpu = 0 if not program._fleet_opt is None: @@ -1896,16 +2147,19 @@ class Executor(object): if is_heter: from paddle.fluid.incubate.fleet.parameter_server.pslib import fleet from paddle.fluid.incubate.fleet.utils.fleet_util import FleetUtil + fu = FleetUtil() ret = fu.split_program_by_device(program) if not compiled: # TODO: Need a better way to distinguish and specify different execution mode if program._pipeline_opt: trainer = TrainerFactory()._create_trainer( - program._pipeline_opt) + program._pipeline_opt + ) elif program._heter_pipeline_opt: trainer = TrainerFactory()._create_trainer( - program._heter_pipeline_opt) + program._heter_pipeline_opt + ) else: trainer = TrainerFactory()._create_trainer(program._fleet_opt) trainer._set_thread_barrier(program._is_distributed) @@ -1915,13 +2169,16 @@ class Executor(object): else: if program._pipeline_opt: trainer = TrainerFactory()._create_trainer( - program.program._pipeline_opt) + program.program._pipeline_opt + ) elif program._heter_pipeline_opt: trainer = TrainerFactory()._create_trainer( - program.program._heter_pipeline_opt) + program.program._heter_pipeline_opt + ) else: trainer = TrainerFactory()._create_trainer( - program.program._fleet_opt) + program.program._fleet_opt + ) trainer._set_program(program.program) if thread <= 0: @@ -1930,7 +2187,8 @@ class Executor(object): elif dataset.thread_num <= 0: raise RuntimeError( "You should set thread num first, either in Dataset" - "or in Executor.train_from_dataset") + "or in Executor.train_from_dataset" + ) else: trainer._set_thread(dataset.thread_num) else: @@ -1940,19 +2198,22 @@ class Executor(object): trainer._set_fetch_var_and_info(fetch_list, fetch_info, print_period) return scope, trainer - def _run_from_dataset(self, - program=None, - dataset=None, - scope=None, - thread=0, - is_infer=False, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None): + def _run_from_dataset( + self, + program=None, + dataset=None, + scope=None, + thread=0, + is_infer=False, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + ): if program._pipeline_opt is not None: import paddle + if dataset is not None: raise RuntimeError("dataset should be None for pipeline mode") # The following fake dataset is created to call @@ -1963,24 +2224,28 @@ class Executor(object): data_vars.append(var) if core.is_compiled_with_npu(): dataset = paddle.fluid.DatasetFactory().create_dataset( - 'InMemoryDataset') + 'InMemoryDataset' + ) else: dataset = paddle.fluid.DatasetFactory().create_dataset( - 'FileInstantDataset') + 'FileInstantDataset' + ) dataset.set_batch_size(1) dataset.set_thread(1) dataset.set_filelist(['None']) dataset.set_use_var(data_vars) elif program._heter_pipeline_opt is not None: stage_id = program._heter_pipeline_opt["pipeline_stage"] - #print("test_fl_stage_id: {}".format(stage_id)) + # print("test_fl_stage_id: {}".format(stage_id)) heter_place = program._heter_pipeline_opt["heter_place"] if stage_id != 0: if "is_fl_mode" not in program._heter_pipeline_opt: import paddle + if dataset is not None: raise RuntimeError( - "dataset should be None for heter pipeline mode") + "dataset should be None for heter pipeline mode" + ) # The following fake dataset is created to call # the _prepare_trainer api, and it is meaningless. data_vars = [] @@ -1988,7 +2253,8 @@ class Executor(object): if var.is_data: data_vars.append(var) dataset = paddle.fluid.DatasetFactory().create_dataset( - 'InMemoryDataset') + 'InMemoryDataset' + ) dataset.set_batch_size(1) dataset.set_thread(1) dataset.set_filelist(['None']) @@ -1996,7 +2262,8 @@ class Executor(object): else: if dataset is None: raise RuntimeError( - "dataset is need and should be initialized") + "dataset is need and should be initialized" + ) ## change default executor heter_place = framework._get_paddle_place(heter_place) p = core.Place() @@ -2023,7 +2290,8 @@ class Executor(object): feed=[], fetch_list=real_fetch_list, feed_var_name='feed', - fetch_var_name='fetch') + fetch_var_name='fetch', + ) main_block = program._pipeline_opt["section_program"].block(0) for op in main_block.ops: # set the op_role of fetch op to Optimize to avoid @@ -2031,16 +2299,19 @@ class Executor(object): if op.type == 'fetch': op._set_attr( 'op_role', - core.op_proto_and_checker_maker.OpRole.Optimize) + core.op_proto_and_checker_maker.OpRole.Optimize, + ) fetch_list = None - scope, trainer = self._prepare_trainer(program=program, - dataset=dataset, - scope=scope, - thread=thread, - debug=debug, - fetch_list=fetch_list, - fetch_info=fetch_info, - print_period=print_period) + scope, trainer = self._prepare_trainer( + program=program, + dataset=dataset, + scope=scope, + thread=thread, + debug=debug, + fetch_list=fetch_list, + fetch_info=fetch_info, + print_period=print_period, + ) trainer._set_infer(is_infer) trainer._gen_trainer_desc() @@ -2055,8 +2326,11 @@ class Executor(object): dataset._dynamic_adjust_before_train(trainer.proto_desc.thread_num) if program._heter_pipeline_opt is None: - trainer_instance = self._default_executor.init_for_dataset( # -->InitForDataset - program.desc, trainer._desc(), scope, dataset.dataset) + trainer_instance = ( + self._default_executor.init_for_dataset( # -->InitForDataset + program.desc, trainer._desc(), scope, dataset.dataset + ) + ) else: # cache trainer instance for heterps pipeline training if fetch_list == None: @@ -2065,8 +2339,9 @@ class Executor(object): trainer_instance = self._get_trainer_cache(cache_key) if trainer_instance is None: trainer_instance = self._default_executor.init_for_dataset( - program.desc, trainer._desc(), scope, dataset.dataset) - #print("test_fl_ps - trainer_desc: {}\n".format(trainer)) + program.desc, trainer._desc(), scope, dataset.dataset + ) + # print("test_fl_ps - trainer_desc: {}\n".format(trainer)) self._add_trainer_cache(cache_key, trainer_instance) else: trainer_instance.ResetDataset(dataset.dataset) @@ -2093,18 +2368,20 @@ class Executor(object): return None - def _prepare_pipeline_ctx(self, - program=None, - dataset=None, - scope=None, - thread=0, - is_infer=False, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None, - use_program_cache=False): + def _prepare_pipeline_ctx( + self, + program=None, + dataset=None, + scope=None, + thread=0, + is_infer=False, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + use_program_cache=False, + ): assert program._pipeline_opt is not None assert dataset is None, "dataset should be None for pipeline mode" @@ -2124,10 +2401,12 @@ class Executor(object): data_vars.append(var) if core.is_compiled_with_npu(): dataset = paddle.fluid.DatasetFactory().create_dataset( - 'InMemoryDataset') + 'InMemoryDataset' + ) else: dataset = paddle.fluid.DatasetFactory().create_dataset( - 'FileInstantDataset') + 'FileInstantDataset' + ) dataset.set_batch_size(1) dataset.set_thread(1) dataset.set_filelist(['None']) @@ -2148,11 +2427,13 @@ class Executor(object): if fetch_var_name in real_program.global_block().vars: real_fetch_list.append(fetch_var) - real_program = _add_feed_fetch_ops(program=real_program, - feed=[], - fetch_list=real_fetch_list, - feed_var_name='feed', - fetch_var_name='fetch') + real_program = _add_feed_fetch_ops( + program=real_program, + feed=[], + fetch_list=real_fetch_list, + feed_var_name='feed', + fetch_var_name='fetch', + ) main_block = real_program.block(0) for op in main_block.ops: # set the op_role of fetch op to Optimize to avoid @@ -2160,7 +2441,8 @@ class Executor(object): if op.type == 'fetch': op._set_attr( 'op_role', - core.op_proto_and_checker_maker.OpRole.Optimize) + core.op_proto_and_checker_maker.OpRole.Optimize, + ) return real_program, real_fetch_list real_program, real_fetch_list = _get_real_program_fetch_list() @@ -2168,14 +2450,16 @@ class Executor(object): program._pipeline_opt["section_program"] = real_program fetch_list = None - scope, trainer = self._prepare_trainer(program=program, - dataset=dataset, - scope=scope, - thread=thread, - debug=debug, - fetch_list=fetch_list, - fetch_info=fetch_info, - print_period=print_period) + scope, trainer = self._prepare_trainer( + program=program, + dataset=dataset, + scope=scope, + thread=thread, + debug=debug, + fetch_list=fetch_list, + fetch_info=fetch_info, + print_period=print_period, + ) trainer._set_infer(is_infer) trainer._gen_trainer_desc() @@ -2190,93 +2474,148 @@ class Executor(object): trainer_desc = trainer._desc() # slow, cache trainer_instance = self._default_executor.init_for_dataset( - program.desc, trainer_desc, scope, dataset.dataset) + program.desc, trainer_desc, scope, dataset.dataset + ) ctx = [scope, real_fetch_list, trainer_instance] - if use_program_cache: self._add_ctx_cache(cache_key, ctx) + if use_program_cache: + self._add_ctx_cache(cache_key, ctx) return ctx - def _prepare_fleet_executor_carrier(self, - carrier_id="", - program=None, - scope=None, - fleet_opt=None, - with_standalone_executor=False): - num_micro_batches = fleet_opt[ - "num_micro_batches"] if "num_micro_batches" in fleet_opt else 1 + def _prepare_fleet_executor_carrier( + self, + carrier_id="", + program=None, + scope=None, + fleet_opt=None, + micro_scope_list=[], + with_standalone_executor=False, + ): + num_micro_batches = ( + fleet_opt["num_micro_batches"] + if "num_micro_batches" in fleet_opt + else 1 + ) cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',') nrank = len(trainer_endpoints) - assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, \ - "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. " \ + assert 'scheduler' in fleet_opt or 'tasks' in fleet_opt, ( + "Fleet executor need configuration for scheduler, you can choose from 1F1B or Origin. " "Or you can provide a list of task nodes to init fleet executor directly." + ) if 'tasks' in fleet_opt: - assert 'task_id_to_rank' in fleet_opt, "If you provide tasks to init fleet executor," \ - " task_id_to_rank should also be provided." + assert 'task_id_to_rank' in fleet_opt, ( + "If you provide tasks to init fleet executor," + " task_id_to_rank should also be provided." + ) print('fleet executor will use user defined task nodes') tasks = [task.task_node() for task in fleet_opt['tasks']] task_id_to_rank = fleet_opt['task_id_to_rank'] else: scheduler = fleet_opt['scheduler'] if scheduler == '1F1B': - from paddle.distributed.fleet.fleet_executor_utils import run1f1b - if "dist_strategy" not in fleet_opt or \ - "pp_degree" not in fleet_opt["dist_strategy"] or \ - fleet_opt["dist_strategy"]["pp_degree"] == 1: + from paddle.distributed.fleet.fleet_executor_utils import ( + run1f1b, + ) + + if ( + "dist_strategy" not in fleet_opt + or "pp_degree" not in fleet_opt["dist_strategy"] + or fleet_opt["dist_strategy"]["pp_degree"] == 1 + ): warnings.warn("Using 1F1B scheduler with pp_degree == 1.") tasks, task_id_to_rank = run1f1b( - program, cur_rank, fleet_opt.get('num_micro_batches', 1), - fleet_opt.get('dist_strategy', {}), nrank, - with_standalone_executor) + program, + cur_rank, + fleet_opt.get('num_micro_batches', 1), + fleet_opt.get('dist_strategy', {}), + nrank, + with_standalone_executor, + ) elif scheduler == 'Origin': from paddle.distributed.fleet.fleet_executor_utils import origin - if "dist_strategy" in fleet_opt and \ - "pp_degree" in fleet_opt["dist_strategy"]: - assert fleet_opt["dist_strategy"]["pp_degree"] == 1, \ - "For pipeline mode, the scheduler should be 1F1B instead of Origin." + + if ( + "dist_strategy" in fleet_opt + and "pp_degree" in fleet_opt["dist_strategy"] + ): + assert ( + fleet_opt["dist_strategy"]["pp_degree"] == 1 + ), "For pipeline mode, the scheduler should be 1F1B instead of Origin." if "num_micro_batches" in fleet_opt: - assert fleet_opt["num_micro_batches"] == 1, \ - "For origin scheduler mode, the num micro batches should be 1." + assert ( + fleet_opt["num_micro_batches"] == 1 + ), "For origin scheduler mode, the num micro batches should be 1." tasks, task_id_to_rank = origin(program, cur_rank) else: - raise "Fleet_executor only supports 1F1B and Origin scheduler, " \ - "but received " + str(scheduler) + "." + raise "Fleet_executor only supports 1F1B and Origin scheduler, " "but received " + str( + scheduler + ) + "." # NOTE: have to hold these vars, otherwise will be destructed fleet_opt['tasks'] = tasks fleet_opt['task_id_to_rank'] = task_id_to_rank place = core.Place() place.set_place(self.place) - # NOTE: the last argument is used to force create some vars in root scope, - # won't be used during train. - self._fleet_executor.init(carrier_id, program.desc, scope, place, - num_micro_batches, tasks, task_id_to_rank, []) - - def _run_using_fleet_executor(self, - program=None, - feed=None, - feed_var_name="feed", - fetch_var_name="fetch", - fetch_list=None, - with_standalone_executor=False): + + inference_root_scope_vars = ( + fleet_opt["fetch_var"] if "fetch_var" in fleet_opt else [] + ) + self._fleet_executor.init( + carrier_id, + program.desc, + scope, + place, + num_micro_batches, + tasks, + task_id_to_rank, + inference_root_scope_vars, + micro_scope_list, + ) + + def _run_using_fleet_executor( + self, + program=None, + feed=None, + feed_var_name="feed", + fetch_var_name="fetch", + fetch_list=None, + with_standalone_executor=False, + return_numpy=True, + ): cache_key = _get_strong_program_cache_key(program, feed, fetch_list) cached_program = self._get_program_cache(cache_key) cached_scope = self._get_scope_cache(cache_key) + micro_cached_scopes = self._get_micro_scopes_cache(cache_key) + fleet_opt = program._pipeline_opt["fleet_opt"] if cached_scope is None: cached_scope = global_scope() self._add_scope_cache(cache_key, cached_scope) + if micro_cached_scopes is None: + micro_cached_scopes = [] + if ( + "inference_generation" in fleet_opt + and fleet_opt["inference_generation"] + ): + for _ in range(int(fleet_opt["num_micro_batches"])): + micro_cached_scopes.append(cached_scope.new_scope()) + self._add_micro_scopes_cache(cache_key, micro_cached_scopes) if cached_program is None: - assert program._pipeline_opt, "program should have _pipeline_opt to start carrier" + assert ( + program._pipeline_opt + ), "program should have _pipeline_opt to start carrier" real_feed = [] if feed is None else feed real_program = program if "section_program" in program._pipeline_opt: real_program = program._pipeline_opt["section_program"] - cached_program = _add_feed_fetch_ops(program=real_program, - feed=real_feed, - fetch_list=fetch_list, - feed_var_name=feed_var_name, - fetch_var_name=fetch_var_name) + cached_program = _add_feed_fetch_ops( + program=real_program, + feed=real_feed, + fetch_list=fetch_list, + feed_var_name=feed_var_name, + fetch_var_name=fetch_var_name, + ) main_block = cached_program.block(0) for op in main_block.ops: # set the op_role of fetch op to Optimize to avoid @@ -2284,9 +2623,9 @@ class Executor(object): if op.type == 'fetch': op._set_attr( 'op_role', - core.op_proto_and_checker_maker.OpRole.Optimize) + core.op_proto_and_checker_maker.OpRole.Optimize, + ) self._add_program_cache(cache_key, cached_program) - fleet_opt = program._pipeline_opt["fleet_opt"] if 'tasks' in fleet_opt: # Insert feed/fetch op for cloned program in each task node, # these ops has already been inserted into the origin program. @@ -2298,9 +2637,11 @@ class Executor(object): feed_task = fleet_opt['tasks'][0] print("Inserting feed ops for task", feed_task.task_id()) feed_program = feed_task.get_program() - feed_program = self._add_feed_ops(program=feed_program, - feed=real_feed, - feed_var_name=feed_var_name) + feed_program = self._add_feed_ops( + program=feed_program, + feed=real_feed, + feed_var_name=feed_var_name, + ) feed_task.set_program(feed_program) # Insert fetch ops @@ -2310,7 +2651,8 @@ class Executor(object): fetch_program = self._add_fetch_ops( program=fetch_program, fetch_list=fetch_list, - fetch_var_name=fetch_var_name) + fetch_var_name=fetch_var_name, + ) main_block = fetch_program.block(0) for op in main_block.ops: # set the op_role of fetch op to Optimize to avoid @@ -2318,7 +2660,8 @@ class Executor(object): if op.type == 'fetch': op._set_attr( 'op_role', - core.op_proto_and_checker_maker.OpRole.Optimize) + core.op_proto_and_checker_maker.OpRole.Optimize, + ) fetch_task.set_program(fetch_program) self._prepare_fleet_executor_carrier( @@ -2326,7 +2669,9 @@ class Executor(object): program=cached_program, scope=cached_scope, fleet_opt=fleet_opt, - with_standalone_executor=with_standalone_executor) + micro_scope_list=micro_cached_scopes, + with_standalone_executor=with_standalone_executor, + ) if feed: # NOTE: don't have to traverse programs in task nodes, @@ -2335,18 +2680,49 @@ class Executor(object): self._feed_data(cached_program, feed, feed_var_name, cached_scope) from paddle.optimizer.lr import LRScheduler + if hasattr(program, 'lr_sheduler'): lr_sheduler = program.lr_sheduler assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" lr_value = lr_sheduler() lr_var = program.global_block().vars[lr_sheduler._var_name] data = np.array([lr_value]).astype(convert_dtype(lr_var.dtype)) - tensor = core.get_variable_tensor(cached_scope, - lr_sheduler._var_name) + tensor = core.get_variable_tensor( + cached_scope, lr_sheduler._var_name + ) tensor.set(data, self.place) self._fleet_executor.run(cache_key) + if "fetch_var" in fleet_opt: + # If we speed up the generation in evaluation, we need to generate + # multiple queries at the same time. Each query will in separate scope in order + # not mix up. It indicate that final result will in multiple scopes and need to + # fetch each. + result_list = [] + for scope in micro_cached_scopes: + scope_result_list = [] + for varname in fleet_opt["fetch_var"]: + tensor = None + try: + tensor = core.get_variable_tensor(scope, varname) + if return_numpy: + tensor = as_numpy(tensor) + except: + var = scope.find_var(varname) + tensor = var.get_lod_tensor_array() + if return_numpy: + tensor = as_numpy(tensor) + else: + tensor = [t for t in tensor] + + if tensor: + scope_result_list.append(tensor) + + if scope_result_list: + result_list.append(scope_result_list) + return result_list + if fetch_list: arr = cached_scope.find_var(fetch_var_name).get_fetch_list() tensors = arr._move_to_list() @@ -2364,30 +2740,32 @@ class Executor(object): feed_var = global_block.create_var( name=feed_var_name, type=core.VarDesc.VarType.FEED_MINIBATCH, - persistable=True) + persistable=True, + ) # prepend feed operators if not has_feed_operators(global_block, feed, feed_var_name): for i, name in enumerate(feed): if global_block.has_var(name): out = global_block.var(name) - global_block._prepend_op(type='feed', - inputs={'X': [feed_var]}, - outputs={'Out': [out]}, - attrs={'col': i}) + global_block._prepend_op( + type='feed', + inputs={'X': [feed_var]}, + outputs={'Out': [out]}, + attrs={'col': i}, + ) else: warnings.warn( "The variable %s is not found in program. It is not declared or is pruned." - % name) + % name + ) return tmp_program @classmethod - def _add_fetch_ops(cls, - program, - fetch_list, - fetch_var_name, - use_fetch_v2=False): + def _add_fetch_ops( + cls, program, fetch_list, fetch_var_name, use_fetch_v2=False + ): tmp_program = program.clone() global_block = tmp_program.global_block() @@ -2398,7 +2776,8 @@ class Executor(object): fetch_var = global_block.create_var( name=fetch_var_name, type=core.VarDesc.VarType.FETCH_LIST, - persistable=True) + persistable=True, + ) if use_fetch_v2: fetch_op = 'fetch_v2' @@ -2406,17 +2785,19 @@ class Executor(object): fetch_op = 'fetch' # append fetch_operators - if not has_fetch_operators(global_block, fetch_list, fetch_var_name, - fetch_op): + if not has_fetch_operators( + global_block, fetch_list, fetch_var_name, fetch_op + ): for i, var in enumerate(fetch_list): assert isinstance(var, Variable) or isinstance( - var, - six.string_types), ("Wrong type for fetch_list[%s]: %s" % - (i, type(var))) - global_block.append_op(type=fetch_op, - inputs={'X': [var]}, - outputs={'Out': [fetch_var]}, - attrs={'col': i}) + var, six.string_types + ), "Wrong type for fetch_list[%s]: %s" % (i, type(var)) + global_block.append_op( + type=fetch_op, + inputs={'X': [var]}, + outputs={'Out': [fetch_var]}, + attrs={'col': i}, + ) return tmp_program @@ -2431,25 +2812,36 @@ class Executor(object): return tmp_program - def _run_pipeline(self, - program=None, - dataset=None, - scope=None, - thread=0, - is_infer=False, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None, - use_program_cache=False): - scope, real_fetch_list, trainer_instance = \ - self._prepare_pipeline_ctx(program, dataset, scope, thread, - is_infer, debug, fetch_list, fetch_info, - print_period, fetch_handler, - use_program_cache) + def _run_pipeline( + self, + program=None, + dataset=None, + scope=None, + thread=0, + is_infer=False, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + use_program_cache=False, + ): + scope, real_fetch_list, trainer_instance = self._prepare_pipeline_ctx( + program, + dataset, + scope, + thread, + is_infer, + debug, + fetch_list, + fetch_info, + print_period, + fetch_handler, + use_program_cache, + ) from paddle.optimizer.lr import LRScheduler + if hasattr(program, 'lr_sheduler'): lr_sheduler = program.lr_sheduler assert isinstance(lr_sheduler, LRScheduler), "must be LRScheduler" @@ -2471,16 +2863,18 @@ class Executor(object): return None - def infer_from_dataset(self, - program=None, - dataset=None, - scope=None, - thread=0, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None): + def infer_from_dataset( + self, + program=None, + dataset=None, + scope=None, + thread=0, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + ): """ Infer from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset. Given a program, either a program or compiled program, infer_from_dataset will @@ -2536,26 +2930,39 @@ class Executor(object): dataset=dataset) """ - return self._run_from_dataset(program, dataset, scope, thread, True, - debug, fetch_list, fetch_info, - print_period, fetch_handler) - - def start_heter_trainer(self, - program=None, - scope=None, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None): - scope, trainer = self._prepare_trainer(program=program, - dataset=None, - scope=scope, - thread=1, - debug=debug, - fetch_list=fetch_list, - fetch_info=fetch_info, - print_period=print_period) + return self._run_from_dataset( + program, + dataset, + scope, + thread, + True, + debug, + fetch_list, + fetch_info, + print_period, + fetch_handler, + ) + + def start_heter_trainer( + self, + program=None, + scope=None, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + ): + scope, trainer = self._prepare_trainer( + program=program, + dataset=None, + scope=scope, + thread=1, + debug=debug, + fetch_list=fetch_list, + fetch_info=fetch_info, + print_period=print_period, + ) trainer._set_infer(False) trainer._gen_trainer_desc() @@ -2563,32 +2970,35 @@ class Executor(object): self._dump_debug_info(program=program, trainer=trainer) trainer_instance = self._default_executor.init_for_dataset( - program.desc, trainer._desc(), scope, None) + program.desc, trainer._desc(), scope, None + ) - #if fetch_handler is not None: + # if fetch_handler is not None: # scope0 = trainer_instance.get_worker_scope(0) # fetch_monitor = FetchHandlerMonitor(scope0, fetch_handler) # fetch_monitor.start() # self._default_executor.run_from_dataset(trainer_instance) # fetch_monitor.stop() # self._default_executor.release_trainer(trainer_instance) - #else: + # else: self._default_executor.run_from_dataset(trainer_instance) - #self._default_executor.release_trainer(trainer_instance) + # self._default_executor.release_trainer(trainer_instance) return trainer_instance - def train_from_dataset(self, - program=None, - dataset=None, - scope=None, - thread=0, - debug=False, - fetch_list=None, - fetch_info=None, - print_period=100, - fetch_handler=None): + def train_from_dataset( + self, + program=None, + dataset=None, + scope=None, + thread=0, + debug=False, + fetch_list=None, + fetch_info=None, + print_period=100, + fetch_handler=None, + ): """ Train from a pre-defined Dataset. Dataset is defined in paddle.fluid.dataset. Given a program, either a program or compiled program, train_from_dataset will @@ -2610,7 +3020,7 @@ class Executor(object): for each run. default is global_scope thread(int): number of thread a user wants to run in this function. Default is 0, which means using thread num of dataset - debug(bool): whether a user wants to run train_from_dataset + debug(bool): whether a user wants to run train_from_dataset fetch_list(Tensor List): fetch Tensor list, each variable will be printed during training fetch_info(String List): print information for each Tensor, its length should be equal @@ -2620,9 +3030,9 @@ class Executor(object): Returns: None - + Examples: - + .. code-block:: python import paddle @@ -2643,6 +3053,15 @@ class Executor(object): dataset=dataset) """ - return self._run_from_dataset(program, dataset, scope, thread, False, - debug, fetch_list, fetch_info, - print_period, fetch_handler) + return self._run_from_dataset( + program, + dataset, + scope, + thread, + False, + debug, + fetch_list, + fetch_info, + print_period, + fetch_handler, + ) diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ce013815ae4..f6971110789 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -101,6 +101,7 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) list(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run) list(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor) + list(REMOVE_ITEM TEST_OPS test_fleet_executor_cond_interceptor) endif() list(REMOVE_ITEM TEST_OPS test_deprecated_decorator) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index bd6ccfd3922..446461a045b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -63,6 +63,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) set_tests_properties(test_engine_callbacks PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) + py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS ${dist_ENVS}) set_tests_properties(test_parallel_tuner PROPERTIES TIMEOUT 120) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py index ea3bdd32082..6d96cd13773 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py @@ -32,7 +32,9 @@ def apply_pass(use_amp=False, level=None): amp.enable = True amp.custom_white_list = ['softmax', 'layer_norm', 'gelu'] amp.custom_black_list = [ - 'c_softmax_with_cross_entropy', 'elementwise_div', 'reduce_sum' + 'c_softmax_with_cross_entropy', + 'elementwise_div', + 'reduce_sum', ] amp.init_loss_scaling = 32768 amp.use_fp16_guard = False @@ -48,7 +50,6 @@ def reset_prog(): class TestAMPPass(unittest.TestCase): - def setUp(self): self.rtol = 1e-5 self.atol = 1e-8 @@ -61,6 +62,7 @@ class TestAMPPass(unittest.TestCase): paddle.seed(2021) np.random.seed(2021) random.seed(2021) + paddle.distributed.fleet.init(is_collective=True) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) @@ -83,7 +85,9 @@ class TestAMPPass(unittest.TestCase): rtol=rtol or self.rtol, atol=atol or self.atol, err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( - __class__, ref_losses, check_losses, ref_losses - check_losses)) + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) def test_amp_pass(self): # mp2 training diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py index 1cbc8aed120..e462751e1a6 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/clip_grad_by_global_norm.py @@ -62,7 +62,6 @@ def reset_prog(): class TestGradientClipByGlobalNorm(unittest.TestCase): - def setUp(self): self.batch_size = 2 self.batch_num = 1 @@ -73,6 +72,7 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): paddle.seed(2022) np.random.seed(2022) random.seed(2022) + paddle.distributed.fleet.init(is_collective=True) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) @@ -95,9 +95,10 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): sharding_p, rtol=1e-05, atol=1e-08, - err_msg= - 'gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}' - .format(dp_p, sharding_p, dp_p - sharding_p)) + err_msg='gradient clip by global norm has wrong results!, \nu={}\nv={}\ndiff={}'.format( + dp_p, sharding_p, dp_p - sharding_p + ), + ) def test_grad_clip(self): # dp2 training @@ -109,7 +110,8 @@ class TestGradientClipByGlobalNorm(unittest.TestCase): sharding_engine = self.get_engine(True) sharding_engine.fit(self.dataset, 3, batch_size=self.batch_size) sharding_param_values = get_parameter_value( - sharding_engine.main_program) + sharding_engine.main_program + ) self.check_result(dp_param_values, sharding_param_values) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py new file mode 100644 index 00000000000..335e139295e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/generation_pipeline_pass_unittest.py @@ -0,0 +1,177 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from paddle.distributed.fleet import auto + +_g_mesh = auto.ProcessMesh([0, 1]) +PP_MESH_0 = auto.ProcessMesh([0]) +PP_MESH_1 = auto.ProcessMesh([1]) + +image_size = 1024 +class_num = 10 + + +class MyDataset(paddle.io.Dataset): + def __init__(self, num_samples): + super(MyDataset, self).__init__() + self.num_samples = num_samples + + def __getitem__(self, index): + input = np.random.uniform(size=image_size).astype("float32") + input = np.random.uniform(size=image_size).astype("float32") + return input, input + + def __len__(self): + return self.num_samples + + +class MLPLayer(nn.Layer): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02, + ): + super(MLPLayer, self).__init__() + d_model = hidden_size + dim_feedforward = intermediate_size + weight_attr = paddle.ParamAttr( + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) + ) + bias_attr = None + + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) + self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr) + self.norm = nn.LayerNorm(d_model, epsilon=1e-5) + self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train") + + def forward(self, input): + out = auto.shard_op(self.norm, PP_MESH_0)(input) + out = self.linear0(out) + out = F.gelu(out, approximate=True) + out = auto.shard_op(self.linear1, PP_MESH_1)(out) + out = self.dropout(out) + out = self.linear2(out) + return out + + +class GEN(nn.Layer): + def __init__(self, mlp): + super(GEN, self).__init__() + self.mlp = mlp + + def forward(self, input): + model_kwargs = {} + + output = self.mlp(input) + + cur_step = paddle.full([1], 0, dtype='int64') + total_step = paddle.full([1], 10, dtype='int64') + + model_kwargs['input'] = input + model_kwargs['output'] = output + + while cur_step < total_step: + + out = self.mlp(model_kwargs['input']) + model_kwargs['res'] = out + paddle.increment(cur_step) + + auto.shard_op(paddle.assign, _g_mesh)(model_kwargs['input'], out) + + output = F.gelu(model_kwargs['input'], approximate=True) + + return output, cur_step + + +def get_model(): + + with paddle.LazyGuard(): + mlp = MLPLayer() + gen = GEN(mlp) + return gen + + +class TestGenerationPipeline(unittest.TestCase): + def test_pp2(self): + + model = get_model() + + strategy = auto.Strategy() + pipeline = strategy.pipeline + pipeline.enable = True + pipeline.schedule_mode = "stream" + pipeline.generation_batch_size = 4 + pipeline.accumulate_steps = 4 + engine = auto.Engine(model, strategy=strategy) + + engine.prepare( + inputs_spec=paddle.static.InputSpec( + shape=[2, 1024], name='input', dtype='float32' + ), + labels_spec=paddle.static.InputSpec( + shape=[2, 1024], name='label', dtype='float32' + ), + mode="eval", + ) + + train_data = MyDataset(50 * 2) + train_dataloader = engine._prepare_dataloader_from_generator( + dataset=train_data, + capacity=70, + iterable=False, + batch_size=2, + epochs=1, + steps_per_epoch=100, + ) + engine._prepare_reader() + + fleet_opt = engine.main_program._pipeline_opt['fleet_opt'] + assert len(fleet_opt['tasks']) == 5 + assert fleet_opt['inference_generation'] == True + assert fleet_opt['num_micro_batches'] == 4 + num_task_in_rank = 5 + for idx, (task_id, rank_id) in enumerate( + fleet_opt['task_id_to_rank'].items() + ): + assert ( + task_id == rank_id * num_task_in_rank + idx % num_task_in_rank + ) + + train_dataloader._inner_dataloader.start() + try: + engine._executor.run( + engine.main_program, use_program_cache=False, return_numpy=False + ) + except paddle.fluid.core.EOFException: + print("test done") + train_dataloader._inner_dataloader.reset() + train_dataloader._inner_dataloader.start() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py index 438e17d29f7..58bc1143885 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/gradient_merge_pass_unittest.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -import sys import random import numpy as np import paddle @@ -44,7 +43,6 @@ def reset_prog(): class TestGradientMergePass(unittest.TestCase): - def setUp(self): self.rtol = 1e-5 self.atol = 1e-8 @@ -57,6 +55,7 @@ class TestGradientMergePass(unittest.TestCase): paddle.seed(2021) np.random.seed(2021) random.seed(2021) + paddle.distributed.fleet.init(is_collective=True) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) @@ -79,23 +78,23 @@ class TestGradientMergePass(unittest.TestCase): rtol=self.rtol, atol=self.atol, err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( - __class__, ref_losses, check_losses, ref_losses - check_losses)) + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) def test_gradient_merge_pass(self): # dp2 training dp_engine = self.get_engine() - history = dp_engine.fit(self.dataset, - 3, - batch_size=self.batch_size, - log_freq=1) + history = dp_engine.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) dp_losses = np.array(history.history["loss"]) # dp2 gradient merge training gm_engine = self.get_engine(True) - history = gm_engine.fit(self.dataset, - 3, - batch_size=self.batch_size, - log_freq=1) + history = gm_engine.fit( + self.dataset, 3, batch_size=self.batch_size, log_freq=1 + ) gm_losses = np.array(history.history["loss"]) # avg_loss = 0 diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py index 1a444353d03..f14d95bfaa7 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/recompute_pass_unittest.py @@ -39,7 +39,6 @@ def reset_prog(): class TestRecomputePass(unittest.TestCase): - def setUp(self): self.rtol = 1e-6 self.atol = 1e-8 @@ -52,6 +51,7 @@ class TestRecomputePass(unittest.TestCase): paddle.seed(2022) np.random.seed(2022) random.seed(2022) + paddle.distributed.fleet.init(is_collective=True) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) @@ -74,7 +74,9 @@ class TestRecomputePass(unittest.TestCase): rtol=self.rtol, atol=self.atol, err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( - __class__, ref_losses, check_losses, ref_losses - check_losses)) + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) def test_recompute_pass(self): # mp2 training diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py index 356c8ec2e14..1f6b84915ab 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/sharding_pass_unittest.py @@ -44,7 +44,6 @@ def reset_prog(): class TestShardingPass(unittest.TestCase): - def setUp(self): self.rtol = 1e-6 self.atol = 1e-8 @@ -57,6 +56,7 @@ class TestShardingPass(unittest.TestCase): paddle.seed(2022) np.random.seed(2022) random.seed(2022) + paddle.distributed.fleet.init(is_collective=True) place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) engine._executor = paddle.static.Executor(place) @@ -79,7 +79,9 @@ class TestShardingPass(unittest.TestCase): rtol=self.rtol, atol=self.atol, err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format( - __class__, ref_losses, check_losses, ref_losses - check_losses)) + __class__, ref_losses, check_losses, ref_losses - check_losses + ), + ) def test_sharding_pass(self): # dp2 training @@ -89,25 +91,25 @@ class TestShardingPass(unittest.TestCase): # sharding2 stage1 training sharding1_engine = self.get_engine(True, 1) - history = sharding1_engine.fit(self.dataset, - 3, - batch_size=self.batch_size) + history = sharding1_engine.fit( + self.dataset, 3, batch_size=self.batch_size + ) sharding1_losses = np.array(history.history["loss"]) self.check_results(dp_losses, sharding1_losses) # sharding2 stage2 training sharding2_engine = self.get_engine(True, 2) - history = sharding2_engine.fit(self.dataset, - 3, - batch_size=self.batch_size) + history = sharding2_engine.fit( + self.dataset, 3, batch_size=self.batch_size + ) sharding2_losses = np.array(history.history["loss"]) self.check_results(dp_losses, sharding2_losses) # sharding2 stage3 training sharding3_engine = self.get_engine(True, 3) - history = sharding3_engine.fit(self.dataset, - 3, - batch_size=self.batch_size) + history = sharding3_engine.fit( + self.dataset, 3, batch_size=self.batch_size + ) sharding3_losses = np.array(history.history["loss"]) self.check_results(dp_losses, sharding3_losses) diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py index d2047332c9a..1ce23802943 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_dist_context.py @@ -36,7 +36,7 @@ hidden_size = 1024 sequence_len = 512 _g_process_mesh = [ auto.ProcessMesh([0, 1], dim_names=["x"]), - auto.ProcessMesh([2, 3], dim_names=["x"]) + auto.ProcessMesh([2, 3], dim_names=["x"]), ] @@ -47,41 +47,45 @@ def get_random_inputs_and_labels(input_shape, label_shape): def batch_generator_creator(): - def __reader__(): for _ in range(batch_size): batch_input, batch_label = get_random_inputs_and_labels( [batch_size, sequence_len, hidden_size], - [batch_size, sequence_len, 1]) + [batch_size, sequence_len, 1], + ) yield batch_input, batch_label return __reader__ class MLPLayer(nn.Layer): - - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - dropout_ratio=0.1, - initializer_range=0.02): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + dropout_ratio=0.1, + initializer_range=0.02, + ): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size - param_initializer = nn.initializer.Normal(mean=0.0, - std=initializer_range) + param_initializer = nn.initializer.Normal( + mean=0.0, std=initializer_range + ) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) self.linear0 = nn.Linear( d_model, dim_feedforward, weight_attr=paddle.ParamAttr(initializer=param_initializer), - bias_attr=None) + bias_attr=None, + ) self.linear1 = nn.Linear( dim_feedforward, d_model, weight_attr=paddle.ParamAttr(initializer=param_initializer), - bias_attr=None) + bias_attr=None, + ) def forward(self, input): out = self.norm(input) @@ -103,78 +107,106 @@ def get_program(): start_program = static.Program() with static.program_guard(train_program, start_program): # input - input = static.data(name="input", - shape=[batch_size, sequence_len, hidden_size], - dtype='float32') - label = static.data(name="label", - shape=[batch_size, sequence_len, 1], - dtype='float32') + input = static.data( + name="input", + shape=[batch_size, sequence_len, hidden_size], + dtype='float32', + ) + label = static.data( + name="label", shape=[batch_size, sequence_len, 1], dtype='float32' + ) data_holder = [input, label] # dataloader - dataloader = paddle.io.DataLoader.from_generator(feed_list=data_holder, - capacity=4 * - batch_size, - iterable=False) - dataloader.set_batch_generator(batch_generator_creator(), - places=paddle.static.cuda_places()) + dataloader = paddle.io.DataLoader.from_generator( + feed_list=data_holder, capacity=4 * batch_size, iterable=False + ) + dataloader.set_batch_generator( + batch_generator_creator(), places=paddle.static.cuda_places() + ) # data dist_attr auto.shard_tensor(input, _g_process_mesh[0], ["x", None, None]) auto.shard_tensor(label, _g_process_mesh[0], ["x", None, None]) - mlp_start = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp_start = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) pred = mlp_start(input) - mlp_mid = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp_mid = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) pred = mlp_mid(pred) - mlp_end = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - dropout_ratio=0.1, - initializer_range=0.02) + mlp_end = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + dropout_ratio=0.1, + initializer_range=0.02, + ) pred = mlp_end(pred) error_cost = paddle.nn.functional.square_error_cost(pred, label) loss = paddle.mean(error_cost) - optimizer = paddle.optimizer.Adam(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + optimizer = paddle.optimizer.Adam( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) feed_vars = {"inputs": [input], "labels": [label]} fetch_vars = {"loss": [loss]} - return train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars + return ( + train_program, + start_program, + dataloader, + loss, + optimizer, + feed_vars, + fetch_vars, + ) class TestDistributedContext(unittest.TestCase): - def test_backup_restore(self): - train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program( + ( + train_program, + start_program, + dataloader, + loss, + optimizer, + feed_vars, + fetch_vars, + ) = get_program() + dist_context = DistributedContext( + train_program, start_program, optimizer, loss, feed_vars, fetch_vars ) - dist_context = DistributedContext(train_program, start_program, - optimizer, loss, feed_vars, - fetch_vars) dist_context.initialize() dist_context._backup(serial=True, dist=True) - dist_context._restore(serial=True, - serial_mode="to_backup", - dist=True, - dist_mode="to_backup") + dist_context._restore( + serial=True, + serial_mode="to_backup", + dist=True, + dist_mode="to_backup", + ) dist_context._backup(serial=True, dist=True) - dist_context._restore(serial=True, - serial_mode="to_original", - dist=True, - dist_mode="to_original") + dist_context._restore( + serial=True, + serial_mode="to_original", + dist=True, + dist_mode="to_original", + ) dist_context._backup(serial=True, dist=True) dist_context._restore(serial=True, dist=True, dist_mode="to_default") @@ -183,25 +215,45 @@ class TestDistributedContext(unittest.TestCase): dist_context._restore(serial=True, dist=True, dist_mode="to_nothing") def test_deepcopy(self): - train_program, start_program, dataloader, loss, optimizer, feed_vars, fetch_vars = get_program( + ( + train_program, + start_program, + dataloader, + loss, + optimizer, + feed_vars, + fetch_vars, + ) = get_program() + dist_context = DistributedContext( + train_program, start_program, optimizer, loss, feed_vars, fetch_vars ) - dist_context = DistributedContext(train_program, start_program, - optimizer, loss, feed_vars, - fetch_vars) dist_context.initialize() copy_dist_context = copy.deepcopy(dist_context) copy_list = [ - "_original_serial_main_program", "_original_serial_startup_program", \ - "_serial_main_program", "_serial_startup_program", "_serial_graph", \ - "_dist_main_programs", "_dist_startup_programs", \ - "_serial_ordered_nodes", "_serial_ordered_tensor_nodes", \ - "_serial_ordered_op_nodes", "_original_serial_loss", \ - "_original_serial_feed_vars", "_original_serial_fetch_vars", \ - "_serial_loss", "_serial_feed_vars", "_serial_fetch_vars", "_serial_optimizer", \ - "_backup_serial_main_program_stack", "_backup_serial_startup_program_stack", \ - "_pass_context"] + "_original_serial_main_program", + "_original_serial_startup_program", + "_serial_main_program", + "_serial_startup_program", + "_serial_graph", + "_dist_main_programs", + "_dist_startup_programs", + "_serial_ordered_nodes", + "_serial_ordered_tensor_nodes", + "_serial_ordered_op_nodes", + "_original_serial_loss", + "_original_serial_feed_vars", + "_original_serial_fetch_vars", + "_serial_loss", + "_serial_feed_vars", + "_serial_fetch_vars", + "_serial_optimizer", + "_backup_serial_main_program_stack", + "_backup_serial_startup_program_stack", + "_pass_context", + "_tensor_nodes_with_same_name", + ] for i in range(len(copy_list)): copy_obj = "copy_dist_context." + copy_list[i] diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py new file mode 100644 index 00000000000..02210483a23 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_pass_generation_pipeline.py @@ -0,0 +1,58 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest +import os +import sys +import subprocess +from paddle.distributed.fleet.launch_utils import run_with_coverage + + +class TestGenerationPipeline(unittest.TestCase): + def test_pp2(self): + file_dir = os.path.dirname(os.path.abspath(__file__)) + launch_model_path = os.path.join( + file_dir, "generation_pipeline_pass_unittest.py" + ) + + if os.environ.get("WITH_COVERAGE", "OFF") == "ON": + coverage_args = ["-m", "coverage", "run", "--branch", "-p"] + else: + coverage_args = [] + + tmp_dir = tempfile.TemporaryDirectory() + cmd = ( + [sys.executable, "-u"] + + coverage_args + + [ + "-m", + "paddle.distributed.launch", + "--devices", + "0,1", + "--log_dir", + tmp_dir.name, + launch_model_path, + ] + ) + + process = subprocess.Popen(cmd) + process.wait() + self.assertEqual(process.returncode, 0) + + tmp_dir.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py index ef08eda6533..27440cb526b 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard.py @@ -28,7 +28,10 @@ from paddle.distributed import fleet from paddle.distributed.auto_parallel.parallelizer import AutoParallelizer from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder -from paddle.distributed.auto_parallel.process_group import _g_process_group_map, ProcessGroup +from paddle.distributed.auto_parallel.process_group import ( + _g_process_group_map, + ProcessGroup, +) from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr paddle.enable_static() @@ -39,26 +42,26 @@ PP_MESH_1 = None class MLPLayer(nn.Layer): - - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - initializer_range=0.02): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02, + ): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr( - initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) + ) bias_attr = None - self.linear0 = nn.Linear(d_model, - dim_feedforward, - weight_attr, - bias_attr=bias_attr) - self.linear1 = nn.Linear(dim_feedforward, - d_model, - weight_attr, - bias_attr=bias_attr) + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): @@ -66,10 +69,12 @@ class MLPLayer(nn.Layer): auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) else: - auto.shard_tensor(self.linear0.weight, _global_process_mesh, - [None, None]) - auto.shard_tensor(self.linear1.weight, _global_process_mesh, - [None, None]) + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, [None, None] + ) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, [None, None] + ) out = self.norm(input) out = self.linear0(out) @@ -80,17 +85,18 @@ class MLPLayer(nn.Layer): def mlp_forward(train_program, start_program): - with static.program_guard(train_program, - start_program), utils.unique_name.guard(): + with static.program_guard( + train_program, start_program + ), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 - input = static.data(name="input", - shape=[batch_size, hidden_size], - dtype='float32') - label = static.data(name="label", - shape=[batch_size, 1], - dtype='float32') + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32' + ) + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32' + ) if _global_parallel_strategy == "pp": auto.shard_tensor(input, PP_MESH_0, [None, None]) @@ -100,9 +106,11 @@ def mlp_forward(train_program, start_program): else: auto.shard_tensor(input, _global_process_mesh, [None, None]) - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02, + ) predict = mlp(input) error_cost = paddle.nn.functional.square_error_cost(predict, label) @@ -111,13 +119,16 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -def get_dist_prog(train_program, - startup_program, - dist_context, - rank_id, - change_process_mesh=False): - loss, train_program, startup_program = mlp_forward(train_program, - startup_program) +def get_dist_prog( + train_program, + startup_program, + dist_context, + rank_id, + change_process_mesh=False, +): + loss, train_program, startup_program = mlp_forward( + train_program, startup_program + ) fleet._user_defined_strategy = fleet.DistributedStrategy() fleet.user_defined_optimizer = paddle.fluid.optimizer.AdamOptimizer() @@ -127,30 +138,43 @@ def get_dist_prog(train_program, # serial forward & backward completion completer = Completer(dist_context) complete_train_program = completer.complete_forward_annotation( - train_program) + train_program + ) dist_context.block_state.parse_forward_blocks(complete_train_program) if change_process_mesh: global PP_MESH_1 dist_context.get_tensor_dist_attr_for_program( - train_program.global_block( - ).vars["gelu_0.tmp_0"]).process_mesh = PP_MESH_1 - - params_grads = parallelizer._generate_backward(complete_train_program, - startup_program, - loss, - parameter_list=None, - no_grad_set=None, - callbacks=None) + train_program.global_block().vars["gelu_0.tmp_0"] + ).process_mesh = PP_MESH_1 + + params_grads = parallelizer._generate_backward( + complete_train_program, + startup_program, + loss, + parameter_list=None, + no_grad_set=None, + callbacks=None, + ) # logical partition partitioner = Partitioner(dist_context, rank_id) - auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads = partitioner.partition( - complete_train_program, startup_program, params_grads) + ( + auto_parallel_main_prog, + auto_parallel_startup_prog, + dist_params_grads, + ) = partitioner.partition( + complete_train_program, startup_program, params_grads + ) partitioned_optimize_ops = parallelizer._apply_optimize( - auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads) + auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads + ) - return auto_parallel_main_prog, auto_parallel_startup_prog, dist_params_grads + return ( + auto_parallel_main_prog, + auto_parallel_startup_prog, + dist_params_grads, + ) def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): @@ -162,16 +186,28 @@ def check_backward_dist_attr(dist_context, dist_main_prog, op_need_check): has_dist_attr = False for var_name in op_need_check.input_arg_names: - if not op_dist_attr.get_input_dims_mapping(var_name) or \ - not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ - not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: + if ( + not op_dist_attr.get_input_dims_mapping(var_name) + or not dist_context.get_tensor_dist_attr_for_program( + vars[var_name] + ).dims_mapping + or not dist_context.get_tensor_dist_attr_for_program( + vars[var_name] + ).process_mesh + ): has_dist_attr = False break if has_dist_attr: for var_name in op_need_check.output_arg_names: - if not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).dims_mapping or \ - not dist_context.get_tensor_dist_attr_for_program(vars[var_name]).process_mesh: + if ( + not dist_context.get_tensor_dist_attr_for_program( + vars[var_name] + ).dims_mapping + or not dist_context.get_tensor_dist_attr_for_program( + vars[var_name] + ).process_mesh + ): has_dist_attr = False break @@ -187,14 +223,22 @@ def check_send_recv_result(dist_main_prog, rank_id): for idx, op in enumerate(ops): if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: send_result = True - if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ - 0]: + if ( + op.type == "recv_v2" + and "gelu_0.tmp_0@GRAD" in op.output_arg_names[0] + ): recv_result = True else: for idx, op in enumerate(ops): - if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: + if ( + op.type == "send_v2" + and "gelu_0.tmp_0@GRAD" in op.input_arg_names + ): send_result = True - if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]: + if ( + op.type == "recv_v2" + and "gelu_0.tmp_0" in op.output_arg_names[0] + ): recv_result = True return send_result and recv_result @@ -203,8 +247,10 @@ def check_send_recv_result(dist_main_prog, rank_id): def check_initialization(dist_startup_prog, rank_id): if rank_id == 0: need_check_params = [ - "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", - "linear_0.b_0" + "layer_norm_0.b_0", + "layer_norm_0.w_0", + "linear_0.w_0", + "linear_0.b_0", ] else: need_check_params = ['linear_1.w_0', 'linear_1.b_0'] @@ -219,7 +265,10 @@ def check_initialization(dist_startup_prog, rank_id): def check_initialization_for_dp(dist_startup_prog): need_check_params = [ - "layer_norm_0.b_0", "layer_norm_0.w_0", "linear_0.w_0", "linear_0.b_0" + "layer_norm_0.b_0", + "layer_norm_0.w_0", + "linear_0.w_0", + "linear_0.b_0", ] + ['linear_1.w_0', 'linear_1.b_0'] params = [] for var_name, var in dist_startup_prog.global_block().vars.items(): @@ -230,12 +279,14 @@ def check_initialization_for_dp(dist_startup_prog): if op.type == "c_broadcast": broadcast_varnames.append(op.output_arg_names[0]) - return sorted(params) == sorted(need_check_params) == sorted( - broadcast_varnames) + return ( + sorted(params) + == sorted(need_check_params) + == sorted(broadcast_varnames) + ) class TestMLPReshard(unittest.TestCase): - def test_complete_backward_annotation(self): global _global_process_mesh _global_process_mesh = auto.ProcessMesh(mesh=[0, 1]) @@ -245,7 +296,8 @@ class TestMLPReshard(unittest.TestCase): dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( - train_program, startup_program, dist_context, 0) + train_program, startup_program, dist_context, 0 + ) op_need_check = None for op in dist_main_prog.global_block().ops: @@ -255,12 +307,14 @@ class TestMLPReshard(unittest.TestCase): # grad op should have dist attr self.assertTrue( - check_backward_dist_attr(dist_context, dist_main_prog, - op_need_check)) + check_backward_dist_attr( + dist_context, dist_main_prog, op_need_check + ) + ) # clear _g_process_group_map _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) def test_mlp_pp(self): global _global_parallel_strategy @@ -277,9 +331,15 @@ class TestMLPReshard(unittest.TestCase): dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( - train_program, startup_program, dist_context, rank_id) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, - dist_context, dist_params_grads) + train_program, startup_program, dist_context, rank_id + ) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank_id, + dist_context, + dist_params_grads, + ) resharder.reshard() # check send and recv result @@ -289,7 +349,7 @@ class TestMLPReshard(unittest.TestCase): # clear _g_process_group_map _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) def test_mlp_pp_diff_process_mesh(self): global _global_parallel_strategy @@ -306,9 +366,15 @@ class TestMLPReshard(unittest.TestCase): dist_context = DistributedContext() rank_id = 1 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( - train_program, startup_program, dist_context, rank_id, True) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, - dist_context, dist_params_grads) + train_program, startup_program, dist_context, rank_id, True + ) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank_id, + dist_context, + dist_params_grads, + ) resharder.reshard() # check send and recv result self.assertTrue(check_send_recv_result(dist_main_prog, rank_id)) @@ -316,7 +382,7 @@ class TestMLPReshard(unittest.TestCase): # clear _g_process_group_map _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) def test_mlp_dp(self): global _global_parallel_strategy @@ -329,9 +395,15 @@ class TestMLPReshard(unittest.TestCase): dist_context = DistributedContext() rank_id = 0 dist_main_prog, dist_startup_prog, dist_params_grads = get_dist_prog( - train_program, startup_program, dist_context, rank_id) - resharder = Resharder(dist_main_prog, dist_startup_prog, rank_id, - dist_context, dist_params_grads) + train_program, startup_program, dist_context, rank_id + ) + resharder = Resharder( + dist_main_prog, + dist_startup_prog, + rank_id, + dist_context, + dist_params_grads, + ) resharder.reshard() # send and recv should not exist in dp scene. @@ -341,7 +413,7 @@ class TestMLPReshard(unittest.TestCase): # clear _g_process_group_map _g_process_group_map.clear() - _g_process_group_map[0] = ProcessGroup(0, []) + _g_process_group_map[0] = ProcessGroup(1000, []) if __name__ == "__main__": diff --git a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py index 75ec5ad6805..1901687b2dd 100644 --- a/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py +++ b/python/paddle/fluid/tests/unittests/test_auto_parallel_reshard_serial.py @@ -17,6 +17,7 @@ from __future__ import print_function import unittest import os + if os.getenv("CUDA_VISIBLE_DEVICES", None) is None: os.environ["CUDA_VISIBLE_DEVICES"] = '0' @@ -25,8 +26,11 @@ import paddle.nn as nn import paddle.static as static import paddle.nn.functional as F import paddle.utils as utils +import paddle.fluid.core as core from paddle.distributed.fleet import auto -from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context +from paddle.distributed.auto_parallel.dist_context import ( + get_default_distributed_context, +) from paddle.distributed import fleet from paddle.distributed.auto_parallel.partitioner import Partitioner from paddle.distributed.auto_parallel.reshard import Resharder @@ -38,26 +42,26 @@ _global_process_mesh = None class MLPLayer(nn.Layer): - - def __init__(self, - hidden_size=1024, - intermediate_size=4 * 1024, - initializer_range=0.02): + def __init__( + self, + hidden_size=1024, + intermediate_size=4 * 1024, + initializer_range=0.02, + ): super(MLPLayer, self).__init__() d_model = hidden_size dim_feedforward = intermediate_size weight_attr = paddle.ParamAttr( - initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)) + initializer=nn.initializer.Normal(mean=0.0, std=initializer_range) + ) bias_attr = None - self.linear0 = nn.Linear(d_model, - dim_feedforward, - weight_attr, - bias_attr=bias_attr) - self.linear1 = nn.Linear(dim_feedforward, - d_model, - weight_attr, - bias_attr=bias_attr) + self.linear0 = nn.Linear( + d_model, dim_feedforward, weight_attr, bias_attr=bias_attr + ) + self.linear1 = nn.Linear( + dim_feedforward, d_model, weight_attr, bias_attr=bias_attr + ) self.norm = nn.LayerNorm(d_model, epsilon=1e-5) def forward(self, input): @@ -65,10 +69,12 @@ class MLPLayer(nn.Layer): auto.shard_tensor(self.linear0.weight, PP_MESH_0, [None, None]) auto.shard_tensor(self.linear1.weight, PP_MESH_1, [None, None]) else: - auto.shard_tensor(self.linear0.weight, _global_process_mesh, - [None, None]) - auto.shard_tensor(self.linear1.weight, _global_process_mesh, - [None, None]) + auto.shard_tensor( + self.linear0.weight, _global_process_mesh, [None, None] + ) + auto.shard_tensor( + self.linear1.weight, _global_process_mesh, [None, None] + ) out = self.norm(input) out = self.linear0(out) @@ -79,17 +85,18 @@ class MLPLayer(nn.Layer): def mlp_forward(train_program, start_program): - with static.program_guard(train_program, - start_program), utils.unique_name.guard(): + with static.program_guard( + train_program, start_program + ), utils.unique_name.guard(): batch_size = 4 hidden_size = 1024 sequence_len = 512 - input = static.data(name="input", - shape=[batch_size, hidden_size], - dtype='float32') - label = static.data(name="label", - shape=[batch_size, 1], - dtype='float32') + input = static.data( + name="input", shape=[batch_size, hidden_size], dtype='float32' + ) + label = static.data( + name="label", shape=[batch_size, 1], dtype='float32' + ) if _global_parallel_strategy == "pp": auto.shard_tensor(input, PP_MESH_0, [None, None]) @@ -99,9 +106,11 @@ def mlp_forward(train_program, start_program): else: auto.shard_tensor(input, _global_process_mesh, [None, None]) - mlp = MLPLayer(hidden_size=hidden_size, - intermediate_size=4 * hidden_size, - initializer_range=0.02) + mlp = MLPLayer( + hidden_size=hidden_size, + intermediate_size=4 * hidden_size, + initializer_range=0.02, + ) predict = mlp(input) error_cost = paddle.nn.functional.square_error_cost(predict, label) @@ -110,8 +119,9 @@ def mlp_forward(train_program, start_program): return loss, train_program, start_program -def get_dist_prog_with_parallelizer(train_program, startup_program, - dist_context): +def get_dist_prog_with_parallelizer( + train_program, startup_program, dist_context +): global _global_process_mesh dist_strategy = fleet.DistributedStrategy() @@ -123,18 +133,25 @@ def get_dist_prog_with_parallelizer(train_program, startup_program, dist_strategy.semi_auto = True fleet.init(is_collective=True, strategy=dist_strategy) - loss, train_program, startup_program = mlp_forward(train_program, - startup_program) - - optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001, - beta1=0.9, - beta2=0.999, - epsilon=1e-08, - grad_clip=None) + loss, train_program, startup_program = mlp_forward( + train_program, startup_program + ) + + optimizer = paddle.fluid.optimizer.AdamOptimizer( + learning_rate=0.00001, + beta1=0.9, + beta2=0.999, + epsilon=1e-08, + grad_clip=None, + ) optimizer = fleet.distributed_optimizer(optimizer) - _, _, distributed_startup_program, distributed_main_program = optimizer.minimize( - loss, startup_program) + ( + _, + _, + distributed_startup_program, + distributed_main_program, + ) = optimizer.minimize(loss, startup_program) return distributed_main_program, distributed_startup_program @@ -147,21 +164,31 @@ def check_send_recv_result(dist_main_prog, rank_id): for idx, op in enumerate(ops): if op.type == "send_v2" and "gelu_0.tmp_0" in op.input_arg_names: send_result = True - if op.type == "recv_v2" and "gelu_0.tmp_0@GRAD" in op.output_arg_names[ - 0]: + if ( + op.type == "recv_v2" + and "gelu_0.tmp_0@GRAD" in op.output_arg_names[0] + ): recv_result = True else: for idx, op in enumerate(ops): - if op.type == "send_v2" and "gelu_0.tmp_0@GRAD" in op.input_arg_names: + if ( + op.type == "send_v2" + and "gelu_0.tmp_0@GRAD" in op.input_arg_names + ): send_result = True - if op.type == "recv_v2" and "gelu_0.tmp_0" in op.output_arg_names[0]: + if ( + op.type == "recv_v2" + and "gelu_0.tmp_0" in op.output_arg_names[0] + ): recv_result = True return send_result and recv_result +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" +) class TestMLPReshard(unittest.TestCase): - def test_mlp_serial(self): global _global_parallel_strategy _global_parallel_strategy = None @@ -173,7 +200,8 @@ class TestMLPReshard(unittest.TestCase): dist_context = get_default_distributed_context() rank_id = 0 dist_main_prog, dist_startup_prog = get_dist_prog_with_parallelizer( - train_program, startup_program, dist_context) + train_program, startup_program, dist_context + ) # send and recv should not exist in serial scene. self.assertFalse(check_send_recv_result(dist_main_prog, rank_id)) diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py new file mode 100644 index 00000000000..a1e0b65b725 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_cond_interceptor.py @@ -0,0 +1,217 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +from paddle.distributed.fleet.fleet_executor_utils import TaskNode + +paddle.enable_static() + + +def cond(i, ten, data): + return i < ten + + +def body(i, ten, data): + i = i + 1 + data = data + 1 + return [i, ten, data] + + +num_micro_batches = 4 + + +def batch_generator_creator(): + def __reader__(): + for i in range(num_micro_batches): + data = np.full(shape=[1, 1], fill_value=i, dtype=np.float32) + yield data + + return __reader__ + + +class TestFleetExecutor(unittest.TestCase): + def test_cond_interceptor(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + i = paddle.full( + shape=[1], fill_value=0, dtype='int64' + ) # loop counter + ten = paddle.full( + shape=[1], fill_value=10, dtype='int64' + ) # loop length + data = paddle.static.data(name='x', shape=[1]) + + loader = paddle.fluid.io.DataLoader.from_generator( + feed_list=[data], capacity=num_micro_batches * 4, iterable=False + ) + loader.set_batch_generator( + batch_generator_creator(), paddle.CUDAPlace(0) + ) + + paddle.static.nn.while_loop(cond, body, [i, ten, data]) + + program_a = paddle.static.Program() + program_b = paddle.static.Program() + + for var_name in main_program.block(0).vars: + if var_name != "_generated_var_0": + var = main_program.block(0).var(var_name) + if ( + var_name == "create_py_reader_0" + or var_name == "double_buffer_0" + ): + program_a.block(0).create_var( + name=var_name, + persistable=var.persistable, + ) + else: + program_a.block(0).create_var( + name=var_name, + shape=var.shape, + dtype=var.dtype, + stop_gradient=var.stop_gradient, + ) + program_b.block(0).create_var( + name=var_name, + shape=var.shape, + dtype=var.dtype, + stop_gradient=var.stop_gradient, + ) + + for op in main_program.block(0).ops: + if op.type != "while": + program_a.block(0).append_op( + type=op.type, + inputs=op.desc.inputs(), + outputs=op.desc.outputs(), + attrs=op.all_attrs(), + ) + + for var_name in main_program.block(1).vars: + var = main_program.block(1).var(var_name) + program_b.block(0).create_var( + name=var_name, + shape=var.shape, + dtype=var.dtype, + stop_gradient=var.stop_gradient, + ) + + for op in main_program.block(1).ops: + program_b.block(0).append_op( + type=op.type, + inputs=op.desc.inputs(), + outputs=op.desc.outputs(), + attrs=op.all_attrs(), + ) + + cond_var_name = "tmp_0" + + task_a = TaskNode( + 0, + num_micro_batches, + node_type="Start", + task_id=0, + program=program_a, + lazy_initialize=True, + ) + task_b = TaskNode( + 0, + num_micro_batches, + node_type="Cond", + task_id=1, + program=paddle.static.Program(), + cond_var_name=cond_var_name, + lazy_initialize=True, + ) + task_c = TaskNode( + 0, + num_micro_batches, + node_type="Compute", + task_id=2, + program=program_b, + lazy_initialize=True, + ) + task_d = TaskNode( + 0, + num_micro_batches, + node_type="Compute", + task_id=3, + program=paddle.static.Program(), + vars_to_dtype={'x': 'float32', 'tmp_1': 'int64'}, + vars_to_shape={'x': (1,), 'tmp_1': (1,)}, + lazy_initialize=True, + ) + task_e = TaskNode( + 0, + num_micro_batches, + node_type="Compute", + task_id=4, + program=paddle.static.Program(), + lazy_initialize=True, + ) + + infinite_buff_size = -1 + task_a.add_downstream_task(task_b.task_id(), 2) + task_b.add_upstream_task(task_a.task_id(), 2) + task_b.add_downstream_task(task_c.task_id(), infinite_buff_size) + task_c.add_upstream_task(task_b.task_id(), infinite_buff_size) + task_c.add_downstream_task(task_d.task_id(), 2) + task_d.add_upstream_task(task_c.task_id(), 2) + task_d.add_downstream_task( + task_b.task_id(), infinite_buff_size, core.DependType.LOOP + ) + task_b.add_upstream_task( + task_d.task_id(), infinite_buff_size, core.DependType.LOOP + ) + task_b.add_downstream_task( + task_e.task_id(), infinite_buff_size, core.DependType.STOP_LOOP + ) + task_e.add_upstream_task( + task_b.task_id(), infinite_buff_size, core.DependType.STOP_LOOP + ) + + main_program._pipeline_opt = { + "fleet_opt": { + 'tasks': [task_a, task_b, task_c, task_d, task_e], + 'task_id_to_rank': { + task_a.task_id(): 0, + task_b.task_id(): 0, + task_c.task_id(): 0, + task_d.task_id(): 0, + task_e.task_id(): 0, + }, + 'num_micro_batches': num_micro_batches, + 'inference_generation': True, + 'fetch_var': ['x'], + }, + } + + place = paddle.CUDAPlace(0) + exe = paddle.static.Executor(place) + loader.start() + res = exe.run(main_program) + ref_res = np.full([1, 1], 10, dtype="float32") + for data in res: + np.testing.assert_allclose(data, ref_res, rtol=1e-05) + ref_res = ref_res + 1 + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py index 07ecf85c3db..1d6b426bde1 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_task_node.py @@ -21,26 +21,33 @@ paddle.enable_static() class TestFleetExecutorTaskNode(unittest.TestCase): - def test_task_node(self): program = paddle.static.Program() - task_node_0 = core.TaskNode(program.desc, 0, 1, 1) + task_node_0 = core.TaskNode(program.desc, 0, 0, 1) task_node_1 = core.TaskNode(program.desc, 0, 1, 1) - task_node_2 = core.TaskNode(program.desc, 0, 1, 1) + task_node_2 = core.TaskNode(program.desc, 0, 2, 1) self.assertEqual(task_node_0.task_id(), 0) self.assertEqual(task_node_1.task_id(), 1) self.assertEqual(task_node_2.task_id(), 2) self.assertTrue( - task_node_0.add_downstream_task(task_node_1.task_id(), 1)) - self.assertTrue(task_node_1.add_upstream_task(task_node_0.task_id(), 1)) + task_node_0.add_downstream_task( + task_node_1.task_id(), 1, core.DependType.NORMAL + ) + ) + self.assertTrue( + task_node_1.add_upstream_task( + task_node_0.task_id(), 1, core.DependType.NORMAL + ) + ) def test_lazy_task_node(self): program = paddle.static.Program() - task = TaskNode(program=program, - rank=0, - max_run_times=1, - max_slot_times=1, - lazy_initialize=True) + task = TaskNode( + program=program, + rank=0, + max_run_times=1, + lazy_initialize=True, + ) task_node = task.task_node() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py b/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py index f80f998c047..398ba59539a 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_executor_with_task_nodes.py @@ -22,17 +22,16 @@ paddle.enable_static() class TestFleetExecutor(unittest.TestCase): - def run_fleet_executor(self, place, x_data, y_data): exe = paddle.static.Executor(place) empty_program = paddle.static.Program() with fluid.program_guard(empty_program, empty_program): - x = fluid.layers.data(name='x', - shape=x_data.shape, - dtype=x_data.dtype) - y = fluid.layers.data(name='y', - shape=y_data.shape, - dtype=y_data.dtype) + x = fluid.layers.data( + name='x', shape=x_data.shape, dtype=x_data.dtype + ) + y = fluid.layers.data( + name='y', shape=y_data.shape, dtype=y_data.dtype + ) z = x + y a = 2 * x + 3 * y loss = paddle.mean(a) @@ -41,11 +40,13 @@ class TestFleetExecutor(unittest.TestCase): steps_per_pass = 10 bd = [steps_per_pass * p for p in passes] lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)] - lr_val = paddle.optimizer.lr.PiecewiseDecay(boundaries=bd, - values=lr) + lr_val = paddle.optimizer.lr.PiecewiseDecay( + boundaries=bd, values=lr + ) opt = paddle.optimizer.AdamW( learning_rate=lr_val, - grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0)) + grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0), + ) opt.minimize(loss) # TODO: section_program will be removed in the future task_node = TaskNode( @@ -54,23 +55,20 @@ class TestFleetExecutor(unittest.TestCase): rank=0, node_type="Compute", max_run_times=1, - max_slot_times=1, - lazy_initialize=True) + lazy_initialize=True, + ) empty_program._pipeline_opt = { "fleet_opt": { 'tasks': [task_node], - 'task_id_to_rank': { - task_node.task_id(): 0 - } + 'task_id_to_rank': {task_node.task_id(): 0}, }, - "section_program": empty_program + "section_program": empty_program, } - res = exe.run(empty_program, - feed={ - 'x': x_data, - 'y': y_data - }, - fetch_list=[z.name, a.name]) + res = exe.run( + empty_program, + feed={'x': x_data, 'y': y_data}, + fetch_list=[z.name, a.name], + ) return res def test_executor_on_single_device(self): diff --git a/python/paddle/tensor/stat.py b/python/paddle/tensor/stat.py index 99dd877e480..a8d3a89460c 100644 --- a/python/paddle/tensor/stat.py +++ b/python/paddle/tensor/stat.py @@ -182,6 +182,7 @@ def var(x, axis=None, unbiased=True, keepdim=False, name=None): if unbiased: one_const = paddle.ones([1], x.dtype) n = where(n > one_const, n - 1.0, one_const) + n.stop_gradient = True out /= n return out -- GitLab