diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 6ecc02bbae61697c9d8f1c4eea11fe7210884569..ae0a6a2bc4e719b79813c29705986322247e9628 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -28,6 +28,7 @@ limitations under the License. */ #include #include "paddle/fluid/framework/data_feed.h" +#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/framework/heter_service.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_registry.h" @@ -451,7 +452,7 @@ class HeterBoxWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -550,7 +551,7 @@ class PSGPUWorker : public HogwildWorker { virtual void CacheProgram(const ProgramDesc& main_program) { new (&program_) ProgramDesc(main_program); } - virtual void ProduceTasks() override; + void ProduceTasks() override; virtual void SetStream(const cudaStream_t stream) { copy_stream_ = stream; } virtual void SetEvent(const cudaEvent_t event) { event_ = event; } virtual void TrainFilesWithProfiler() {} @@ -654,6 +655,9 @@ class SectionWorker : public DeviceWorker { void SetDeviceIndex(int tid) override {} void SetThreadIndex(int thread_id) { thread_id_ = thread_id; } void SetMicrobatchNum(int num) { num_microbatches_ = num; } + void SetPipelineStageNum(int num) { num_pipeline_stages_ = num; } + void SetPipelineStage(int stage) { pipeline_stage_ = stage; } + void SetScheduleMode(int mode) { schedule_mode_ = mode; } void SetMicrobatchScopes(const std::vector& scope) { microbatch_scopes_ = scope; } @@ -661,11 +665,23 @@ class SectionWorker : public DeviceWorker { void SetSkipVars(const std::vector& skip_vars) { skip_vars_ = skip_vars; } + void RunBackward( + int micro_id, std::unique_ptr&, + std::unordered_map>&); + void RunForward( + int micro_id, std::unique_ptr&, + std::unordered_map>&); + void RunUpdate( + std::unique_ptr&, + std::unordered_map>&); protected: int section_id_; int thread_id_; int num_microbatches_; + int num_pipeline_stages_; + int pipeline_stage_; + int schedule_mode_; // 0 for GPipe and 1 for deepspeed std::vector microbatch_scopes_; std::vector skip_vars_; const Scope* minibatch_scope_; diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto old mode 100644 new mode 100755 index 7cf8d55aeeb1d99acd2f501461f0563f87a25e78..aae9515a565e2ffed2965628c9c19cd8e6f6e6c0 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -32,6 +32,14 @@ message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; optional bool hybrid_dp = 2 [ default = false ]; optional int32 sharding_group_size = 3 [ default = 8 ]; + optional bool as_outer_parallelism = 4 [ default = false ]; + optional int32 parallelism = 5 [ default = 1 ]; + optional bool use_pipeline = 6 [ default = false ]; + optional int32 acc_steps = 7 [ default = 1 ]; + optional int32 schedule_mode = 8 [ default = 0 ]; + optional int32 pp_bz = 9 [ default = 1 ]; + optional bool pp_allreduce_in_optimize = 10 [ default = false ]; + optional bool optimize_offload = 11 [ default = false ]; } message AMPConfig { @@ -44,6 +52,8 @@ message AMPConfig { repeated string custom_white_list = 7; repeated string custom_black_list = 8; repeated string custom_black_varnames = 9; + optional bool use_pure_fp16 = 10 [ default = false ]; + optional bool use_fp16_guard = 11 [ default = true ]; } message LocalSGDConfig { @@ -117,6 +127,8 @@ message AsyncConfig { message PipelineConfig { optional int32 micro_batch = 1 [ default = 1 ]; } +message ModelParallelConfig { optional int32 parallelism = 1 [ default = 1 ]; } + message DistributedStrategy { // bool options optional Mode mode = 1 [ default = COLLECTIVE ]; @@ -140,12 +152,13 @@ message DistributedStrategy { optional int32 fuse_grad_size_in_MB = 19 [ default = 32 ]; optional float fuse_grad_size_in_TFLOPS = 20 [ default = 50 ]; optional bool cudnn_exhaustive_search = 21 [ default = true ]; - optional int32 conv_workspace_size_limit = 22 [ default = 4000 ]; + optional int32 conv_workspace_size_limit = 22 [ default = 512 ]; optional bool cudnn_batchnorm_spatial_persistent = 23 [ default = true ]; optional bool adaptive_localsgd = 24 [ default = false ]; optional bool fp16_allreduce = 25 [ default = false ]; optional bool sharding = 26 [ default = false ]; optional float last_comm_group_size_MB = 27 [ default = 1 ]; + optional bool model_parallel = 28 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -158,6 +171,7 @@ message DistributedStrategy { optional LambConfig lamb_configs = 109; optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional ShardingConfig sharding_configs = 111; + optional ModelParallelConfig model_parallel_configs = 112; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; } diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 58e09203299e87e6b0ec21042b4af1fb26e38d0a..dbcc993aee8272dbf01addf2cb235e35ee4c51d2 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -25,6 +25,9 @@ namespace framework { void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, Dataset* dataset) { const auto& section_params = trainer_desc.section_param(); + const auto num_pipeline_stages_ = section_params.num_pipeline_stages(); + const auto pipeline_stage_ = section_params.pipeline_stage(); + const auto schedule_mode_ = section_params.schedule_mode(); num_microbatches_ = section_params.num_microbatches(); VLOG(3) << "Number of microbatches per minibatch: " << num_microbatches_; trainer_desc_ = trainer_desc; @@ -40,6 +43,9 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetPlace(place_); this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetPipelineStageNum(num_pipeline_stages_); + this_worker->SetPipelineStage(pipeline_stage_); + this_worker->SetScheduleMode(schedule_mode_); } void PipelineTrainer::InitOtherEnv(const ProgramDesc& main_program) { @@ -76,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id, for (auto& var : global_block.AllVars()) { bool is_param_grad = false; size_t pos = 0; - if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos) { + // A magic suffix to indicated the merged gradient. + std::string magicSuffix = "MERGED"; + if ((pos = var->Name().find(kGradVarSuffix)) != std::string::npos && + var->Name().find(magicSuffix) != std::string::npos) { auto prefix_name = var->Name().substr(0, pos); if (param_map.find(prefix_name) != param_map.end()) { is_param_grad = true; diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 6634cb98d67413087f6a9acb4bac3378bf15dcab..87bd2ebad2afff2107b2e20bbc0a1aa139c3de9d 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -11,36 +11,90 @@ limitations under the License. */ #if defined(PADDLE_WITH_NCCL) #include -#include "paddle/fluid/framework/executor_gc_helper.h" -#include "paddle/fluid/framework/garbage_collector.h" -#include "paddle/fluid/framework/program_desc.h" - -#include "google/protobuf/io/zero_copy_stream_impl.h" -#include "google/protobuf/message.h" -#include "google/protobuf/text_format.h" - #include "paddle/fluid/framework/device_worker.h" -#include "paddle/fluid/framework/fleet/box_wrapper.h" -#include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/framework/trainer_desc.pb.h" -#include "paddle/fluid/platform/cpu_helper.h" +#include "paddle/fluid/framework/executor_gc_helper.h" #include "paddle/fluid/platform/device_context.h" -#include "paddle/fluid/platform/lodtensor_printer.h" namespace paddle { namespace framework { +class TrainerDesc; + uint64_t SectionWorker::batch_id_(0); -void SectionWorker::Initialize(const TrainerDesc& desc) { +void SectionWorker::Initialize(const TrainerDesc &desc) { dev_ctx_ = platform::DeviceContextPool::Instance().Get(place_); program_.reset( new ProgramDesc(desc.section_param().section_config().program_desc())); - for (auto& op_desc : program_->Block(0).AllOps()) { + for (auto &op_desc : program_->Block(0).AllOps()) { ops_.push_back(OpRegistry::CreateOp(*op_desc)); } } +void SectionWorker::RunForward( + int micro_id, std::unique_ptr &gc, + std::unordered_map> + &unused_vars_) { + for (auto &op : ops_) { + int op_role = op->Attr(std::string("op_role")); + // We run op with op_role = kLRSched only for the first microbatch + // to avoid increasing the @LR_DECAY_STEP@ multiple times. + bool run_first_mbatch = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)) || + op_role == static_cast(OpRole::kLRSched); + bool run_others = op_role == static_cast(OpRole::kForward) || + op_role == (static_cast(OpRole::kForward) | + static_cast(OpRole::kLoss)); + if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) { + VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " + << micro_id; + op->Run(*microbatch_scopes_[micro_id], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(), + unused_vars_, gc.get()); + } + } + } +} + +void SectionWorker::RunBackward( + int micro_id, std::unique_ptr &gc, + std::unordered_map> + &unused_vars_) { + for (auto &op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kBackward) || + op_role == (static_cast(OpRole::kBackward) | + static_cast(OpRole::kLoss))) { + VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " + << micro_id; + op->Run(*microbatch_scopes_[micro_id], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[micro_id], op.get(), + unused_vars_, gc.get()); + } + } + } +} + +void SectionWorker::RunUpdate( + std::unique_ptr &gc, + std::unordered_map> + &unused_vars_) { + for (auto &op : ops_) { + int op_role = op->Attr(std::string("op_role")); + if (op_role == static_cast(OpRole::kOptimize)) { + VLOG(3) << "Update: running op " << op->Type(); + op->Run(*microbatch_scopes_[num_microbatches_ - 1], place_); + if (gc) { + DeleteUnusedTensors(*microbatch_scopes_[num_microbatches_ - 1], + op.get(), unused_vars_, gc.get()); + } + } + } +} + void SectionWorker::TrainFiles() { VLOG(5) << "begin section_worker TrainFiles"; @@ -58,61 +112,49 @@ void SectionWorker::TrainFiles() { #endif } - for (int i = 0; i < num_microbatches_; ++i) { - for (auto& op : ops_) { - int op_role = op->Attr(std::string("op_role")); - // We run op with op_role = kLRSched only for the first microbatch - // to avoid increasing the @LR_DECAY_STEP@ multiple times. - bool run_first_mbatch = op_role == static_cast(OpRole::kForward) || - op_role == (static_cast(OpRole::kForward) | - static_cast(OpRole::kLoss)) || - op_role == static_cast(OpRole::kLRSched); - bool run_others = op_role == static_cast(OpRole::kForward) || - op_role == (static_cast(OpRole::kForward) | - static_cast(OpRole::kLoss)); - if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { - VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch " - << i; - op->Run(*microbatch_scopes_[i], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); - } - } + if (schedule_mode_ == 0) { + // Gpipe scheduler which runs all forwards first, then backwards, then + // update + // step1: run forward + for (int i = 0; i < num_microbatches_; ++i) { + RunForward(i, gc, unused_vars_); } - cudaDeviceSynchronize(); - } - - // backward pass - for (int i = 0; i < num_microbatches_; ++i) { - for (auto& op : ops_) { - int op_role = op->Attr(std::string("op_role")); - if (op_role == static_cast(OpRole::kBackward) || - op_role == (static_cast(OpRole::kBackward) | - static_cast(OpRole::kLoss))) { - VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch " - << i; - op->Run(*microbatch_scopes_[i], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[i], op.get(), unused_vars_, - gc.get()); - } - } + // step2: run backward + for (int i = 0; i < num_microbatches_; ++i) { + RunBackward(i, gc, unused_vars_); + } + // step2: run update + RunUpdate(gc, unused_vars_); + } else { + // 1F1B scheduler + auto startup_steps = num_pipeline_stages_ - pipeline_stage_ - 1; + VLOG(3) << "startup_steps:" << startup_steps + << ", num_stages: " << num_pipeline_stages_ + << ", stage:" << pipeline_stage_; + if (startup_steps > num_microbatches_) { + startup_steps = num_microbatches_; + } + int fw_step = 0; + int bw_step = 0; + // startup phase + while (fw_step < startup_steps) { + RunForward(fw_step, gc, unused_vars_); + fw_step += 1; } - cudaDeviceSynchronize(); - } - // update pass - for (auto& op : ops_) { - int op_role = op->Attr(std::string("op_role")); - if (op_role == static_cast(OpRole::kOptimize)) { - VLOG(3) << "Update: running op " << op->Type(); - op->Run(*microbatch_scopes_[0], place_); - if (gc) { - DeleteUnusedTensors(*microbatch_scopes_[0], op.get(), unused_vars_, - gc.get()); - } + // 1f1b phase + while (fw_step < num_microbatches_) { + RunForward(fw_step, gc, unused_vars_); + fw_step += 1; + RunBackward(bw_step, gc, unused_vars_); + bw_step += 1; + } + // backward phase + while (bw_step < num_microbatches_) { + RunBackward(bw_step, gc, unused_vars_); + bw_step += 1; } + RunUpdate(gc, unused_vars_); } dev_ctx_->Wait(); ++batch_id_; diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index 70481cf3727012e4cf41d235154eb277d92cc92f..504885ff5ccbce760c0a659aedabef6790de5f1a 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -93,6 +93,9 @@ message SectionWorkerParameter { optional int32 start_cpu_core_id = 4 [ default = 1 ]; repeated string param_need_sync = 5; optional int32 num_microbatches = 6; + optional int32 num_pipeline_stages = 7 [ default = 1 ]; + optional int32 pipeline_stage = 8 [ default = 1 ]; + optional int32 schedule_mode = 9 [ default = 0 ]; } message SectionConfig { diff --git a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc index 18f2a40f3ddd0b8726f698e0484dc00c7302cf72..9446c38dcba32b79f8ecb3b4d5c915d10732620e 100644 --- a/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc +++ b/paddle/fluid/operators/collective/c_gen_nccl_id_op.cc @@ -55,7 +55,8 @@ class CGenNCCLIdOp : public framework::OperatorBase { SendBroadCastNCCLID(endpoint_list, 1, func, local_scope); } else { std::string endpoint = Attr("endpoint"); - RecvBroadCastNCCLID(endpoint, 1, func, local_scope); + int server_fd = platform::SocketServer::GetInstance(endpoint).socket(); + platform::RecvBroadCastCommID(server_fd, endpoint, &nccl_ids); } scope.DeleteScope(&local_scope); } @@ -71,8 +72,7 @@ class CGenNCCLIdOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void RunImpl(const framework::Scope& scope, - const platform::Place& dev_place) const override { - } + const platform::Place& dev_place) const override {} }; #endif diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc index a0df244000be26b73378be04eecb35e4c8c2bf39..94f471e4456778ce6d8ffbf600c6c5733a75edda 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.cc @@ -31,7 +31,9 @@ limitations under the License. */ #include "paddle/fluid/string/split.h" namespace paddle { -namespace operators { +namespace platform { + +std::once_flag SocketServer::init_flag_; constexpr char COMM_HEAD[] = "_pd_gen_comm_id_"; @@ -340,5 +342,34 @@ void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, CloseSocket(client); } -} // namespace operators +SocketServer& SocketServer::GetInstance(const std::string& end_point) { + static SocketServer instance; + std::call_once(init_flag_, [&]() { + instance.server_fd_ = CreateListenSocket(end_point); + instance.end_point_ = end_point; + }); + PADDLE_ENFORCE_NE(instance.server_fd_, -1, + platform::errors::Unavailable( + "listen socket failed with end_point=%s", end_point)); + PADDLE_ENFORCE_EQ(instance.end_point_, end_point, + platform::errors::InvalidArgument( + "old end_point=%s must equal with new end_point=%s", + instance.end_point_, end_point)); + return instance; +} + +/// template instantiation +#define INSTANT_TEMPLATE(Type) \ + template void SendBroadCastCommID(std::vector servers, \ + std::vector * nccl_ids); \ + template void RecvBroadCastCommID(std::string endpoint, \ + std::vector * nccl_ids); + +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +INSTANT_TEMPLATE(ncclUniqueId) +#endif +#ifdef PADDLE_WITH_XPU_BKCL +INSTANT_TEMPLATE(BKCLUniqueId) +#endif +} // namespace platform } // namespace paddle diff --git a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h index 38751805191e3e300e8ba2d1762c31c39427a3df..8db9bcee4da5e92499c0a0f920ad41ae05c7af4c 100644 --- a/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h +++ b/paddle/fluid/operators/collective/gen_nccl_id_op_helper.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include +#include +#include // NOLINT #include #include @@ -25,7 +27,7 @@ class Scope; } // namespace paddle namespace paddle { -namespace operators { +namespace platform { int CreateListenSocket(const std::string& ep); @@ -41,8 +43,26 @@ void RecvBroadCastNCCLID(std::string endpoint, int nccl_comm_num, const framework::Scope& scope); // recv nccl id from socket -void RecvBroadCastNCCLID(int server_fd, std::string endpoint, int nccl_comm_num, - std::function func, - const framework::Scope& scope); -} // namespace operators +template +void RecvBroadCastCommID(int server_fd, std::string endpoint, + std::vector* nccl_ids); + +class SocketServer { + public: + SocketServer() = default; + + ~SocketServer() { CloseSocket(server_fd_); } + + int socket() const { return server_fd_; } + + static SocketServer& GetInstance(const std::string& end_point); + + private: + int server_fd_{-1}; + std::string end_point_; + + static std::once_flag init_flag_; +}; + +} // namespace platform } // namespace paddle diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index f7a28f15e9b70be3280ce29eb97487a238e78ce6..3bad28bbd148aba2e794ec7d163d4f26865c3644 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -736,6 +736,60 @@ class DistributedStrategy(object): "sharding_configs") assign_configs_value(self.strategy.sharding_configs, configs) + @property + def model_parallel(self): + """ + Indicating whether we are using model parallel parallelism for distributed training. + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.model_parallel = True + + """ + return self.strategy.model_parallel + + @model_parallel.setter + @is_strict_auto + def model_parallel(self, flag): + if isinstance(flag, bool): + self.strategy.model_parallel = flag + else: + print("WARNING: model_parallel should have value of bool type") + + @property + def model_parallel_configs(self): + """ + Set model_parallel parallelism configurations. + + **Notes**: + **Detailed arguments for model_parallel_configs** + + **parallelism**: degree of model parallel + + Examples: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.model_parallel = True + strategy.model_parallel_configs = {"parallelism": 12} + + """ + + return get_msg_dict(self.strategy.model_parallel_configs) + + @model_parallel_configs.setter + @is_strict_auto + def model_parallel_configs(self, configs): + check_configs_key(self.strategy.model_parallel_configs, configs, + "model_parallel_configs") + assign_configs_value(self.strategy.model_parallel_configs, configs) + @property def pipeline(self): """ diff --git a/python/paddle/distributed/fleet/meta_optimizers/__init__.py b/python/paddle/distributed/fleet/meta_optimizers/__init__.py index cdc8162f6dee54db24007b4485706b57545aea54..eb4dcc8ef843a8adbf0c9848817e2eb3d7ef6a31 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/__init__.py +++ b/python/paddle/distributed/fleet/meta_optimizers/__init__.py @@ -17,6 +17,7 @@ from .gradient_merge_optimizer import GradientMergeOptimizer from .graph_execution_optimizer import GraphExecutionOptimizer from .parameter_server_optimizer import ParameterServerOptimizer from .pipeline_optimizer import PipelineOptimizer +from .model_parallel_optimizer import ModelParallelOptimizer from .localsgd_optimizer import LocalSGDOptimizer from .localsgd_optimizer import AdaptiveLocalSGDOptimizer from .lars_optimizer import LarsOptimizer diff --git a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py index c751e229cbbe2b900ead900297ff9956946b9e75..cf6962357cb36b9399832bf5cbfcf97a8eceb588 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/amp_optimizer.py @@ -50,15 +50,17 @@ class AMPOptimizer(MetaOptimizerBase): self.inner_opt, amp_lists, config['init_loss_scaling'], config['incr_every_n_steps'], config['decr_every_n_nan_or_inf'], config['incr_ratio'], config['decr_ratio'], - config['use_dynamic_loss_scaling']) + config['use_dynamic_loss_scaling'], config['use_pure_fp16'], + config['use_fp16_guard']) # if worker_num > 1, all cards will communication with each other, # add is_distributed to optimize amp, overlap communication and # computation by split the check_finite_and_unscale op. is_distributed = self.role_maker._worker_num() > 1 - if self.user_defined_strategy.sharding: - # FIXME(wangxi). sharding failed when split check_finite_and_unscale - is_distributed = False + #if self.user_defined_strategy.sharding or self.user_defined_strategy.model_parallel: + # # FIXME(wangxi). sharding failed when split check_finite_and_unscale + # # FIXME(JZ-LIANG). To support Sharding-Megatron-AMP, Megatron should follow Sharding's behavior + # is_distributed = False self.wrapped_opt._set_distributed(is_distributed) def _can_apply(self): @@ -112,3 +114,11 @@ class AMPOptimizer(MetaOptimizerBase): self.wrapped_opt.minimize(loss, startup_program, parameter_list, no_grad_set) return optimize_ops, params_grads + + def amp_init(self, + place, + scope=None, + test_program=None, + use_fp16_test=False): + return self.wrapped_opt.amp_init(place, scope, test_program, + use_fp16_test) diff --git a/python/paddle/distributed/fleet/meta_optimizers/common.py b/python/paddle/distributed/fleet/meta_optimizers/common.py index 9befdfff04bef14e756f41681f838ea0959d3759..9e1ccc5f82752806201cc0676909b778d03dc5d4 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/common.py +++ b/python/paddle/distributed/fleet/meta_optimizers/common.py @@ -25,6 +25,24 @@ OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() OP_ROLE_VAR_KEY = core.op_proto_and_checker_maker.kOpRoleVarAttrName() +class Topology: + """A 4-D structure to describe the process group.""" + + def __init__(self, axes, dims): + pass + + +class ParallelGrid: + """Initialize each process group.""" + + def __init__(self, topology): + self.build_global_group() + self.build_mp_group() + self.build_sharding_group() + self.build_pp_group() + self.build_dp_group() + + def is_update_op(op): return 'Param' in op.input_names and 'Grad' in op.input_names and \ "LearningRate" in op.input_names @@ -66,16 +84,49 @@ class CollectiveHelper(object): self.role_maker._worker_index(), ring_id, self.wait_port) self._broadcast_params() - def _init_communicator(self, program, current_endpoint, endpoints, rank, - ring_id, wait_port): + def _init_communicator(self, + program, + current_endpoint, + endpoints, + rank, + ring_id, + wait_port, + sync=True): nranks = len(endpoints) other_endpoints = endpoints[:] other_endpoints.remove(current_endpoint) block = program.global_block() if core.is_compiled_with_cuda(): - if rank == 0 and wait_port: - wait_server_ready(other_endpoints) - nccl_id_var = block.create_var( + if not wait_port and sync: + temp_var = block.create_var( + name=unique_name.generate('temp_var'), + dtype=core.VarDesc.VarType.INT32, + persistable=False, + stop_gradient=True) + block.append_op( + type='fill_constant', + inputs={}, + outputs={'Out': [temp_var]}, + attrs={ + 'shape': [1], + 'dtype': temp_var.dtype, + 'value': 1, + 'force_cpu': False, + OP_ROLE_KEY: OpRole.Forward + }) + block.append_op( + type='c_allreduce_sum', + inputs={'X': [temp_var]}, + outputs={'Out': [temp_var]}, + attrs={'ring_id': 3, + OP_ROLE_KEY: OpRole.Forward}) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': temp_var}, + outputs={'Out': temp_var}, + attrs={'ring_id': 3, + OP_ROLE_KEY: OpRole.Forward}) + comm_id_var = block.create_var( name=unique_name.generate('nccl_id'), persistable=True, type=core.VarDesc.VarType.RAW) @@ -100,9 +151,7 @@ class CollectiveHelper(object): OP_ROLE_KEY: OpRole.Forward }) elif core.is_compiled_with_npu(): - endpoint_to_index_map = { - e: idx for idx, e in enumerate(endpoints) - } + endpoint_to_index_map = {e: idx for idx, e in enumerate(endpoints)} block.append_op( type='c_comm_init_hcom', inputs={}, diff --git a/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..1511769350477094f858688f41ddc31593e4b39a --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/model_parallel_optimizer.py @@ -0,0 +1,281 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +from __future__ import print_function +from __future__ import division + +import paddle.fluid as fluid +from paddle.fluid import core, unique_name +from ..base.private_helper_function import wait_server_ready +from .meta_optimizer_base import MetaOptimizerBase +from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op + + +class ModelParallelHelper(object): + def __init__(self, role_maker, wait_port=True, megatron_dp=False): + self.wait_port = wait_port + self.role_maker = role_maker + self.megatron_dp = megatron_dp + + def update_startup_program(self, + startup_program=None, + inner_parallelism=None): + self.startup_program = startup_program + + nranks = self.role_maker._worker_num() + rank = self.role_maker._worker_index() + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[rank] + + # Create ring 0 for all model parallel parts within a single model + mp_endpoints = [] + mp_rank = rank % inner_parallelism + mp_id = rank // inner_parallelism + for idx, ep in enumerate(endpoints): + if idx // inner_parallelism == mp_id: + mp_endpoints.append(ep) + print("model parallel eps:{}, rank{}".format(mp_endpoints, mp_rank)) + self._init_communicator(self.startup_program, current_endpoint, + mp_endpoints, mp_rank, 0, self.wait_port) + self._broadcast_params(0, broadcast_distributed_weight=False) + + print("megatron group size: {}".format(inner_parallelism)) + print("megatron rank: {}".format(mp_rank)) + print("megatron endpoints: {}".format(mp_endpoints)) + + if self.megatron_dp: + mp_num = len(endpoints) // inner_parallelism + if mp_num == 1: return + # Create rings for gpus as the same model parallel part + eps = [] + dp_rank = rank // inner_parallelism + dp_id = rank % inner_parallelism + #if dp_rank == 1: dp_rank =0 + #if dp_rank == 0: dp_rank =1 + ring_id = 1 + for idx, ep in enumerate(endpoints): + if idx % inner_parallelism == dp_id: + eps.append(ep) + #ep = eps.pop(0) + #eps.insert(1, ep) + print("data parallel eps:{}, rank{}".format(eps, dp_rank)) + self._init_communicator(self.startup_program, current_endpoint, eps, + dp_rank, ring_id, self.wait_port) + self._broadcast_params(ring_id, broadcast_distributed_weight=True) + + def _init_communicator(self, program, current_endpoint, endpoints, rank, + ring_id, wait_port): + nranks = len(endpoints) + other_endpoints = endpoints[:] + other_endpoints.remove(current_endpoint) + if rank == 0 and wait_port: + wait_server_ready(other_endpoints) + + block = program.global_block() + nccl_id_var = block.create_var( + name=unique_name.generate('nccl_id'), + persistable=True, + type=core.VarDesc.VarType.RAW) + block.append_op( + type='c_gen_nccl_id', + inputs={}, + outputs={'Out': nccl_id_var}, + attrs={ + 'rank': rank, + 'endpoint': current_endpoint, + 'other_endpoints': other_endpoints, + OP_ROLE_KEY: OpRole.Forward, + }) + block.append_op( + type='c_comm_init', + inputs={'X': nccl_id_var}, + outputs={}, + attrs={ + 'nranks': nranks, + 'rank': rank, + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward, + }) + + def _broadcast_params(self, ring_id, broadcast_distributed_weight): + block = self.startup_program.global_block() + for param in block.iter_parameters(): + if not broadcast_distributed_weight and param.is_distributed: + continue + + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) + + +class ModelParallelOptimizer(MetaOptimizerBase): + def __init__(self, optimizer): + super(ModelParallelOptimizer, self).__init__(optimizer) + self.inner_opt = optimizer + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + "LarsOptimizer", + "LambOptimizer", + ] + self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] + self.megatron_dp = False + + def _set_basic_info(self, loss, role_maker, user_defined_optimizer, + user_defined_strategy): + super(ModelParallelOptimizer, self)._set_basic_info( + loss, role_maker, user_defined_optimizer, user_defined_strategy) + self.inner_parallelism = user_defined_strategy.model_parallel_configs[ + 'parallelism'] + + def _can_apply(self): + if not self.role_maker._is_collective: + return False + + if self.user_defined_strategy.model_parallel == True: + return True + return False + + def _disable_strategy(self, dist_strategy): + dist_strategy.model_parallel = False + dist_strategy.model_parallel_configs = {} + + def _enable_strategy(self, dist_strategy, context): + dist_strategy.model_parallel = True + dist_strategy.model_parallel_configs = {"parallelism": 1, } + + # the following function will be used by AMP if both Megatron and AMP are turn on together. + def apply_gradients(self, params_grads): + return self.minimize_impl(params_grads=params_grads) + + def minimize_impl(self, + loss, + startup_program=None, + parameter_list=None, + no_grad_set=None): + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + self.startup_program = startup_program + if startup_program is None: + self.startup_program = fluid.default_startup_program() + + # (TODO) check the order of metaoptimizer + # (TODO) check the params_grads + optimize_ops, params_grads = self.inner_opt.minimize( + loss, self.startup_program, parameter_list, no_grad_set) + + self.main_program = loss.block.program + self.inner_parallelism = self.inner_parallelism + self.nranks = len(endpoints) + + pipeline_helper = ModelParallelHelper(self.role_maker) + pipeline_helper.update_startup_program(self.startup_program, + self.inner_parallelism) + + assert self.nranks % self.inner_parallelism == 0 + + if self.megatron_dp: + # data parallelism + dp_parallelism = self.nranks // self.inner_parallelism + + self._transpile_main_program(loss, dp_parallelism) + return optimize_ops, params_grads + + def _transpile_main_program(self, loss, dp_parallelism): + self._insert_loss_grad_ops(loss, dp_parallelism) + ring_id = 1 + print("ring_id: ", ring_id) + # for ring_id in range(1, dp_parallelism + 1): + self._insert_allreduce_ops(loss, ring_id) + + def _insert_loss_grad_ops(self, loss, dp_parallelism): + """ + In order to keep the learning rate consistent in different numbers of + training workers, we scale the loss grad by the number of workers + """ + block = loss.block + for idx, op in reversed(list(enumerate(block.ops))): + if is_loss_grad_op(op): + loss_grad_var = block.vars[op.output_arg_names[0]] + block._insert_op( + idx + 1, + type='scale', + inputs={'X': loss_grad_var}, + outputs={'Out': loss_grad_var}, + attrs={ + 'scale': 1.0 / dp_parallelism, + OP_ROLE_KEY: OpRole.Backward + }) + + def _insert_allreduce_ops(self, loss, ring_id): + block = loss.block + grad = None + for idx, op in reversed(list(enumerate(block.ops))): + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) == 0: + continue + assert len(op_role_var) % 2 == 0 + offset = idx + for i in range(0, len(op_role_var), 2): + param = block.vars[op_role_var[i]] + grad = block.vars[op_role_var[i + 1]] + #if param.is_distributed: + # continue + if offset == idx: + offset += 1 + block._insert_op( + offset, + type='c_sync_calc_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={OP_ROLE_KEY: OpRole.Backward}) + offset += 1 + + block._insert_op( + offset, + type='c_allreduce_sum', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={ + 'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward + }) + + if grad is None: + return + + for idx, op in list(enumerate(block.ops)): + if is_optimizer_op(op): + block._insert_op( + idx, + type='c_sync_comm_stream', + inputs={'X': grad}, + outputs={'Out': grad}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Backward}) + break diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index da8adf47b854bf3cf74eab712088ad1d481face3..779b7534494c2ddb2c346c79717ab0ff4d8ca15f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -154,8 +154,10 @@ class PipelineOptimizer(MetaOptimizerBase): def __init__(self, optimizer): super(PipelineOptimizer, self).__init__(optimizer) self.inner_opt = optimizer - # we do not allow meta optimizer to be inner optimizer currently - self.meta_optimizers_white_list = [] + self.meta_optimizers_white_list = [ + "RecomputeOptimizer", + "AMPOptimizer", + ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] def _set_basic_info(self, loss, role_maker, user_defined_optimizer, diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py index 03b36262a4fb1e095eb17fa57bf27b5c9f3cf74c..788d4e526352a9197bd47f791a2df4e90344739e 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/fp16_helper.py @@ -73,7 +73,7 @@ class FP16Utils(object): @staticmethod def prune_fp16(block, shard, reduced_grads_to_param, ring_id): """ - 1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard + 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard 2. revise amp inifine grad checking for sharding """ # remove cast @@ -81,7 +81,9 @@ class FP16Utils(object): if not FP16Utils.is_fp32_cast_op(block, op): continue output_name = op.desc.output_arg_names()[0] - param_name = output_name.strip("@GRAD") + param_name = output_name.strip( + "@GRAD@MERGED" + ) if "@MERGED" in output_name else output_name.strip("@GRAD") if param_name not in shard.global_params: raise ValueError("Output 'X' of cast_op must be a grad of" "model param, but {} is not a grad".format( @@ -103,20 +105,35 @@ class FP16Utils(object): op._rename_input(inf_var_name, inf_var_name + "@sharding") if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: reversed_x = [] + reversed_x_paramname = [] for input_name in op.desc.input('X'): - param_name = input_name.strip("@GRAD") + param_name = input_name.strip("@GRAD@MERGED") if param_name not in shard.global_params: raise ValueError( "Input 'X' of check_finite_and_unscale must" "be grads, but {} is not a grad".format(input_name)) if shard.has_param(param_name): reversed_x.append(input_name) + reversed_x_paramname.append(param_name) op.desc.set_input('X', reversed_x) op.desc.set_output('Out', reversed_x) + + # the grad checking should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection( + set([ + param + for param, worker_idx in shard.global_param2device. + items() if worker_idx == shard.worker_idx + ])) + #assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( + # should_check_param - to_check_param, + # to_check_param - should_check_param) + if update_loss_scaling_op_idx == -1: return inf_var = block.var(inf_var_name) - inf_var_fp32 = block.create_var( + inf_var_int32 = block.create_var( name=inf_var_name + "@cast_int32", shape=inf_var.shape, dtype=core.VarDesc.VarType.INT32) @@ -128,32 +145,36 @@ class FP16Utils(object): update_loss_scaling_op_idx, type='cast', inputs={'X': inf_var}, - outputs={'Out': inf_var_fp32}, + outputs={'Out': inf_var_int32}, attrs={ "in_dtype": inf_var.dtype, - "out_dtype": inf_var_fp32.dtype, + "out_dtype": inf_var_int32.dtype, OP_ROLE_KEY: OpRole.Optimize }) - insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, - [inf_var_fp32]) + # this allreduce communication should not overlap with calc + # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, + # [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 2, + update_loss_scaling_op_idx + 1, type='c_allreduce_max', - inputs={'X': inf_var_fp32}, - outputs={'Out': inf_var_fp32}, - attrs={'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Optimize}) + inputs={'X': inf_var_int32}, + outputs={'Out': inf_var_int32}, + attrs={ + 'ring_id': ring_id, + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize + }) - comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, - ring_id, [inf_var_fp32]) + # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, + # ring_id, [inf_var_int32]) block._insert_op_without_sync( - update_loss_scaling_op_idx + 3 + comm_op_num, + update_loss_scaling_op_idx + 2, type='cast', - inputs={'X': inf_var_fp32}, + inputs={'X': inf_var_int32}, outputs={'Out': inf_var_sharding}, attrs={ - "in_dtype": inf_var_fp32.dtype, + "in_dtype": inf_var_int32.dtype, "out_dtype": inf_var_sharding.dtype, OP_ROLE_KEY: OpRole.Optimize }) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py index c6aee792fcf745a6ec51b3c4d1945415bfd9324f..961a789dc081a0c04a32741526a9426a53d72e18 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/gradient_clip_helper.py @@ -16,8 +16,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole class GradientClipHelper(object): - def __init__(self, sharding_ring_id): - self.sharding_ring_id = sharding_ring_id + def __init__(self, mp_ring_id): + self.mp_ring_id = mp_ring_id def _is_gradient_clip_op(self, op): return op.desc.has_attr("op_namescope") \ @@ -31,6 +31,7 @@ class GradientClipHelper(object): """ deperated_vars = set() deperate_op_idx = set() + reversed_x_paramname = [] for idx, op in enumerate(block.ops): if not self._is_gradient_clip_op(op): continue @@ -40,15 +41,18 @@ class GradientClipHelper(object): for input_name in op.desc.input_arg_names(): if input_name in deperated_vars: deperate_op = True - param_name = input_name.strip("@GRAD") + param_name = input_name.strip("@GRAD@MERGED") if shard.is_param(param_name) and \ not shard.has_param(param_name): deperate_op = True + elif shard.is_param(param_name): + reversed_x_paramname.append(param_name) if deperate_op: deperate_op_idx.add(idx) for output_name in op.desc.output_arg_names(): - deperated_vars.add(output_name) + if output_name not in op.desc.input_arg_names(): + deperated_vars.add(output_name) if not deperated_vars: # got no gradient_clip op @@ -65,31 +69,47 @@ class GradientClipHelper(object): for input_name in op.desc.input_arg_names(): if input_name not in deperated_vars: reversed_inputs.append(input_name) + op.desc.set_input("X", reversed_inputs) assert (len(op.desc.output_arg_names()) == 1) sum_res = op.desc.output_arg_names()[0] - block._insert_op_without_sync( - idx + 1, - type='c_sync_comm_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={'ring_id': 0, - OP_ROLE_KEY: OpRole.Optimize}) + + # this allreduce should not overlap with calc and should be scheduled in calc stream + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_comm_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={'ring_id': 0, + # OP_ROLE_KEY: OpRole.Optimize}) block._insert_op_without_sync( idx + 1, type='c_allreduce_sum', inputs={'X': sum_res}, outputs={'Out': sum_res}, attrs={ - 'ring_id': self.sharding_ring_id, - OP_ROLE_KEY: OpRole.Optimize + 'ring_id': self.mp_ring_id, + 'op_namescope': "/gradient_clip_model_parallelism", + 'use_calc_stream': True, + OP_ROLE_KEY: OpRole.Optimize, }) - block._insert_op_without_sync( - idx + 1, - type='c_sync_calc_stream', - inputs={'X': sum_res}, - outputs={'Out': sum_res}, - attrs={OP_ROLE_KEY: OpRole.Optimize}) + # block._insert_op_without_sync( + # idx + 1, + # type='c_sync_calc_stream', + # inputs={'X': sum_res}, + # outputs={'Out': sum_res}, + # attrs={OP_ROLE_KEY: OpRole.Optimize}) + + # the grad sum here should take the all and only param in the current shard + to_check_param = set(reversed_x_paramname) + should_check_param = set(shard.global_params).intersection( + set([ + param for param, worker_idx in shard.global_param2device.items() + if worker_idx == shard.worker_idx + ])) + assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format( + should_check_param - to_check_param, + to_check_param - should_check_param) for var_name in deperated_vars: block._remove_var(var_name, sync=False) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..76803818453c929d1dbf503159c59e1325c8337e --- /dev/null +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -0,0 +1,281 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole +from paddle.fluid import core, unique_name + + +class OffloadHelper(object): + cpu_place_type = 0 + cuda_place_type = 1 + cuda_pinned_place_type = 2 + + def __init__(self): + pass + "0: dst is on CPUPlace. " + "1: dst is on CUDAPlace. " + "2: dst is on CUDAPinnedPlace. " + + def _insert_cast_op(self, block, idx, src_name, dst_name): + src_var = block.var(src_name) + if not block.has_var(dst_name): + block.create_var( + name=dst_name, + shape=src_var.shape, + dtype=core.VarDesc.VarType.FP16, + persistable=True) + dst_var = block.var(dst_name) + assert dst_var.dtype == core.VarDesc.VarType.FP16 + block._insert_op_without_sync( + idx, + type='cast', + inputs={'X': src_var}, + outputs={'Out': dst_var}, + attrs={ + 'in_dtype': src_var.dtype, + 'out_dtype': dst_var.dtype, + OP_ROLE_KEY: OpRole.Optimize + }) + + def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type): + src_var = block.var(src_name) + dst_var = block.var(dst_name) + block._insert_op_without_sync( + idx, + type='memcpy', + inputs={'X': src_var}, + outputs={'Out': dst_var}, + attrs={ + 'dst_place_type': dst_place_type, + OP_ROLE_KEY: OpRole.Optimize, + }) + + def _insert_fetch_op(self, block, idx, src_name, dst_name): + self._insert_memcpy_op(block, idx, src_name, dst_name, + OffloadHelper.cuda_place_type) + + def _insert_offload_op(self, block, idx, src_name, dst_name): + self._insert_memcpy_op(block, idx, src_name, dst_name, + OffloadHelper.cuda_pinned_place_type) + + def _get_offload_var_name(self, name): + return unique_name.generate(name + '@offload') + + def _create_offload_var(self, var_name, offload_var_name, blocks): + for block in blocks: + var = block.var(var_name) + var.persistable = False + offload_var = block.create_var( + name=offload_var_name, + shape=var.shape, + dtype=var.dtype, + persistable=True) + + def offload_fp32param(self, block, startup_block): + """ + (p_fp16) = cast(p) + (p_fp16_recompute) = cast(p) + (pout,) = adam(p) + ===========================> + rename(p_fp16_recompute, p_fp16) + + (p,) = prefetch(p@offload) + (pout,) = adam(p) + (p_fp16) = cast(p) + (p@offload) = memcpy(p) + """ + param_to_idx = dict() + param_to_fp16 = dict() + # recompute_var which need rename to fp16_param + fp16_param_to_recompute = dict() + recompute_to_fp16 = dict() + + def remove_param(input_name): + param_to_idx.pop(input_name) + if input_name in param_to_fp16: + fp16_param = param_to_fp16.pop(input_name) + if fp16_param in fp16_param_to_recompute: + recompute = fp16_param_to_recompute.pop(fp16_param) + recompute_to_fp16.pop(recompute) + + # step1: record param + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in ('adam', 'momentum', 'lars', 'lamb'): + param = op.desc.input("Param")[0] + param_to_idx[param] = idx + + # step2: remove param which can't offload + for idx, op in enumerate(block.ops): + if is_optimizer_op(op): + break + for input_name in op.desc.input_arg_names(): + if input_name not in param_to_idx: + continue + + # param is real used by fp32 op + if op.type != 'cast': + remove_param(input_name) + continue + + # param is only used by cast op, + # which to cast fp32_param to fp16_param + output_name = op.output_arg_names[0] + if 'cast_fp16' not in output_name: + remove_param(input_name) + continue + + if 'subprog' not in output_name: + assert output_name == input_name + '.cast_fp16' + assert input_name not in param_to_fp16, \ + "There must be only one cast op from fp32 param to fp16 param." + param_to_fp16[input_name] = output_name + else: + # fp16-->recompute_var + assert input_name in param_to_fp16, \ + "param must first be cast to fp16" + fp16_param = param_to_fp16[input_name] + fp16_param_to_recompute[fp16_param] = output_name + recompute_to_fp16[output_name] = fp16_param + + param_name_to_offload_name = dict() + # step3: main_block add offload, cast op + # change recompute to fp16, remove cast(param) to fp16 + for idx, op in reversed(list(enumerate(block.ops))): + if op.type in ('adam', 'momentum', 'lars', 'lamb'): + param = op.desc.input("Param")[0] + if param not in param_to_idx: continue + # step3.1: create offload_var + offload_var_name = self._get_offload_var_name(param) + param_name_to_offload_name[param] = offload_var_name + self._create_offload_var(param, offload_var_name, + [block, startup_block]) + + # step3.2: insert cast op and offload op + self._insert_offload_op(block, idx + 1, param, offload_var_name) + + assert param in param_to_fp16 + fp16_param_name = param_to_fp16[param] + fp16_param_var = block.var(fp16_param_name) + fp16_param_var.persistable = True + self._insert_cast_op(block, idx + 1, param, + param_to_fp16[param]) + + # step3.3: insert fetch op + self._insert_fetch_op(block, idx, offload_var_name, param) + continue + + # step3.4: remove cast op + if op.type == 'cast': + input_name = op.desc.input_arg_names()[0] + if input_name in param_to_idx: + block._remove_op(idx, sync=False) + continue + + # step3.5: change recompute_param to fp16_param + for input_name in op.desc.input_arg_names(): + if input_name in recompute_to_fp16: + op._rename_input(input_name, recompute_to_fp16[input_name]) + for output_name in op.desc.output_arg_names(): + if output_name in recompute_to_fp16: + op._rename_output(output_name, + recompute_to_fp16[output_name]) + + # step4: remove recompute_param + for name in recompute_to_fp16.keys(): + block._remove_var(name, sync=False) + + # step5: startup_block add offload + visited_vars = set() + for idx, op in reversed(list(enumerate(startup_block.ops))): + for out_name in op.output_arg_names: + if out_name in visited_vars: + continue + + if out_name in param_name_to_offload_name: + var_name = out_name + offload_var_name = param_name_to_offload_name[var_name] + self._insert_offload_op(startup_block, idx + 1, var_name, + offload_var_name) + self._insert_cast_op(startup_block, idx + 1, var_name, + param_to_fp16[var_name]) + + visited_vars.add(out_name) + + block._sync_with_cpp() + startup_block._sync_with_cpp() + + def offload(self, block, startup_block): + """ + (m1, m2) = prefetch(m1@offload, m2@offload) + (m1out, m2out, pout) = adam(m1, m2, p) + (m1@offload, m2@offload) = memcpy(m1, m2) + """ + vars_name_to_offload_name = dict() + + # main_block add offload + for idx, op in reversed(list(enumerate(block.ops))): + if not is_optimizer_op(op): + break + + vars_name = [] + if op.type == "adam": + # {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} = + # adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']}) + vars_name.append(op.desc.input("Moment1")[0]) + vars_name.append(op.desc.input("Moment2")[0]) + elif op.type == 'momentum': + pass + elif op.type == 'lars': + pass + elif op.type == 'lamb': + pass + + # step1: create and init offload_var + for var_name in vars_name: + assert var_name not in vars_name_to_offload_name + + offload_var_name = self._get_offload_var_name(var_name) + vars_name_to_offload_name[var_name] = offload_var_name + + self._create_offload_var(var_name, offload_var_name, + [block, startup_block]) + + # step2: insert offload op + for var_name in vars_name: + offload_var_name = vars_name_to_offload_name[var_name] + self._insert_offload_op(block, idx + 1, var_name, + offload_var_name) + + # step3: insert fetch op + for var_name in vars_name: + offload_var_name = vars_name_to_offload_name[var_name] + self._insert_fetch_op(block, idx, offload_var_name, var_name) + + # startup_block add offload + visited_vars = set() + for idx, op in reversed(list(enumerate(startup_block.ops))): + for out_name in op.output_arg_names: + if out_name in visited_vars: + continue + + if out_name in vars_name_to_offload_name: + var_name = out_name + offload_var_name = vars_name_to_offload_name[var_name] + # insert offload op after var is generated + self._insert_offload_op(startup_block, idx + 1, var_name, + offload_var_name) + visited_vars.add(out_name) + + block._sync_with_cpp() + startup_block._sync_with_cpp() diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py index 70753b59ccc318a25661e084bd305d7d76b0e2a6..9748bec3454d5368972baf42cbb8448869c8315c 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/prune.py @@ -126,6 +126,9 @@ class ProgramDeps(object): def should_remove_op(self, op_idx): op = self._block.ops[op_idx] + # remove check_finite_and_unscale op if its input 'X' is empty + if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0: + return True for output_name in op.desc.output_arg_names(): if output_name not in self._should_removed_var: return False diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py index ad1cd4f60826bbf434294114d1982cb4beb3f00a..c25c81ee114e2bb849b470cb1b5ccdf7c070cc30 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py @@ -28,21 +28,24 @@ def check_broadcast(block): if the broadcasted var has a fill_constant op, the fill_constant op should stay forward before the broadcast op, and before a sync_calc op. Otherwise, raise error. + + should ignore and skip broadcast_op of inner_parallelism (e.g. Megatron) """ broadcast_vars = {} for idx, op in enumerate(block.ops): if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if var_name in broadcast_vars: - raise ValueError("var_name areadly exist: {}" - "the old pos is {}, the new pos is {}". - format(var_name, broadcast_vars[var_name][ - "broadcast_pos"], idx)) - broadcast_vars[var_name] = { - "fill_constant_pos": -1, - "broadcast_pos": idx, - } + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if var_name in broadcast_vars: + raise ValueError("var_name areadly exist: {}" + "the old pos is {}, the new pos is {}". + format(var_name, broadcast_vars[ + var_name]["broadcast_pos"], idx)) + broadcast_vars[var_name] = { + "fill_constant_pos": -1, + "broadcast_pos": idx, + } for idx, op in enumerate(block.ops): if op.type == "fill_constant": @@ -61,14 +64,15 @@ def check_broadcast(block): last_sync_calc_op_idx = idx continue if op.type == "c_broadcast": - var_name = op.desc.input_arg_names()[0] - if "@BroadCast" in var_name: - if broadcast_vars[var_name]["fill_constant_pos"] != -1: - assert (last_sync_calc_op_idx != -1) - assert (broadcast_vars[var_name]["fill_constant_pos"] < - last_sync_calc_op_idx) - assert (last_sync_calc_op_idx < idx) - continue + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + if "@BroadCast" in var_name: + if broadcast_vars[var_name]["fill_constant_pos"] != -1: + assert (last_sync_calc_op_idx != -1) + assert (broadcast_vars[var_name]["fill_constant_pos"] < + last_sync_calc_op_idx) + assert (last_sync_calc_op_idx < idx) + continue for input_name in op.desc.input_arg_names(): if input_name in broadcast_vars: assert (broadcast_vars[input_name]["broadcast_pos"] != -1) @@ -78,43 +82,47 @@ def check_broadcast(block): return -def check_allreduce_sum(block, shard, dp_ring_id=-1): +def check_allreduce_sum(block, shard, sharding_ring_id, dp_ring_id=-1): """ the op order should be: grad: - 0: op that generate Var - 1: sync_calc - - 2: allreduce_sum_sharding + - 2: reduce_sum_sharding (allreduce --> reduce) - 3: sync_comm - 4: allreuce_sum_dp (dp_grads) - 5: sync_comm (dp_grads) - 6: op that use Var (dp_grads & sum) + + should ignore and skip allreduce_op of inner_parallelism (e.g. Megatron) """ vars_status = {} dp_grads_status = {} idx_last_grad_allreduce = -1 idx_amp_allreduce = -1 idx_gradient_clip_allreduce = -1 + for idx, op in enumerate(block.ops): - if op.type == "c_allreduce_sum": - ring_id = op.desc.attr("ring_id") - var_name = op.desc.input_arg_names()[0] - param = var_name.split("@")[0] + if op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": + if op.all_attrs()["use_calc_stream"] == False: + ring_id = op.desc.attr("ring_id") + var_name = op.desc.input_arg_names()[0] + param = var_name.split("@")[0] - assert 'sum' in var_name or ("@GRAD" in var_name) - if 'sum' in var_name or (not shard.has_param(param)): - vars_status[var_name] = -1 - else: - dp_grads_status[var_name] = -1 + assert 'sum' in var_name or ("@GRAD" in var_name) + if 'sum' in var_name or (not shard.has_param(param)): + vars_status[var_name] = -1 + else: + dp_grads_status[var_name] = -1 - if ring_id != 0: - assert shard.has_param(param) - assert ring_id == dp_ring_id + if ring_id != sharding_ring_id: + assert shard.has_param(param) + assert ring_id == dp_ring_id - if "sum" in var_name: - idx_amp_allreduce = idx - elif "@GRAD": - idx_last_grad_allreduce = idx + if "sum" in var_name: + idx_amp_allreduce = idx + elif "@GRAD": + idx_last_grad_allreduce = idx if op.type == "c_allreduce_max": idx_gradient_clip_allreduce = idx @@ -129,37 +137,40 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): var_name] == 0: dp_grads_status[var_name] = 1 - elif op.type == "c_allreduce_sum": - var_name = op.desc.input_arg_names()[0] - ring_id = op.desc.attr("ring_id") - if ring_id == 0: - if var_name in vars_status: - _status = vars_status[var_name] - else: - _status = dp_grads_status[var_name] - if _status == -1: - raise ValueError("{} is not generated, but you are" - "trying to all-reduce it".format(var_name)) - if _status == 0: - raise ValueError("There should be a sync_calc op " - "after generate Var: {} and before the" - "c_allreduce_sum op".format(var_name)) - assert (_status == 1) - if var_name in vars_status: - vars_status[var_name] = 2 + elif op.type == "c_allreduce_sum" or op.type == "c_reduce_sum": + if op.all_attrs()["use_calc_stream"] == False: + var_name = op.desc.input_arg_names()[0] + ring_id = op.desc.attr("ring_id") + if ring_id == sharding_ring_id: + assert op.type == "c_reduce_sum", "Grad in Sharding group should be reduce rather than allreduce" + if var_name in vars_status: + _status = vars_status[var_name] + else: + _status = dp_grads_status[var_name] + if _status == -1: + raise ValueError("{} is not generated, but you are" + "trying to all-reduce it".format( + var_name)) + if _status == 0: + raise ValueError("There should be a sync_calc op " + "after generate Var: {} and before the" + "c_allreduce_sum op".format(var_name)) + assert (_status == 1) + if var_name in vars_status: + vars_status[var_name] = 2 + else: + dp_grads_status[var_name] = 2 else: - dp_grads_status[var_name] = 2 - else: - assert ring_id == dp_ring_id - param = var_name.split("@")[0] - assert shard.has_param(param) - assert dp_grads_status[var_name] == 3 - dp_grads_status[var_name] = 4 + assert ring_id == dp_ring_id + param = var_name.split("@")[0] + assert shard.has_param(param) + assert dp_grads_status[var_name] == 3 + dp_grads_status[var_name] = 4 elif op.type == "c_sync_comm_stream": var_name = op.desc.input_arg_names()[0] ring_id = op.desc.attr("ring_id") - if ring_id == 0: + if ring_id == sharding_ring_id: for var_name in op.desc.input_arg_names(): if var_name in vars_status: assert vars_status[var_name] == 2 @@ -181,6 +192,9 @@ def check_allreduce_sum(block, shard, dp_ring_id=-1): raise ValueError("There should be a sync_comm op " "after allreduce the Var: {}".format( input_name)) + raise ValueError( + "The reduce output grad [{}] should NOT be be used in Non-root rank.". + format(input_name)) if input_name in dp_grads_status: if dp_ring_id == -1: if dp_grads_status[input_name] != 3: @@ -225,6 +239,13 @@ def get_valid_op_role(block, insert_idx): return get_valid_op_role(block, insert_idx + 1) + # if insert_idx >= len(block.ops): return OpRole.Optimize + # if op_role == int(OpRole.Backward): return OpRole.Backward + # if op_role == int(OpRole.Optimize): return OpRole.Optimize + # if op_role in [int(OpRole.Forward), int(OpRole.Loss)]: + # return OpRole.Forward + # return get_valid_op_role(block, insert_idx + 1) + def insert_sync_calc_op(block, insert_idx, calc_dep_vars): """ @@ -259,6 +280,9 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): """ insert sync_comm_op for vars """ + if len(comm_dep_vars) == 0: + return 0 + op_role = get_valid_op_role(block, insert_idx) block._insert_op_without_sync( insert_idx, @@ -313,6 +337,9 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): """ _add_allreduce_ops """ + if len(allreduce_vars) == 0: + return + for var in allreduce_vars: block._insert_op_without_sync( insert_idx, @@ -325,6 +352,62 @@ def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): return +def get_grad_device(grad_name, shard): + assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format( + grad_name) + base_name = None + # mind the traversal order + possible_suffixes = [ + '.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD' + ] + for suffix in possible_suffixes: + if suffix in grad_name: + base_name = re.sub(suffix, '', grad_name) + break + + assert base_name in shard.global_param2device, "[{}] should be a param variable.".format( + base_name) + + return shard.global_param2device[base_name] + + +def insert_reduce_ops(block, + insert_idx, + ring_id, + reduce_vars, + shard, + op_role=OpRole.Backward, + use_calc_stream=False): + """ + _add_allreduce_ops + """ + for var in reduce_vars: + + root_id = get_grad_device(var, shard) + assert root_id >= 0, "root id should be a positive int".format(var) + block._insert_op_without_sync( + insert_idx, + type='c_reduce_sum', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + 'ring_id': ring_id, + 'root_id': root_id, + 'use_calc_stream': use_calc_stream, + OP_ROLE_KEY: op_role + }) + return + + +def get_first_check_finite_and_unscale_op_idx(block): + + for idx, op in enumerate(block.ops): + if op.type == "check_finite_and_unscale": + return idx + + raise ValueError("check_finite_and_unscale does not exist in block") + + def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): """ _add_broadcast_ops @@ -428,7 +511,7 @@ def comm_analyse(main_program): count)) -def add_sync_comm(program, dist_strategy): +def add_sync_comm(program, nccl_ids): """ When clone a test prog by clone from the sharding main prog, part of the sync_comm op maybe be pruned by mistake, this function @@ -438,6 +521,9 @@ def add_sync_comm(program, dist_strategy): #NOTE (liangjianzhong): only support one comm stream by now, use more than one # comm streams will cause error. should be revise in future. + assert isinstance( + nccl_ids, list + ), "the second argument of this function should be a list of nccl_ids" block = program.global_block() not_sync_vars = set([]) for op in block.ops: @@ -448,7 +534,7 @@ def add_sync_comm(program, dist_strategy): for input_name in op.desc.input_arg_names(): not_sync_vars.remove(input_name) if not_sync_vars: - for nccl_id in range(dist_strategy.nccl_comm_num): + for nccl_id in nccl_ids: block.append_op( type='c_sync_comm_stream', inputs={'X': list(not_sync_vars)}, @@ -467,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): This function handles the model saving for sharding training. """ + if main_program._pipeline_opt: + main_program = main_program._pipeline_opt['section_program']['program'] + def is_opt_vars(var): # NOTE(liangjianzhong): The checks should be updated when add new compatible optimizer # now only Momentum and adam are compatible with sharding diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index 7fed227cd9936ccac730b58959a2ce8bed51e4ef..97febe8db2b88e4a2ef495cd6540f7916286ce43 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -16,14 +16,16 @@ from paddle.fluid import unique_name, core import paddle.fluid as fluid from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper -from paddle.distributed.fleet.meta_optimizers.common import is_backward_op +from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper +from .sharding.offload_helper import OffloadHelper from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.utils import * + import logging from functools import reduce @@ -31,6 +33,8 @@ __all__ = ["ShardingOptimizer"] class ShardingOptimizer(MetaOptimizerBase): + """Sharding Optimizer.""" + def __init__(self, optimizer): super(ShardingOptimizer, self).__init__(optimizer) self.inner_opt = optimizer @@ -39,6 +43,8 @@ class ShardingOptimizer(MetaOptimizerBase): "AMPOptimizer", "LarsOptimizer", "LambOptimizer", + # "ModelParallelOptimizer", + "PipelineOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self._main_program = None @@ -51,6 +57,10 @@ class ShardingOptimizer(MetaOptimizerBase): self._reduced_grads_to_param = {} self._shard = Shard() + # use sharding as outer parallelism (e.g. inner:Megatron & outer sharding) + self._as_outer_parallelism = False + self._inner_parallelism_size = None + def _can_apply(self): if not self.role_maker._is_collective: return False @@ -71,6 +81,7 @@ class ShardingOptimizer(MetaOptimizerBase): startup_program=None, parameter_list=None, no_grad_set=None): + """Implementation of minimize.""" # TODO: (JZ-LIANG) support multiple comm in future # self._nrings = self.user_defined_strategy.nccl_comm_num self._nrings_sharding = 1 @@ -79,20 +90,72 @@ class ShardingOptimizer(MetaOptimizerBase): "fuse_broadcast_MB"] self.hybrid_dp = self.user_defined_strategy.sharding_configs[ "hybrid_dp"] + self._as_outer_parallelism = self.user_defined_strategy.sharding_configs[ + "as_outer_parallelism"] + self._inner_parallelism_size = int( + self.user_defined_strategy.sharding_configs["parallelism"]) + self.use_pipeline = self.user_defined_strategy.sharding_configs[ + "use_pipeline"] + self.acc_steps = self.user_defined_strategy.sharding_configs[ + "acc_steps"] + self.schedule_mode = self.user_defined_strategy.sharding_configs[ + "schedule_mode"] + self.pp_bz = self.user_defined_strategy.sharding_configs["pp_bz"] + self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[ + "pp_allreduce_in_optimize"] if self.inner_opt is None: raise ValueError( "self.inner_opt of ShardingOptimizer should not be None.") - optimize_ops, params_grads = self.inner_opt.minimize( - loss, startup_program, parameter_list, no_grad_set) + if self.use_pipeline: + pp_optimizer = fluid.optimizer.PipelineOptimizer(self.inner_opt, + self.acc_steps) + main_program = loss.block.program + main_program._pipeline_opt = dict() + main_program._pipeline_opt['schedule_mode'] = self.schedule_mode + main_program._pipeline_opt['pp_bz'] = self.pp_bz + pp_rank = self.role_maker._worker_index() // ( + self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] * self._inner_parallelism_size) + main_program._pipeline_opt['local_rank'] = pp_rank + main_program._pipeline_opt[ + 'global_rank'] = self.role_maker._worker_index() + main_program._pipeline_opt['use_sharding'] = True + main_program._pipeline_opt['ring_id'] = 20 + optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( + loss, startup_program, parameter_list, no_grad_set) + self.pipeline_nodes = len(program_list) + else: + optimize_ops, params_grads = self.inner_opt.minimize( + loss, startup_program, parameter_list, no_grad_set) if startup_program is None: startup_program = default_startup_program() - main_block = loss.block + if self.use_pipeline: + startup_program = startup_program._pipeline_opt['startup_program'] + #main_program = main_program._pipeline_opt['section_program']['program'] + print("pp_rank:", pp_rank) + main_program = program_list[pp_rank]['program'] + with open("main_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(main_program)) + main_block = main_program.global_block() + new_params_grads = [] + for param, grad in params_grads: + if main_block.has_var(param.name): + new_params_grads.append((param, grad)) + params_grads = new_params_grads + + else: + main_block = loss.block startup_block = startup_program.global_block() self._main_program = main_block.program self._startup_program = startup_program + if self.use_pipeline: + pp_optimizer._rename_gradient_var_name(main_block) + with open("main_%d" % self.role_maker._worker_index(), 'w') as f: + f.writelines(str(main_program)) + # step1: set_up self._set_up(params_grads) @@ -105,17 +168,76 @@ class ShardingOptimizer(MetaOptimizerBase): startup_block._sync_with_cpp() # step4: insert reduce_sum for grad - insert_scale_loss_grad_ops( - main_block, scale=1.0 / self.role_maker._worker_num()) + # grad_scale_coeff = self.role_maker._worker_num() + # if self._as_outer_parallelism: + # grad_scale_coeff = grad_scale_coeff / self._inner_parallelism_size + # insert_scale_loss_grad_ops(main_block, scale=1.0 / grad_scale_coeff) + sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + insert_scale_loss_grad_ops(main_block, scale=1.0 / sharding_group_size) main_block._sync_with_cpp() # step5: remove unneeded ops and vars from block self._prune_main_program(main_block) self._prune_startup_program(startup_block) + if self.hybrid_dp: + self._initialization_broadcast(startup_program) + + if self.use_pipeline: + # pp_optimizer._rename_gradient_var_name(main_block) + # crop ops + for idx, op in reversed(list(enumerate(main_block.ops))): + if is_update_op(op): + op_role_var = op.attr('op_role_var') + param_name = op_role_var[0] + if not self._shard.has_param(param_name): + main_block._remove_op(idx) + + for idx, op in reversed(list(enumerate(main_block.ops))): + if op.type != 'cast': continue + in_name = op.input_arg_names[0] + if in_name not in self._params: continue + #if self._shard.has_param(param_name): continue + if in_name not in main_block.vars: + main_block._remove_op(idx) + accumulated_grad_names = pp_optimizer._accumulate_gradients( + main_block) + # accumulated_grad_names = sorted(accumulated_grad_names) + if self.pp_allreduce_in_optimize: + print("persistable FP32 grad: ") + print(accumulated_grad_names) + first_optimize_op_index = get_first_check_finite_and_unscale_op_idx( + main_block) + insert_reduce_ops( + main_block, + first_optimize_op_index, + self.sharding_ring_id, + accumulated_grad_names, + self._shard, + core.op_proto_and_checker_maker.OpRole.Optimize, + use_calc_stream=True) + + main_block._sync_with_cpp() + + # TODO(wangxi): add optimize offload + if self.optimize_offload: + logging.info("Sharding with optimize offload !") + offload_helper = OffloadHelper() + offload_helper.offload(main_block, startup_block) + offload_helper.offload_fp32param(main_block, startup_block) + + with open("start_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(startup_block.program)) + with open("main_sharding_%d" % self.role_maker._worker_index(), + 'w') as f: + f.writelines(str(main_block.program)) # check op dependecy check_broadcast(main_block) - check_allreduce_sum(main_block, self._shard, self.dp_ring_id) + #check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, + # self.dp_ring_id) + #check_allreduce_sum(main_block, self._shard, self.dp_ring_id) self._wait() return optimize_ops, params_grads @@ -129,16 +251,72 @@ class ShardingOptimizer(MetaOptimizerBase): self._nrings_sharding) # config sharding & dp groups self._init_comm() + + # global + if self._as_outer_parallelism: + print("global_group_endpoints:", self.global_group_endpoints) + print("global_rank:", self.global_rank) + print("global_ring_id:", self.global_group_id) + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.global_group_endpoints, self.global_rank, + self.global_group_id, False) + + if self._as_outer_parallelism: + print("mp_group_endpoints:", self.mp_group_endpoints) + print("mp_rank:", self.mp_rank) + print("mp_ring_id:", self.mp_group_id) + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.mp_group_endpoints, self.mp_rank, self.mp_group_id, False) + # sharding + print("sharding_group_endpoints:", self.sharding_group_endpoints) + print("sharding_rank:", self.sharding_rank) + print("sharding_ring_id:", self.sharding_ring_id) self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, self.sharding_group_endpoints, self.sharding_rank, - self.sharding_ring_id, True) + self.sharding_ring_id, False) + # dp if self.hybrid_dp: self._collective_helper._init_communicator( self._startup_program, self.current_endpoint, - self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, True) + self.dp_group_endpoints, self.dp_rank, self.dp_ring_id, False) + # pp + if self.use_pipeline: + print("pp_group_endpoints:", self.pp_group_endpoints) + print("pp_rank:", self.pp_rank) + print("pp_ring_id:", self.pp_ring_id) + if self.schedule_mode == 0: # GPipe + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id, + False) + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + self.pp_group_endpoints, self.pp_rank, self.pp_ring_id + 2, + False) + else: + for pair in self.pipeline_pair: + pair_key = pair[0] * 1000 + pair[1] + ring_id = self.pp_ring_map[pair_key] + print("pp pair:{}, ring_id: {}".format(pair, ring_id)) + if self.pp_rank not in pair: continue + pp_group_endpoints = [ + self.pp_group_endpoints[pair[0]], + self.pp_group_endpoints[pair[1]], + ] + if pair[0] < pair[1]: + start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1 + else: + start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[ + 1] - 1 + pp_rank = 0 if self.pp_rank == pair[0] else 1 + self._collective_helper._init_communicator( + self._startup_program, self.current_endpoint, + pp_group_endpoints, pp_rank, ring_id, False, False) startup_block = self._startup_program.global_block() startup_block._sync_with_cpp() @@ -153,10 +331,35 @@ class ShardingOptimizer(MetaOptimizerBase): self._main_program.global_block()) def _wait(self, ): - endpoints = self.role_maker._get_trainer_endpoints() - current_endpoint = endpoints[self.role_maker._worker_index()] - if self.role_maker._worker_index() == 0: - self._collective_helper._wait(current_endpoint, endpoints) + # only the first parallelsm group that init nccl need to be wait. + if self._as_outer_parallelism: + endpoints = self.role_maker._get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker._worker_index()] + else: + endpoints = self.sharding_group_endpoints[:] + current_endpoint = self.sharding_group_endpoints[self.sharding_rank] + + if self._as_outer_parallelism: + if self.role_maker._worker_index() == 0: + self._collective_helper._wait(current_endpoint, endpoints) + else: + if self.sharding_rank == 0: + self._collective_helper._wait(current_endpoint, endpoints) + + # def _wait(self, ): + # # only the first parallelsm group that init nccl need to be wait. + # if self._as_outer_parallelism: + # endpoints = self.role_maker._get_trainer_endpoints() + # else: + # endpoints = self.sharding_group_endpoints[:] + # current_endpoint = endpoints[self.role_maker._worker_index()] + + # if self._as_outer_parallelism: + # if self.role_maker._worker_index() == 0: + # self._collective_helper._wait(current_endpoint, endpoints) + # else: + # if self.sharding_rank == 0: + # self._collective_helper._wait(current_endpoint, endpoints) def _split_program(self, block): for op_idx, op in reversed(list(enumerate(block.ops))): @@ -197,17 +400,22 @@ class ShardingOptimizer(MetaOptimizerBase): self._main_program.global_block().var(input_name)) # find reduce vars - if is_backward_op(op) and \ - OP_ROLE_VAR_KEY in op.attr_names: - op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] - if len(op_role_var) != 0: - assert len(op_role_var) % 2 == 0 - for i in range(0, len(op_role_var), 2): - param, reduced_grad = op_role_var[i], op_role_var[i + 1] - segment._allreduce_vars.append(reduced_grad) - assert ( - reduced_grad not in self._reduced_grads_to_param) - self._reduced_grads_to_param[reduced_grad] = param + if self.use_pipeline and self.pp_allreduce_in_optimize: + # place pipeline gradient allreduce in optimize + pass + else: + if is_backward_op(op) and \ + OP_ROLE_VAR_KEY in op.attr_names: + op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] + if len(op_role_var) != 0: + assert len(op_role_var) % 2 == 0 + for i in range(0, len(op_role_var), 2): + param, reduced_grad = op_role_var[i], op_role_var[ + i + 1] + segment._allreduce_vars.append(reduced_grad) + #assert ( + # reduced_grad not in self._reduced_grads_to_param) + self._reduced_grads_to_param[reduced_grad] = param # find cast op if FP16Utils.is_fp16_cast_op(block, op, self._params): @@ -234,9 +442,14 @@ class ShardingOptimizer(MetaOptimizerBase): """ weightdecay_helper = WeightDecayHelper() weightdecay_helper.prune_weight_decay(block, self._shard) + # NOTE (JZ-LIANG) the sync of FoundInfinite should among one entire Model Parallelism + # group. and each Data Parallelism group should have its own sync of FoundInfinite + Model_Paramllelism_ring_id = self.sharding_ring_id + if self._as_outer_parallelism: + Model_Paramllelism_ring_id = self.global_group_id FP16Utils.prune_fp16(block, self._shard, self._reduced_grads_to_param, - self.sharding_ring_id) - gradientclip_helper = GradientClipHelper(self.sharding_ring_id) + Model_Paramllelism_ring_id) + gradientclip_helper = GradientClipHelper(Model_Paramllelism_ring_id) gradientclip_helper.prune_gradient_clip(block, self._shard) # build prog deps @@ -264,8 +477,13 @@ class ShardingOptimizer(MetaOptimizerBase): # Prune for idx, op in reversed(list(enumerate(block.ops))): if op.type in [ - "c_allreduce_sum", "c_sync_comm_stream", - "c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init, c_comm_init_hcom" + "c_allreduce_sum", + "c_sync_comm_stream", + "c_calc_comm_stream", + "c_gen_nccl_id", + "c_comm_init", + 'send_v2', + 'recv_v2', ]: pass elif op.type == "conditional_block": @@ -303,15 +521,41 @@ class ShardingOptimizer(MetaOptimizerBase): program_deps.remove_op(idx) block._sync_with_cpp() + for idx, op in reversed(list(enumerate(block.ops))): + if op.type == 'concat' and is_optimizer_op(op): + # remove inputs that not on this card + reserved_x = [] + for var_name in op.desc.input("X"): + if block.has_var(var_name): reserved_x.append(var_name) + op.desc.set_input('X', reserved_x) + block._sync_with_cpp() return def _add_broadcast_allreduce(self, block): """ _add_broadcast_allreduce + + if combined with pipeline(grad accumulate), + the grad allreduce should be done in optimize role """ if len(self._segments) < 1: return # sharding + if self.use_pipeline and self.pp_allreduce_in_optimize: + for idx in range(len(self._segments)): + assert len(self._segments[idx]._allreduce_vars) == 0 + + # fix the _end_idx for segments[-1] if pp is used. + new_end_idx = self._segments[-1]._end_idx + for idx in range(self._segments[-1]._end_idx - 1, + self._segments[-1]._start_idx - 1, -1): + op = block.ops[idx] + if op.type == "fill_constant" or op.type == "sum": + if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 + elif op.type == "cast": + if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 + self._segments[-1]._end_idx = new_end_idx + if self._segments[-1]._allreduce_vars: shard_allredue_vars = self._shard.filter_grads(self._segments[-1] ._allreduce_vars) @@ -323,9 +567,15 @@ class ShardingOptimizer(MetaOptimizerBase): insert_sync_comm_ops(block, self._segments[-1]._end_idx, self.sharding_ring_id, self._segments[-1]._allreduce_vars) - insert_allreduce_ops(block, self._segments[-1]._end_idx, - self.sharding_ring_id, - self._segments[-1]._allreduce_vars) + # allreduce --> reduce + insert_reduce_ops( + block, + self._segments[-1]._end_idx, + self.sharding_ring_id, + self._segments[-1]._allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False) for idx, segment in reversed(list(enumerate(self._segments))): allreduce_vars = self._segments[ @@ -391,6 +641,7 @@ class ShardingOptimizer(MetaOptimizerBase): fill_constant_vars) # step4: add `cast` ops + print("cast_ops:", cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops) # step5: add broadcast ops @@ -404,8 +655,15 @@ class ShardingOptimizer(MetaOptimizerBase): insert_sync_comm_ops(block, segment._start_idx, self.sharding_ring_id, allreduce_vars) # sharding - insert_allreduce_ops(block, segment._start_idx, - self.sharding_ring_id, allreduce_vars) + # allreduce --> reduce + insert_reduce_ops( + block, + segment._start_idx, + self.sharding_ring_id, + allreduce_vars, + self._shard, + op_role=OpRole.Backward, + use_calc_stream=False) block._sync_with_cpp() @@ -459,6 +717,7 @@ class ShardingOptimizer(MetaOptimizerBase): def _init_comm(self): if self.hybrid_dp: + assert self._as_outer_parallelism == False, "hybrid dp is conflict when using sharding as outer parallelism" self.sharding_group_size = self.user_defined_strategy.sharding_configs[ "sharding_group_size"] self.sharding_ring_id = 0 @@ -476,6 +735,7 @@ class ShardingOptimizer(MetaOptimizerBase): ep for idx, ep in enumerate(self.endpoints) if (idx % self.sharding_group_size) == self.sharding_rank ] + assert self.global_word_size > self.sharding_group_size, \ "global_word_size: {} should be larger than sharding_group_size: {}".format(self.global_word_size, self.sharding_group_size) assert self.global_word_size % self.sharding_group_size == 0, \ @@ -485,30 +745,215 @@ class ShardingOptimizer(MetaOptimizerBase): self.global_word_size, self.sharding_group_size, self.dp_group_size) + self.pp_ring_id = -1 + self.pp_rank = -1 + self.pp_group_size = None + self.pp_group_endpoints = None + + # sharding parallelism is the only model parallelism in the current setting + self.mp_group_id = self.sharding_ring_id + self.mp_rank = self.sharding_rank + self.mp_group_size = self.sharding_group_size + self.mp_group_endpoints = self.sharding_group_endpoints[:] logging.info("Using Sharing&DP mode !") else: - self.sharding_ring_id = 0 - self.sharding_rank = self.global_rank - self.sharding_group_size = self.role_maker._worker_num() - self.sharding_group_endpoints = self.endpoints + if self._as_outer_parallelism and not self.use_pipeline: + self.sharding_ring_id = 1 + assert self.global_word_size > self._inner_parallelism_size, \ + "global_word_size: {} should be larger than inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + assert self.global_word_size % self._inner_parallelism_size == 0, \ + "global_word_size: {} should be divisible to the inner_parallelism_size: {}".format(self.global_word_size, self._inner_parallelism_size) + self.sharding_rank = self.global_rank // self._inner_parallelism_size + self.sharding_group_size = self.role_maker._worker_num( + ) // self._inner_parallelism_size + _offset = self.global_rank % self._inner_parallelism_size + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if idx % self._inner_parallelism_size == _offset + ] + + # the current entire model parallelism group is the combination of innert & sharding parallelism + self.mp_group_id = 2 + self.mp_rank = self.global_rank + self.mp_group_size = self.role_maker._worker_num() + self.mp_group_endpoints = self.endpoints[:] + logging.info("Using Sharing as Outer parallelism mode !") + + # print( + # "init the nccl comm for megatron paramllelism, this should be done in Megatron Metaoptimizer" + # ) + # partition_idx = self.global_rank // self._inner_parallelism_size + # magetron_endpoints = self.endpoints[ + # partition_idx * self._inner_parallelism_size:partition_idx * + # self._inner_parallelism_size + self._inner_parallelism_size] + # magetron_rank = self.global_rank % self._inner_parallelism_size + + # self._collective_helper._init_communicator( + # program=self._startup_program, + # current_endpoint=self.current_endpoint, + # endpoints=magetron_endpoints, + # rank=magetron_rank, + # ring_id=0, + # wait_port=True) + # logging.info("megatron group size: {}".format( + # self._inner_parallelism_size)) + # logging.info("megatron rank: {}".format(magetron_rank)) + # logging.info("megatron endpoints: {}".format( + # magetron_endpoints)) + if self.use_pipeline: + if self._inner_parallelism_size == 1: + self.sharding_ring_id = 0 + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + self.sharding_rank = self.global_rank % self.sharding_group_size + assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( + ) + self.pp_ring_id = 20 + self.pp_rank = self.global_rank // ( + self.sharding_group_size * self._inner_parallelism_size) + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // self.sharding_group_size) == self.pp_rank + ] + self.pp_group_size = self.pipeline_nodes + self.pp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx % self.sharding_group_size + ) == self.sharding_rank + ] + else: + self.mp_group_id = 0 + self.sharding_ring_id = 1 + self.pp_ring_id = 20 + self.mp_rank = self.global_rank % self._inner_parallelism_size + self.mp_group = self.global_rank // self._inner_parallelism_size + self.mp_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if idx // self._inner_parallelism_size == self.mp_group + ] + print("megatron_group_endpoints:", self.mp_group_endpoints) + print("megatron_rank:", self.mp_rank) + # self.cards_per_node = 8 + self.sharding_group_size = self.user_defined_strategy.sharding_configs[ + 'sharding_group_size'] + self.sharding_rank = ( + self.global_rank // + self._inner_parallelism_size) % self.sharding_group_size + self.sharding_group_id = self.global_rank // ( + self._inner_parallelism_size * self.sharding_group_size) + self.megatron_rank = self.global_rank % self._inner_parallelism_size + self.sharding_group_endpoints = [ + ep for idx, ep in enumerate(self.endpoints) + if (idx // (self._inner_parallelism_size * + self.sharding_group_size) + ) == self.sharding_group_id and idx % + self._inner_parallelism_size == self.megatron_rank + ] + print("sharding_endpoint:", self.sharding_group_endpoints) + print("sharding_rank:", self.sharding_rank) + assert self.sharding_group_size * self.pipeline_nodes * self._inner_parallelism_size == self.role_maker._worker_num( + ) + self.pp_rank = self.global_rank // ( + self.sharding_group_size * + self._inner_parallelism_size) % self.pipeline_nodes + offset = self.sharding_group_size * self._inner_parallelism_size + # TODO: Adjust for dp + idx_with_pp_0 = self.global_rank % ( + self.sharding_group_size * self._inner_parallelism_size) + self.pp_group_endpoints = [] + for i in range(self.pipeline_nodes): + self.pp_group_endpoints.append(self.endpoints[ + idx_with_pp_0]) + idx_with_pp_0 += offset + print("pp_group_endpoints:", self.pp_group_endpoints) + print("pp_rank:", self.pp_rank) + + #self.pp_group_endpoints = [ + # ep for idx, ep in enumerate(self.endpoints) + # if (idx % self.sharding_group_size) == self.sharding_rank + #] + self.global_group_id = 3 + self.global_rank = self.global_rank + self.global_group_size = self.role_maker._worker_num() + self.global_group_endpoints = self.endpoints[:] + logging.info("Using Sharing as Outer parallelism mode !") + self.dp_ring_id = -1 + self.dp_rank = -1 + self.dp_group_size = None + self.dp_group_endpoints = None + + logging.info("Using Sharing with pipeline !") + #else: + # self.sharding_ring_id = 0 + # self.sharding_rank = self.global_rank + # self.sharding_group_size = self.role_maker._worker_num() + # self.sharding_group_endpoints = self.endpoints + + # # sharding parallelism is the only model parallelism in the current setting + # self.mp_group_id = self.sharding_ring_id + # self.mp_rank = self.sharding_rank + # self.mp_group_size = self.sharding_group_size + # self.mp_group_endpoints = self.sharding_group_endpoints[:] + + # logging.info("Using Sharing alone mode !") + self.dp_ring_id = -1 self.dp_rank = -1 self.dp_group_size = None self.dp_group_endpoints = None + #self.pp_ring_id = -1 + #self.pp_rank = -1 + #self.pp_group_size = None + #self.pp_group_endpoints = None + #self.dp_ring_id = -1 + #self.dp_rank = -1 + #self.dp_group_size = None + #self.dp_group_endpoints = None + logging.info("Using Sharing alone mode !") - logging.info("global word size: {}".format(self.global_word_size)) - logging.info("global rank: {}".format(self.global_rank)) - logging.info("sharding group_size: {}".format(self.sharding_group_size)) - logging.info("sharding rank: {}".format(self.sharding_rank)) - logging.info("dp group size: {}".format(self.dp_group_size)) - logging.info("dp rank: {}".format(self.dp_rank)) - logging.info("current endpoint: {}".format(self.current_endpoint)) - logging.info("sharding group endpoints: {}".format( - self.sharding_group_endpoints)) - logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) - logging.info("global word endpoints: {}".format(self.endpoints)) + #logging.info("global word size: {}".format(self.global_word_size)) + #logging.info("global rank: {}".format(self.global_rank)) + #logging.info("sharding group_size: {}".format(self.sharding_group_size)) + #logging.info("sharding rank: {}".format(self.sharding_rank)) + #logging.info("current model parallelism group_size: {}".format( + # self.mp_group_size)) + #logging.info("current model parallelism rank: {}".format(self.mp_rank)) + #logging.info("dp group size: {}".format(self.dp_group_size)) + #logging.info("dp rank: {}".format(self.dp_rank)) + #logging.info("current endpoint: {}".format(self.current_endpoint)) + #logging.info("global word endpoints: {}".format(self.endpoints)) + #logging.info("sharding group endpoints: {}".format( + # self.sharding_group_endpoints)) + #logging.info("current model parallelism group endpoints: {}".format( + # self.mp_group_endpoints)) + #logging.info("dp group endpoints: {}".format(self.dp_group_endpoints)) return + + def _initialization_broadcast(self, startup_prog): + """ + this funtion is to ensure the initialization between dp group to be + identical when hybrid-dp is used. + """ + block = startup_prog.global_block() + params = [] + for param in block.iter_parameters(): + params.append(param) + block.append_op( + type='c_broadcast', + inputs={'X': param}, + outputs={'Out': param}, + attrs={ + 'ring_id': self.dp_ring_id, + 'root': 0, + OP_ROLE_KEY: OpRole.Forward + }) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': params}, + outputs={'Out': params}, + attrs={'ring_id': self.dp_ring_id, + OP_ROLE_KEY: OpRole.Forward}) diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 33e2e387a82758ba9cd59dc40d41fb5ad05ee29b..abf02851b9c2dcc7b33ed19b2f2a16f1b693313e 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -115,7 +115,7 @@ class ProgramStats(object): updated_min_idx = min_idx while idx_ > pre_segment_end_idx: if is_amp_cast(self.ops[idx_]): - _logger.debug("found amp-cast op: {}, : {}".format(self.ops[ + _logger.info("found amp-cast op: {}, : {}".format(self.ops[ idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ 0])) updated_min_idx = idx_ @@ -155,7 +155,7 @@ class ProgramStats(object): sorted_checkpoints = [] for name in checkpoints_name: if name not in self.var_op_deps: - _logger.debug( + _logger.info( "Recompute Optimizer: deleted %s from checkpoints, because it is not used in paddle program." % name) elif self.var_op_deps[name]["var_as_output_ops"] == []: @@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): new_op_desc = block.desc.append_op() new_op_desc.copy_from(desc) new_op_desc._set_attr(op_role_attr_name, backward) + if desc.has_attr('op_device'): + new_op_desc._set_attr('op_device', desc.attr('op_device')) result_descs.append(new_op_desc) return result_descs @@ -784,7 +788,6 @@ def _append_backward_ops_with_checkpoints_( start_idx = 0 pre_segment_end_idx = -1 while True: - _logger.debug("FW op range[0] - [{}]".format(len(ops))) if start_idx >= len(checkpoints_name) - 1: break # min_idx: checkpoint_1' s input op @@ -797,6 +800,9 @@ def _append_backward_ops_with_checkpoints_( min_idx = program_stat._update_segment_start( min_idx, pre_segment_end_idx) segments.append([min_idx, max_idx + 1]) + else: + _logger.info("Could not recompute op range [{}] - [{}] ".format( + min_idx, max_idx + 1)) start_idx += 1 @@ -806,15 +812,15 @@ def _append_backward_ops_with_checkpoints_( recompute_segments = segments for i, (idx1, idx2) in enumerate(recompute_segments): - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) - _logger.debug("recompute segment[{}]".format(i)) - _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + _logger.info("recompute segment[{}]".format(i)) + _logger.info("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( ), ops[idx1].desc.input_arg_names())) - _logger.debug("segment end op: [{}]: [{}]".format(ops[ + _logger.info("segment end op: [{}]: [{}]".format(ops[ idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) # 2) go through all forward ops and induct all variables that will be hold in memory @@ -825,9 +831,9 @@ def _append_backward_ops_with_checkpoints_( program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) cross_vars = set(vars_should_be_hold) - set(checkpoints_name) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) - _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + _logger.info("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ len(cross_vars), cross_vars)) # b. output of seed op should be kept in memory @@ -843,6 +849,7 @@ def _append_backward_ops_with_checkpoints_( vars_in_memory = vars_should_be_hold + checkpoints_name max_calculated_op_position = len(ops) + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() if recompute_segments == []: gap_ops = ops[0:max_calculated_op_position] for op in reversed(gap_ops): @@ -852,6 +859,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -866,6 +878,11 @@ def _append_backward_ops_with_checkpoints_( _pretty_op_desc_(op.desc, "with_sub_block")) grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op.desc, cpt.to_text(no_grad_dict[block.idx]), []) + # Set device for grad_op according to forward Op + if op.desc.has_attr(device_attr_name): + op_device = op.desc.attr(device_attr_name) + for op_desc in grad_op_desc: + op_desc._set_attr(device_attr_name, op_device) added_descs = _add_descs_to_block(grad_op_desc, local_block) grad_op_descs.extend(added_descs) grad_to_var.update(op_grad_to_var) @@ -888,6 +905,18 @@ def _append_backward_ops_with_checkpoints_( continue if name not in var_name_dict: var_name_dict[name] = name + var_suffix + + # we should create the rename var in subprog, otherwise its VarType will be BOOL + block.create_var( + name=var_name_dict[name], + shape=block.program.global_block().var(name).shape, + dtype=block.program.global_block().var(name).dtype, + type=block.program.global_block().var(name).type, + persistable=block.program.global_block().var( + name).persistable, + stop_gradient=block.program.global_block().var(name) + .stop_gradient) + # 3.a. add ops in current recompute_segment as forward recomputation ops buffer_descs = _add_needed_descs_to_block(ff_ops, buffer_block, block, vars_in_memory) diff --git a/python/paddle/fluid/clip.py b/python/paddle/fluid/clip.py index 8fd01509331e207af1aaabde1e40404f1a8c6f74..5e4ea24137e538558a1885b5748c80655d7dd6f1 100644 --- a/python/paddle/fluid/clip.py +++ b/python/paddle/fluid/clip.py @@ -489,9 +489,14 @@ class ClipGradByGlobalNorm(ClipGradBase): continue with p.block.program._optimized_guard([p, g]): - new_grad = layers.elementwise_mul(x=g, y=scale_var) - param_new_grad_name_dict[p.name] = new_grad.name - params_and_grads.append((p, new_grad)) + p.block.append_op( + type='elementwise_mul', + inputs={'X': g, + 'Y': scale_var}, + outputs={'Out': g}) + + param_new_grad_name_dict[p.name] = p.name + params_and_grads.append((p, p)) _correct_clip_op_role_var(params_and_grads, param_new_grad_name_dict) return params_and_grads diff --git a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py index f9c3a613c4053a79cb467d752b20f6f4ed3ea4ec..67e83a2ec4617c0c59bdb1f92c983e3b5ae471a3 100644 --- a/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py +++ b/python/paddle/fluid/contrib/mixed_precision/fp16_utils.py @@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): outputs={"Out": out_var}, attrs={ "in_dtype": in_var.dtype, - "out_dtype": out_var.dtype + "out_dtype": out_var.dtype, + "op_device": op.attr("op_device") }) num_cast_ops += 1 _rename_arg(op, in_var.name, out_var.name) @@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name, type="cast", inputs={"X": target_var}, outputs={"Out": cast_var}, - attrs={"in_dtype": target_var.dtype, - "out_dtype": cast_var.dtype}) + attrs={ + "in_dtype": target_var.dtype, + "out_dtype": cast_var.dtype, + "op_device": op.attr("op_device") + }) num_cast_ops += 1 op_var_rename_map[block.idx][target_var.name] = cast_var.name diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 838aea37f18344b257e2dd8a9063ebc7f7202152..3c5906ceb9df9558dd8fbf7080422a492eeddb8e 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -413,6 +413,9 @@ class Section(DeviceWorker): section_param = trainer_desc.section_param section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] + section_param.pipeline_stage = pipeline_opt["pipeline_stage"] + section_param.num_pipeline_stages = pipeline_opt["num_pipeline_stages"] + section_param.schedule_mode = pipeline_opt["schedule_mode"] cfg = section_param.section_config program = pipeline_opt["section_program"] cfg.program_desc.ParseFromString(program["program"]._get_desc() diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index 3c560689e1210fcb312a2311da72c720afb2fe0a..2b70b670a475e9166e8b9491f4d603ff505d4feb 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -19,6 +19,7 @@ import six import os import logging from collections import defaultdict +import time import paddle from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table @@ -3759,15 +3760,21 @@ class PipelineOptimizer(object): def __init__(self, optimizer, num_microbatches=1, start_cpu_core_id=0): if framework.in_dygraph_mode(): raise Exception("In dygraph, don't support PipelineOptimizer.") - if not isinstance(optimizer, Optimizer) and not isinstance( - optimizer, paddle.optimizer.Optimizer) and not isinstance( - optimizer, paddle.fluid.contrib.mixed_precision.decorator. - OptimizerWithMixedPrecision): + supported_opt_types = (Optimizer, paddle.fluid.contrib.mixed_precision. + decorator.OptimizerWithMixedPrecision) + if not isinstance(optimizer, supported_opt_types): raise ValueError("The 'optimizer' parameter for " - "PipelineOptimizer must be an instance of " - "Optimizer, but the given type is {}.".format( - type(optimizer))) + "PipelineOptimizer must be an instance of one of " + "{}, but the type is {}.".format( + supported_opt_types, type(optimizer))) + self._optimizer = optimizer + + # Get the original optimizer defined by users, such as SGD + self._origin_optimizer = self._optimizer + while hasattr(self._origin_optimizer, "inner_opt"): + self._origin_optimizer = self._origin_optimizer.inner_opt + assert num_microbatches >= 1, ( "num_microbatches must be a positive value.") self._num_microbatches = num_microbatches @@ -3781,52 +3788,147 @@ class PipelineOptimizer(object): self._op_role_var_key = op_maker.kOpRoleVarAttrName() self._op_device_key = op_maker.kOpDeviceAttrName() self._param_device_map = None + self._pipeline_pair = [] + self._pp_ring_map = dict() def _create_vars(self, block, ori_block): - # Create vars for block, copied from main_program's global block + # Create vars for block, copied from ori_block used_var_set = set() for op_idx in range(block.desc.op_size()): - op_desc = block.desc.op(op_idx) - vars = op_desc.input_arg_names() + op_desc.output_arg_names() + # Whether to insert allreduce_sum or allreduce_max op? + # For amp and global gradient clip strategies, we should + # get the global infomation, so allreduce op is needed. + should_insert = False + + op = block.ops[op_idx] + # For op process vars on all devices, remove its input + # vars not in this block + reserved_x = [] + + if op.type == 'reduce_any' and self._is_optimize_op(op): + should_insert = True + if op.type == 'concat' and self._is_optimize_op(op): + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + print('reserved_x:', reserved_x) + if op.type == 'update_loss_scaling': + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + op.desc.set_output('Out', reserved_x) + if op.type == 'sum' and self._is_gradient_clip_op(op): + for input_name in op.desc.input("X"): + if block._find_var_recursive(input_name): + reserved_x.append(input_name) + op.desc.set_input('X', reserved_x) + should_insert = True + vars = op.desc.input_arg_names() + op.desc.output_arg_names() for var in vars: # a var whose name contains "blocking_queue" # only exists in startup program - if var in used_var_set or "_blocking_queue" in var: - continue + if var in used_var_set or "_blocking_queue" in var: continue used_var_set.add(var) if block._find_var_recursive(str(var)): continue source_var = ori_block._var_recursive(str(var)) if source_var.type == core.VarDesc.VarType.READER: - block.create_var( + dest_var = block.create_var( name=var, type=core.VarDesc.VarType.READER, persistable=source_var.persistable) else: - block._clone_variable(source_var, False) + dest_var = block._clone_variable(source_var, False) + dest_var.stop_gradient = source_var.stop_gradient + + continue + # TODO add allreduce_max when without sharding + if not should_insert: continue + out_name = op.desc.output_arg_names()[0] + out_var = block.var(out_name) + offset = 0 + if op.type == "reduce_any": + # cast the bool var to int32 to use allreduce op + temp_var_name = unique_name.generate(out_name + "_cast_int32") + temp_var = block.create_var( + name=temp_var_name, shape=[1], dtype="int32") + block._insert_op( + op_idx + 1 + offset, + type='cast', + inputs={'X': out_var}, + outputs={'Out': temp_var}, + attrs={ + 'in_dtype': out_var.dtype, + 'out_dtype': temp_var.dtype, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize + }) + offset += 1 + # block._insert_op( + # op_idx + 1 + offset, + # type='c_sync_calc_stream', + # inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + # outputs={ + # 'Out': temp_var if op.type == "reduce_any" else out_var + # }, + # attrs={ + # OP_ROLE_KEY: + # core.op_proto_and_checker_maker.OpRole.Optimize, + # }) + # offset += 1 + block._insert_op( + op_idx + 1 + offset, + type='c_allreduce_max' + if op.type == "reduce_any" else 'c_allreduce_sum', + inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + outputs={ + 'Out': temp_var if op.type == "reduce_any" else out_var + }, + attrs={ + 'ring_id': self.ring_id, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize, + 'use_calc_stream': True + }) + offset += 1 + # block._insert_op( + # # op_idx + 1 + extra_index, + # op_idx + 1 + offset, + # type='c_sync_comm_stream', + # inputs={'X': temp_var if op.type == "reduce_any" else out_var}, + # outputs={ + # 'Out': temp_var if op.type == "reduce_any" else out_var + # }, + # attrs={ + # 'ring_id': self.ring_id, + # OP_ROLE_KEY: + # core.op_proto_and_checker_maker.OpRole.Optimize, + # }) + # offset += 1 + if op.type == "reduce_any": + block._insert_op( + op_idx + 1 + offset, + type='cast', + inputs={'X': temp_var}, + outputs={'Out': out_var}, + attrs={ + 'in_dtype': temp_var.dtype, + 'out_dtype': out_var.dtype, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Optimize + }) def _is_loss_grad_op(self, op): - if self._op_role_key not in op.attr_names: - return False - op_role = int(op.all_attrs()[self._op_role_key]) + assert self._op_role_key in op.attr_names + op_role = int(op.attr(self._op_role_key)) return op_role & int(self._op_role.Backward) and op_role & int( self._op_role.Loss) - def _is_backward_op(self, op): - return self._op_role_key in op.attr_names and int(op.all_attrs()[ - self._op_role_key]) & int(self._op_role.Backward) - - def _is_optimize_op(self, op): - return self._op_role_key in op.attr_names and int(op.all_attrs()[ - self._op_role_key]) & int(self._op_role.Optimize) - - def _is_update_op(self, op): - return 'Param' in op.input_names and 'Grad' in op.input_names and ( - "LearningRate" in op.input_names) - def _split_program(self, main_program, devices): """ Split a program into sections according to devices that ops run on. - The ops of the role LRSched are copied to all sections. + The op whose op_device attr is "gpu:all" is copied to all sections. Args: main_program (Program): the main program @@ -3842,27 +3944,20 @@ class PipelineOptimizer(object): block = main_program.block(0) for op in block.ops: device = op.attr(self._op_device_key) - op_role = op.attr(self._op_role_key) - if int(op_role) & int(self._op_role.LRSched): - # Copy ops of the role LRSched to all sections. - for device in device_program_map.keys(): - program = device_program_map[device] - op_desc = op.desc - ap_op = program["program"].block(0).desc.append_op() - ap_op.copy_from(op_desc) - # ap_op._set_attr(self._op_device_key, "") - elif op.type == "create_py_reader" or op.type == "read" or op.type == "create_double_buffer_reader": - # Copy read related ops to all section to make them exit after each epoch. + # Copy ops whose op_device set to "gpu:all" to all sections. + if device == "gpu:all": for device in device_program_map.keys(): program = device_program_map[device] op_desc = op.desc ap_op = program["program"].block(0).desc.append_op() ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, "") else: program = device_program_map[device] op_desc = op.desc ap_op = program["program"].block(0).desc.append_op() ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, "") for key in devices: program = device_program_map[key] @@ -3921,6 +4016,11 @@ class PipelineOptimizer(object): var_name as output. var_name (string): Variable name. """ + # To skip the cast op added by amp which has no op_device set + if '.cast_fp32' in var_name: + var_name = var_name.replace('.cast_fp32', '') + if '.cast_fp16' in var_name: + var_name = var_name.replace('.cast_fp16', '') post_op = [] before = True for op in ops: @@ -3949,7 +4049,7 @@ class PipelineOptimizer(object): """ prev_op = [] for op in ops: - if op.type == 'send_v2' or op.type == 'recv_v2': + if op.type == 'send_v2' or op.type == 'recv_v2' or op.type == 'c_broadcast': continue if op == cur_op: break @@ -3964,11 +4064,8 @@ class PipelineOptimizer(object): return None def _rename_arg(self, op, old_name, new_name): - op_desc = op.desc - if isinstance(op_desc, tuple): - op_desc = op_desc[0] - op_desc._rename_input(old_name, new_name) - op_desc._rename_output(old_name, new_name) + op._rename_input(old_name, new_name) + op._rename_output(old_name, new_name) def _create_var(self, block, ref_var, name): """ @@ -3982,9 +4079,10 @@ class PipelineOptimizer(object): dtype=ref_var.dtype, type=ref_var.type, lod_level=ref_var.lod_level, - persistable=False, - is_data=False, + persistable=ref_var.persistable, + is_data=ref_var.is_data, need_check_feed=ref_var.desc.need_check_feed()) + new_var.stop_gradient = ref_var.stop_gradient return new_var def _get_data_var_info(self, block): @@ -4037,6 +4135,7 @@ class PipelineOptimizer(object): if not var_name in first_block.vars: self._create_var(first_block, main_var, var_name) dev_index = int(device.split(':')[1]) + print("dev_index:", dev_index) first_block._insert_op( index=insert_index, type='send_v2', @@ -4044,8 +4143,11 @@ class PipelineOptimizer(object): attrs={ self._op_device_key: first_dev_spec, self._op_role_key: self._op_role.Forward, - 'use_calc_stream': True, + 'use_calc_stream': False, 'peer': dev_index, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if dev_index > first_dev_index else self.ring_id + 2, }) # Get the device that that data on assert device in devices @@ -4070,6 +4172,21 @@ class PipelineOptimizer(object): self._op_role_key: self._op_role.Forward, 'peer': first_dev_index, 'use_calc_stream': True, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if first_dev_index < dev_index else self.ring_id + 2, + }) + block._insert_op( + index=index + 1, + type='c_sync_comm_stream', + inputs={'X': [new_var]}, + outputs={'Out': [new_var]}, + attrs={ + self._op_device_key: device, + self._op_role_key: self._op_role.Forward, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id + if first_dev_index > dev_index else self.ring_id + 2, }) def _strip_grad_suffix(self, name): @@ -4085,79 +4202,190 @@ class PipelineOptimizer(object): """ return name + core.grad_var_suffix() - def _add_opdevice_attr_for_regularization_clip(self, block): + def _is_forward_op(self, op): """ - Add op_device attribute for regulization and clip ops. + Is the op_role attribute of a op is Forward. """ - for op in block.ops: - # role for regularization and clip ops is optimize - if int(op.attr(self._op_role_key)) != int(self._op_role.Optimize): - continue - if op.has_attr(self._op_device_key) and ( - op.attr(self._op_device_key) != ""): - continue - assert self._op_role_var_key in op.attr_names - op_role_var = op.all_attrs()[self._op_role_var_key] - assert len(op_role_var) == 2 - param_name = op_role_var[0] - device = self._param_device_map[param_name] - op._set_attr(self._op_device_key, device) + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Forward) + + def _is_backward_op(self, op): + """ + Is the op_role attribute of a op is Backward. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Backward) - def _add_default_opdevice_attr(self, block): + def _is_loss_op(self, op): """ - 1. Add default op_device attribute for lr-related ops. - The default value is the one that of the first place. - 2. Add default op_device attribute for sum ops added during - backward. For these ops, we set the op_device attribute - as the one of its post op, i.e, which op has the output of the - sum op as an input. + Is the op_role attribute of a op is Loss. """ - first_devcie = "" + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Loss) - # Get the device spec of the first place. - # device_spec: 'cpu' for cpu device and 'gpu:id' for gpu device, - # e.g. 'gpu:0', 'gpu:1', etc. - for op in block.ops: - if op.has_attr(self._op_device_key) and ( - op.attr(self._op_device_key) != ""): - first_device = op.attr(self._op_device_key) - break - assert first_device - first_device_type = first_device.split(":")[0] - assert first_device_type == "gpu" + def _is_optimize_op(self, op): + """ + Is the op_role attribute of a op is Optimize. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.Optimize) - # set op_device attr for lr-related ops - lrsched_role = int(self._op_role.LRSched) - for op in block.ops: - if not op.has_attr(self._op_device_key) or ( - op.attr(self._op_device_key) == ""): - if op.type == "sum": - # For sum ops that compute the sum of @RENAMED@ vars - for name in op.desc.input_arg_names(): - assert '@RENAME@' in name - assert len(op.desc.output_arg_names()) == 1 - out_name = op.desc.output_arg_names()[0] - post_op = self._find_post_op(block.ops, op, out_name) - device = post_op.attr(self._op_device_key) - assert device - op._set_attr(self._op_device_key, device) - continue + def _is_lrsched_op(self, op): + """ + Is the op_role attribute of a op is LRSched. + """ + assert self._op_role_key in op.attr_names + return int(op.attr(self._op_role_key)) == int(self._op_role.LRSched) + + def _is_update_op(self, op): + """ + Is the op updates the parameter using gradient. + """ + return 'Param' in op.input_names and 'Grad' in op.input_names and ( + "LearningRate" in op.input_names) + + def _get_op_device_attr(self, op): + """ + Get the op_device attribute of a op. + """ + device = op.attr(self._op_device_key) \ + if op.has_attr(self._op_device_key) else None + if device: + assert device[0:3] == 'gpu', "Now, only gpu devices are " \ + "supported in pipeline parallemism." + return device + + def _add_op_device_attr_for_op(self, op, idx, block): + """ + Add op_device attrribute for ops that have not that attribute set. - assert op.attr(self._op_role_key) == lrsched_role, ( - "Op whose op_device attr has not been set for pipeline" - " must be of the role LRSched.") - op._set_attr(self._op_device_key, first_device) + We use "gpu:all" to represent the op should be put on all + sub-programs, such as lr-related ops. Note that: "gpu:all" + is only used by pipeline as an indicator. + """ + lrsched_role = int(self._op_role.LRSched) + if op.attr(self._op_role_key) == lrsched_role: + # For LRSched ops, we should put them on all sub-programs to + # make sure each sub-program update the lr correctly + op._set_attr(self._op_device_key, "gpu:all") + elif op.type == "sum" and self._is_backward_op(op): + # For sum ops that compute the sum of @RENAMED@ vars + for name in op.desc.input_arg_names(): + assert '@RENAME@' in name, \ + "The op must be sum used to accumulate renamed vars." + assert len(op.desc.output_arg_names()) == 1 + out_name = op.desc.output_arg_names()[0] + post_op = self._find_post_op(block.ops, op, out_name) + assert post_op.has_attr( + 'op_device'), "{} has no op_device attr for var {}".format( + post_op.type, out_name) + device = post_op.attr(self._op_device_key) + assert device, "The post op must have op_device set." + op._set_attr(self._op_device_key, device) + elif (op.type == "cast" or + op.type == "scale") and self._is_backward_op(op): + prev_op = self._find_real_prev_op(block.ops, op, + op.desc.input("X")[0]) + op._set_attr('op_device', prev_op.attr('op_device')) + elif op.type == "memcpy" and not self._is_optimize_op(op): + assert len(op.input_arg_names) == 1 and len( + op.output_arg_names) == 1 + input_name = op.input_arg_names[0] + output_name = op.output_arg_names[0] + if '@Fetch' in output_name: + post_op = self._find_post_op(block.ops, op, output_name) + op._set_attr('op_device', post_op.attr('op_device')) + else: + prev_op = self._find_real_prev_op(block.ops, op, + op.desc.input("X")[0]) + op._set_attr('op_device', prev_op.attr('op_device')) + elif self._is_loss_op(op): + # For loss * loss_scaling op added by AMP + offset = 1 + while (not block.ops[idx + offset].has_attr(self._op_device_key) or + not block.ops[idx + offset].attr(self._op_device_key)): + offset += 1 + # assert block.ops[idx + 1].type == "fill_constant" + # assert block.ops[idx + 2].type == "elementwise_mul_grad" + # assert block.ops[idx + 3].type == "elementwise_add_grad" + # assert block.ops[idx + 4].type == "mean_grad" + # device = block.ops[idx + 4].attr(self._op_device_key) + device = block.ops[idx + offset].attr(self._op_device_key) + assert device, "Please put you program within device_guard scope." + # op._set_attr(self._op_device_key, device) + # block.ops[idx + 1]._set_attr(self._op_device_key, device) + # block.ops[idx + 2]._set_attr(self._op_device_key, device) + # block.ops[idx + 2]._set_attr(self._op_device_key, device) + for i in range(offset): + block.ops[idx + i]._set_attr(self._op_device_key, device) + elif self._is_optimize_op(op) and op.type == "check_finite_and_unscale": + #op._set_attr(self._op_device_key, "gpu:all") + op_role_var = op.attr(self._op_role_var_key) + param_name = op_role_var[0] + device = self._param_device_map[param_name] + op._set_attr(self._op_device_key, device) + elif self._is_optimize_op(op) and op.type == "cast": + # For fp16-->fp32 cast added by AMP + grad_name = op.output('Out') + assert len(grad_name) == 1 + param_name = grad_name[0].strip(core.grad_var_suffix()) + device = self._param_device_map[param_name] + op._set_attr(self._op_device_key, device) + elif self._is_gradient_clip_op(op) or self._is_regularization_op(op): + # For gradient clip and regularization ops, we set their op_device + # attribute to the device where their corresponding parameters on. + assert self._op_role_var_key in op.attr_names, "gradient_clip " \ + "and regularization ops must have op_role_var attribute." + op_role_var = op.attr(self._op_role_var_key) + assert len(op_role_var) == 2, "op_role_var for gradient_clip " \ + "regularization ops must have two elements." + param_name = op_role_var[0] + device = self._param_device_map[param_name] + # For sum op added by global gradient clip, it must be + # put on all devices + if (op.type == 'sum' or op.type == 'sqrt' or + op.type == 'fill_constant' or + op.type == 'elementwise_max' or + op.type == 'elementwise_div'): + device = "gpu:all" + op._set_attr(self._op_device_key, device) + else: + other_known_ops = [ + 'update_loss_scaling', 'reduce_any', 'concat', 'sum' + ] + assert op.type in other_known_ops, "For other ops without " \ + "op_device set, they must be one of {}, but it " \ + "is {}".format(other_known_ops, op.type) + assert self._is_optimize_op(op) + op._set_attr(self._op_device_key, "gpu:all") + + def _add_op_device_attr(self, block): + """ + Add op_device attrribute for ops in block that have + not that attribute set. + """ + for idx, op in enumerate(list(block.ops)): + if (op.type == "create_py_reader" or op.type == "read" or + op.type == "create_double_buffer_reader"): + # Copy read related ops to all section to make them exit + # after each epoch. + # We use "gpu:all" to represent the op should be put on all + # sub-programs, such as lr-related ops. Note that: "gpu:all" + # is only used by pipeline as an indicator. + op._set_attr(self._op_device_key, "gpu:all") + continue + # op_device attribute has been set + if self._get_op_device_attr(op): continue + self._add_op_device_attr_for_op(op, idx, block) def _check_validation(self, block): """ - Check whether ops in a block are all validate (i.e., the - op_device attribute has been set). - Then, return all device specifications in order. + Check whether ops in a block have the op_device attribute set. + Then, return all devices in order. """ - device_specs = [] + device_list = [] for op in block.ops: - type = op.type - if not op._has_kernel(type): + if not op._has_kernel(op.type): assert op.type == "conditional_block" and ( op.attr(self._op_role_key) == int(self._op_role.LRSched)), ( "Now, the only supported op without kernel is " @@ -4165,15 +4393,16 @@ class PipelineOptimizer(object): assert op.has_attr(self._op_device_key), ( "op ({}) has no {} attribute.".format(op.type, self._op_device_key)) - dev_spec = op.attr(self._op_device_key) - assert dev_spec, ("op_device attribute for op " - "{} has not been set.".format(op.type)) - dev_type = dev_spec.split(':')[0] + device = op.attr(self._op_device_key) + assert device, ("op_device attribute for op " + "{} has not been set.".format(op.type)) + if device == "gpu:all": continue + dev_type = device.split(':')[0] assert dev_type == "gpu", ("Now only gpu devices are supported " "for pipeline parallelism.") - if not dev_spec in device_specs: - device_specs.append(dev_spec) - return device_specs + if not device in device_list: + device_list.append(device) + return device_list def _insert_sendrecv_ops_for_boundaries(self, block): """ @@ -4182,75 +4411,387 @@ class PipelineOptimizer(object): """ extra_index = 0 - # A map from var to device spec where op takes it as input, + # A map from var to device where op takes it as input, # avoiding multiple send and recv ops. - var_devspec = dict() + var_dev_map = dict() for index, op in enumerate(list(block.ops)): - # skips lr-related ops and vars, as we will process them later. - if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched): - continue - # skips update ops and vars, as we will process them later. - if self._is_update_op(op): continue - - cur_device_spec = op.attr(self._op_device_key) + cur_device = op.attr(self._op_device_key) + if cur_device == "gpu:all": continue for var_name in op.input_arg_names: # i.e., lod_tensor_blocking_queue created by DataLoader, # which only exists in startup program. - if not var_name in block.vars: continue + # if not var_name in block.vars: continue var = block.var(var_name) # skip data, because we will process it later if var.is_data: continue + prev_device = None + if var_name in self._param_device_map: + prev_device = self._param_device_map[var_name] prev_op = self._find_real_prev_op(block.ops, op, var_name) - if prev_op is None: - continue - prev_device_spec = prev_op.attr(self._op_device_key) + if not pre_device: + prev_device = prev_op.attr(self._op_device_key) \ + if prev_op else None + if not prev_device or prev_device == 'gpu:all': continue - if prev_device_spec != cur_device_spec: - if var_name not in var_devspec: - var_devspec[var_name] = [] - if cur_device_spec in var_devspec[var_name]: continue - var_devspec[var_name].append(cur_device_spec) + if prev_device != cur_device: + if var_name not in var_dev_map: var_dev_map[var_name] = [] + if cur_device in var_dev_map[var_name]: continue + var_dev_map[var_name].append(cur_device) op_role = op.all_attrs()[self._op_role_key] var = block.vars[var_name] - prev_device_index = int(prev_device_spec.split(':')[1]) - cur_device_index = int(cur_device_spec.split(':')[1]) + prev_device_index = int(prev_device.split(':')[1]) + cur_device_index = int(cur_device.split(':')[1]) + pair = (prev_device_index, cur_device_index) + pair_key = prev_device_index * 1000 + cur_device_index + if cur_device_index > prev_device_index: + ring_id = self.ring_id + cur_device_index - prev_device_index - 1 + else: + ring_id = self.ring_id + 2 + prev_device_index - cur_device_index - 1 + if self.schedule_mode == 0: # GPipe + block._insert_op( + index=index + extra_index, + type='send_v2', + inputs={'X': var}, + attrs={ + self._op_device_key: prev_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': cur_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='recv_v2', + outputs={'Out': [var]}, + attrs={ + 'out_shape': var.shape, + 'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': True, + 'peer': prev_device_index, + 'ring_id': self.ring_id + if cur_device_index > prev_device_index else + self.ring_id + 2, + }) + extra_index += 1 + continue + assert self.schedule_mode == 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) + self._pp_ring_map[pair_key] = self.ring_id + ring_id = self.ring_id + self.ring_id += 1 + else: + ring_id = self._pp_ring_map[pair_key] block._insert_op( index=index + extra_index, - type='send_v2', + #type='send_v2', + type='c_broadcast', inputs={'X': var}, + outputs={'Out': var}, attrs={ - self._op_device_key: prev_device_spec, + self._op_device_key: prev_device, self._op_role_key: op_role, - 'use_calc_stream': True, - 'peer': cur_device_index, + 'use_calc_stream': False, + #'peer': cur_device_index, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + 'ring_id': ring_id, + #'ring_id': self.ring_id, + #'root': prev_device_index, + 'root': 0, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + self._op_role_key: + core.op_proto_and_checker_maker.OpRole.Backward, + #self._op_role_key: op_role, + 'ring_id': ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + }) + extra_index += 1 + #block._insert_op( + # index=index + extra_index, + # type='c_sync_comm_stream', + # inputs={'X': [var]}, + # outputs={'Out': [var]}, + # attrs={ + # self._op_device_key: cur_device, + # self._op_role_key: + # core.op_proto_and_checker_maker.OpRole.Backward, + # 'ring_id': self.ring_id, + # #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + # }) + #extra_index += 1 + fill_shape = list(var.shape) + fill_shape[0] = self.pp_bz + block._insert_op( + index=index + extra_index, + #type='recv_v2', + type='fill_constant', + inputs={}, + outputs={'Out': [var]}, + attrs={ + 'shape': fill_shape, + 'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'value': float(0.0), + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, + self._op_role_key: op_role, + 'ring_id': ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + }) + extra_index += 1 + block._insert_op( + index=index + extra_index, + #type='recv_v2', + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + #'out_shape': var.shape, + #'dtype': var.dtype, + self._op_device_key: cur_device, + self._op_role_key: op_role, + 'use_calc_stream': False, + #'peer': prev_device_index, + #'root': prev_device_index, + 'root': 0, + #'ring_id': self.ring_id, + 'ring_id': ring_id, + #'ring_id': self.ring_id if cur_device_index > prev_device_index else self.ring_id + 2, + #'ring_id': self.ring_id if prev_device_index < cur_device_index else self.ring_id + 2, }) extra_index += 1 block._insert_op( index=index + extra_index, - type='recv_v2', + type='c_sync_comm_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: cur_device, + #self._op_role_key: core.op_proto_and_checker_maker.OpRole.Backward, + self._op_role_key: op_role, + 'ring_id': ring_id, + #'ring_id': self.ring_id if prev_device_index > cur_device_index else self.ring_id + 2, + }) + extra_index += 1 + + def _xx_insert_sendrecv_ops_for_boundaries(self, block): + """ + Insert a pair of send and recv ops for every two + consecutive ops on different devices. + """ + + extra_index = 0 + + # A map from var to device where op takes it as input, + # avoiding multiple send and recv ops. + input_var_to_device = dict() + + # A map from output var to op which generate it. + output_var_to_op = dict() + + for index, op in enumerate(list(block.ops)): + for var_name in op.output_arg_names: + ops = output_var_to_op.setdefault(var_name, []) + ops.append([op, index]) + + for index, op in enumerate(list(block.ops)): + cur_device = op.attr(self._op_device_key) + if cur_device == "gpu:all": continue + for var_name in op.input_arg_names: + var = block.var(var_name) + if var.is_data: continue + + #if var_name not in input_var_to_device: + # input_var_to_device[var_name] = [] + #if cur_device in input_var_to_device[var_name]: + # continue + #input_var_to_device[var_name].append(cur_device) + + prev_device = None + generate_ops = output_var_to_op.get(var_name) + if generate_ops is None: + if var_name not in self._param_device_map: + continue + prev_device = self._param_device_map[var_name] + + prev_op = None + for gen_op, gen_idx in reversed(generate_ops): + if gen_idx < index: + prev_op = gen_op + break + + if not prev_device: + prev_device = prev_op.attr(self._op_device_key) \ + if prev_op else None + + if prev_device is None or prev_device == 'gpu:all': continue + + if prev_device == cur_device: continue + + if var_name not in input_var_to_device: + input_var_to_device[var_name] = [] + if (cur_device, prev_device) in input_var_to_device[var_name]: + continue + + device_type = cur_device.split(':')[0] + ':' + + def _insert_send_recv(cur_id, prev_id): + nonlocal extra_index + + cur_dev = device_type + str(cur_id) + prev_dev = device_type + str(prev_id) + if (cur_dev, prev_dev) in input_var_to_device[var_name]: + return + + if cur_id - prev_id > 1: + _insert_send_recv(cur_id - 1, prev_id) + _insert_send_recv(cur_id, cur_id - 1) + input_var_to_device[var_name].append( + (cur_dev, prev_dev)) + return + elif cur_id - prev_id < -1: + _insert_send_recv(cur_id + 1, prev_id) + _insert_send_recv(cur_id, cur_id + 1) + input_var_to_device[var_name].append( + (cur_dev, prev_dev)) + return + + assert abs(cur_id - prev_id) == 1 + + input_var_to_device[var_name].append((cur_dev, prev_dev)) + + op_role = op.all_attrs()[self._op_role_key] + var = block.vars[var_name] + + pair = (prev_id, cur_id) + pair_key = prev_id * 1000 + cur_id + if cur_id > prev_id: + ring_id = self.ring_id + cur_id - prev_id - 1 + else: + ring_id = self.ring_id + 2 + prev_id - cur_id - 1 + + print("call xx_insert, schedule_mode:", self.schedule_mode) + assert self.schedule_mode == 1 + if pair not in self._pipeline_pair: + self._pipeline_pair.append(pair) + self._pp_ring_map[pair_key] = self.ring_id + ring_id = self.ring_id + self.ring_id += 1 + else: + ring_id = self._pp_ring_map[pair_key] + + print("opt: pp_pair: {}, ring_id: {}".format(pair, ring_id)) + + block._insert_op_without_sync( + index=index + extra_index, + type='c_sync_calc_stream', + inputs={'X': [var]}, + outputs={'Out': [var]}, + attrs={ + self._op_device_key: prev_dev, + self._op_role_key: op_role, + }) + extra_index += 1 + + block._insert_op_without_sync( + index=index + extra_index, + type="c_broadcast", + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + self._op_device_key: prev_dev, + self._op_role_key: op_role, + 'use_calc_stream': False, + 'ring_id': ring_id, + 'root': 0, + }) + extra_index += 1 + + fill_shape = list(var.shape) + fill_shape[0] = self.pp_bz + block._insert_op_without_sync( + index=index + extra_index, + type='fill_constant', + inputs={}, outputs={'Out': [var]}, attrs={ - 'out_shape': var.shape, + 'shape': fill_shape, 'dtype': var.dtype, - self._op_device_key: cur_device_spec, + self._op_device_key: cur_dev, + self._op_role_key: op_role, + 'value': float(0.0), + }) + extra_index += 1 + block._insert_op_without_sync( + index=index + extra_index, + type='c_broadcast', + inputs={'X': var}, + outputs={'Out': var}, + attrs={ + self._op_device_key: cur_dev, self._op_role_key: op_role, 'use_calc_stream': True, - 'peer': prev_device_index, + 'root': 0, + 'ring_id': ring_id, }) extra_index += 1 - def _clear_gradients(self, main_block, dev_spec): + # block._insert_op_without_sync( + # index=index + extra_index, + # type='c_sync_comm_stream', + # inputs={'X': [var]}, + # outputs={'Out': [var]}, + # attrs={ + # self._op_device_key: cur_dev, + # self._op_role_key: op_role, + # 'ring_id': ring_id, + # }) + # extra_index += 1 + + _insert_send_recv( + int(cur_device.split(':')[1]), + int(prev_device.split(':')[1])) + + block._sync_with_cpp() + + def _clear_gradients(self, main_block, param_names): """ Clear gradients at the begining of each run of a minibatch. """ - for param_name in self._param_device_map: - device = self._param_device_map[param_name] - if device != dev_spec: continue + # for param_name in self._param_device_map: + print("param_names:", param_names) + for param_name in param_names: + # device = self._param_device_map[param_name] + # if device != dev_spec: continue grad_name = self._append_grad_suffix(param_name) - if not main_block.has_var(grad_name): continue - grad_var = main_block.vars[grad_name] + # if not main_block.has_var(grad_name): continue + assert main_block.has_var(grad_name) + grad_var = main_block.var(grad_name) + grad_var.persistable = True main_block._insert_op( index=0, type='fill_constant', @@ -4260,21 +4801,20 @@ class PipelineOptimizer(object): 'shape': grad_var.shape, 'dtype': grad_var.dtype, 'value': float(0), - self._op_device_key: device, + # self._op_device_key: device, # a trick to run this op once per mini-batch self._op_role_key: self._op_role.Optimize.LRSched, }) - def _accumulate_gradients(self, block): + def _insert_loss_scale(self, block): """ - Accumulate the gradients generated in microbatch to the one in mini-batch. We also scale the loss corresponding to number of micro-batches as well. """ + if self._num_microbatches == 1: return for index, op in reversed(tuple(enumerate(list(block.ops)))): offset = index - device = op.attr(self._op_device_key) + #device = op.attr(self._op_device_key) - # Backward pass if self._is_loss_grad_op(op): loss_grad_var = block.vars[op.output_arg_names[0]] scale_factor = self._num_microbatches @@ -4285,36 +4825,130 @@ class PipelineOptimizer(object): outputs={'Out': loss_grad_var}, attrs={ 'scale': 1.0 / scale_factor, - self._op_device_key: device, + #self._op_device_key: device, self._op_role_key: self._op_role.Backward }) break + + def _rename_gradient_var_name(self, block): + for index, op in enumerate(block.ops): + if not self._is_optimize_op(op): continue + input_names = op.input_arg_names + output_names = op.output_arg_names + in_out_names = input_names + output_names + if op.type == 'cast': continue + # append "MERGED" to the names of parameter gradients, + # and mofify the op_role_var attribute (by rename_arg func). + for name in in_out_names: + if not core.grad_var_suffix() in name: continue + param_name = name.strip(core.grad_var_suffix()) + new_grad_name = name + "@MERGED" + self._rename_arg(op, name, new_grad_name) + + def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False): + """ + Create a new merged gradient for each parameter and accumulate the + corresponding gradient to it. + """ + merged_gradient_names = [] + first_opt_op_idx = None + + for index, op in reversed(tuple(enumerate(list(block.ops)))): + # remove the cast op of fp16 grad to fp32 grad + if self._is_optimize_op(op) and op.type == 'cast': + in_name = op.input_arg_names[0] + out_name = op.output_arg_names[0] + if out_name.strip('@GRAD') in self._param_device_map: + assert in_name.replace('.cast_fp16', '') == out_name + block._remove_op(index) + continue + + if self._is_backward_op(op) and not first_opt_op_idx: + first_opt_op_idx = index + 1 + if block.ops[first_opt_op_idx].type == "c_sync_comm_stream": + #block.ops[first_opt_op_idx]._set_attr(self._op_role_key, self._op_role.Backward) + first_opt_op_idx += 1 + if self._is_backward_op(op) and ( self._op_role_var_key in op.attr_names): - op_role_var = op.all_attrs()[self._op_role_var_key] + op_role_var = op.attr(self._op_role_var_key) if len(op_role_var) == 0: continue assert len(op_role_var) % 2 == 0 - offset = index + #op._remove_attr(self._op_role_var_key) for i in range(0, len(op_role_var), 2): - grad_name = op_role_var[i + 1] - grad_var = block.vars[grad_name] - new_grad_var_name = unique_name.generate(grad_name) - new_var = self._create_var(block, grad_var, - new_grad_var_name) - self._rename_arg(op, grad_name, new_grad_var_name) + offset = 0 + param_name = op_role_var[i] + if not block.has_var(param_name): continue + if '@BroadCast' in param_name: continue + param_grad_name = param_name + core.grad_var_suffix() + merged_param_grad_name = param_grad_name + '@MERGED' + if not block.has_var(merged_param_grad_name): + self._create_var(block, block.vars[param_name], + merged_param_grad_name) + assert block.has_var(merged_param_grad_name) + param_grad_var = block.var(param_grad_name) + merged_param_grad_var = block.var(merged_param_grad_name) + merged_param_grad_var.persistable = True block._insert_op( - index=offset + 1, - type='sum', - inputs={'X': [grad_var, new_var]}, - outputs={'Out': grad_var}, + index=first_opt_op_idx + offset, + type='fill_constant', + inputs={}, + outputs={'Out': [merged_param_grad_var]}, attrs={ - self._op_device_key: device, - self._op_role_key: self._op_role.Backward, - self._op_role_var_key: op_role_var + 'shape': merged_param_grad_var.shape, + 'dtype': merged_param_grad_var.dtype, + 'value': float(0), + # a trick to run this op once per mini-batch + self._op_role_key: self._op_role.Optimize.LRSched, }) offset += 1 + grad_name = op_role_var[i + 1] + grad_var = block.vars[grad_name] + if not 'cast_fp16' in grad_name: + block._insert_op( + index=first_opt_op_idx + offset, + type='sum', + inputs={'X': [grad_var, merged_param_grad_var]}, + outputs={'Out': merged_param_grad_var}, + attrs={ + self._op_role_key: self._op_role.Backward, + }) + offset += 1 + merged_gradient_names.append(merged_param_grad_name) + else: + # cast gradient to fp32 to accumulate to merged gradient + cast_grad_var_name = param_grad_name + '@TMP' + cast_grad_var = self._create_var(block, param_grad_var, + cast_grad_var_name) + cast_grad_var.persistable = False + block._insert_op( + index=first_opt_op_idx + offset, + type='cast', + inputs={'X': grad_var}, + outputs={'Out': cast_grad_var}, + attrs={ + 'in_dtype': grad_var.dtype, + 'out_dtype': cast_grad_var.dtype, + self._op_role_key: self._op_role.Backward, + }) + offset += 1 + block._insert_op( + index=first_opt_op_idx + offset, + type='sum', + inputs={ + 'X': [merged_param_grad_var, cast_grad_var] + }, + outputs={'Out': merged_param_grad_var}, + attrs={ + # self._op_device_key: device, + self._op_role_key: self._op_role.Backward, + #self._op_role_var_key: op_role_var + }) + offset += 1 + merged_gradient_names.append(merged_param_grad_name) + return merged_gradient_names def _add_sub_blocks(self, main_block, program_list): main_program = main_block.program @@ -4372,7 +5006,7 @@ class PipelineOptimizer(object): block = prog.block(0) for op in block.ops: if op.type == "recv_v2" or op.type == "create_py_reader" or \ - op.type == "read": + op.type == "read" or op.type == "update_loss_scaling": continue # We have processed lr related vars if op.attr(self._op_role_key) == int( @@ -4407,11 +5041,14 @@ class PipelineOptimizer(object): inputs={'X': write_block.var(var_name), }, attrs={ self._op_device_key: write_device, - 'use_calc_stream': True, + 'use_calc_stream': False, # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, 'peer': read_dev_index, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + read_dev_index > write_dev_index else self.ring_id + 2, }) read_block._insert_op( index=0, @@ -4421,34 +5058,77 @@ class PipelineOptimizer(object): 'out_shape': read_block.var(var_name).shape, 'dtype': read_block.var(var_name).dtype, self._op_device_key: read_device, - 'use_calc_stream': True, + 'use_calc_stream': False, + # A trick to make the role LRSched to avoid copy every + # microbatch + self._op_role_key: self._op_role.LRSched, + 'peer': write_dev_index, + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + write_dev_index < read_dev_index else self.ring_id + 2, + }) + read_block._insert_op( + index=1, + type='c_sync_comm_stream', + inputs={'X': [read_block.var(var_name)]}, + outputs={'Out': [read_block.var(var_name)]}, + attrs={ + self._op_device_key: read_device, # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, - 'peer': write_dev_index + #'ring_id': self.ring_id, + 'ring_id': self.ring_id if + write_dev_index > read_dev_index else self.ring_id + 2, }) + def _is_gradient_clip_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/gradient_clip") + + def _is_regularization_op(self, op): + return op.desc.has_attr("op_namescope") \ + and op.desc.attr("op_namescope").startswith("/regularization") + def minimize(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): main_block = loss.block + self.origin_main_block = main_block if startup_program is None: startup_program = default_startup_program() optimize_ops, params_grads = self._optimizer.minimize( loss, startup_program, parameter_list, no_grad_set) - self._param_device_map = self._optimizer._param_device_map - - # Step1: add default op_device attribute for regulization and clip ops - self._add_opdevice_attr_for_regularization_clip(main_block) - - # Step2: add default op_device attribute for ops whose op_device - # attribute have not been set yet. Then check all ops have the - # op_device attribute. - self._add_default_opdevice_attr(main_block) - - device_specs = self._check_validation(main_block) + self._param_device_map = self._origin_optimizer._param_device_map + assert main_block.program._pipeline_opt \ + and 'local_rank' in main_block.program._pipeline_opt, \ + 'Please use pipeline with fleet.' + local_rank = main_block.program._pipeline_opt['local_rank'] + schedule_mode = 0 + if 'schedule_mode' in main_block.program._pipeline_opt: + schedule_mode = main_block.program._pipeline_opt['schedule_mode'] + self.schedule_mode = schedule_mode + self.pp_bz = main_block.program._pipeline_opt['pp_bz'] + + self.use_sharding = False + if 'use_sharding' in main_block.program._pipeline_opt: + self.use_sharding = main_block.program._pipeline_opt['use_sharding'] + + self.ring_id = 0 + if 'ring_id' in main_block.program._pipeline_opt: + self.ring_id = main_block.program._pipeline_opt['ring_id'] + + if main_block.program._pipeline_opt['global_rank'] == 0: + with open("startup_raw", 'w') as f: + f.writelines(str(startup_program)) + with open("main_raw", 'w') as f: + f.writelines(str(main_block.program)) + + # Step1: add default op_device attribute for ops. + self._add_op_device_attr(main_block) + device_list = self._check_validation(main_block) def device_cmp(device1, device2): dev1_id = int(device1.split(':')[1]) @@ -4460,66 +5140,169 @@ class PipelineOptimizer(object): else: return 0 - sorted_device_spec = sorted(device_specs, key=cmp_to_key(device_cmp)) - assert sorted_device_spec == device_specs, ( - "With pipeline " - "parallelism, you must use gpu devices one after another " - "in the order of their ids.") + sorted_device_list = sorted(device_list, key=cmp_to_key(device_cmp)) + assert sorted_device_list == device_list, ( + "With pipeline parallelism, you must use gpu devices one after " + "another in the order of their ids.") - # Step3: add send and recv ops between section boundaries - self._insert_sendrecv_ops_for_boundaries(main_block) + # Step2: add send and recv ops between section boundaries + self._xx_insert_sendrecv_ops_for_boundaries(main_block) - # Step4: split program into sections and add pairs of + # Step3: split program into sections and add pairs of # send and recv ops for data var. main_program = main_block.program - program_list = self._split_program(main_program, device_specs) + program_list = self._split_program(main_program, device_list) + #cur_device_index = 0 + #device_num = len(program_list) for p in program_list: - self._create_vars(p["program"].block(0), - main_program.global_block()) - self._insert_sendrecv_for_data_var(main_block, program_list, - startup_program, device_specs) - - # Step5: Special Case: process persistable vars that exist in + self._create_vars(p["program"].block(0), main_block) + # # Add send/recv pair to sync the execution. + # block = p['program'].block(0) + # prev_device_index = cur_device_index - 1 + # next_device_index = cur_device_index + 1 + # add_send_for_forward = False + # add_send_for_backward = False + # add_recv_for_backward = False + # extra_index = 0 + # new_var = block.create_var( + # name=unique_name.generate('sync'), + # shape=[1], + # dtype='float32', + # persistable=False, + # stop_gradient=True) + # block._insert_op( + # index=0, + # type='fill_constant', + # inputs={}, + # outputs={'Out': [new_var]}, + # attrs={ + # 'shape': [1], + # 'dtype': new_var.dtype, + # self._op_role_key: self._op_role.Forward, + # 'value': float(0.0), + # }) + # extra_index += 1 + # for op_idx, op in enumerate(list(block.ops)): + # if op_idx == extra_index: + # if cur_device_index > 0: + # pair_key = prev_device_index * 1000 + cur_device_index + # ring_id = self._pp_ring_map[pair_key] + # block._insert_op( + # index=op_idx, + # type='recv_v2', + # outputs={'Out': [new_var]}, + # attrs={ + # 'out_shape': new_var.shape, + # 'dtype': new_var.dtype, + # self._op_role_key: self._op_role.Forward, + # 'peer': 0, + # 'use_calc_stream': True, + # 'ring_id': ring_id, + # }) + # extra_index += 1 + # continue + # if op.type == "send_v2" and self._is_forward_op(op) \ + # and not add_send_for_forward \ + # and cur_device_index < device_num - 1: + # add_send_for_forward = True + # pair_key = cur_device_index * 1000 + next_device_index + # ring_id = self._pp_ring_map[pair_key] + # block._insert_op( + # index=op_idx + extra_index, + # type='send_v2', + # inputs={'Out': new_var}, + # attrs={ + # 'out_shape': new_var.shape, + # 'dtype': new_var.dtype, + # self._op_role_key: self._op_role.Forward, + # 'peer': 1, + # 'use_calc_stream': True, + # 'ring_id': ring_id, + # }) + # extra_index += 1 + # if self._is_backward_op(op) and not add_recv_for_backward \ + # and cur_device_index < device_num - 1: + # pair_key = next_device_index * 1000 + cur_device_index + # add_recv_for_backward = True + # ring_id = self._pp_ring_map[pair_key] + # block._insert_op( + # index=op_idx + extra_index, + # type='recv_v2', + # outputs={'Out': [new_var]}, + # attrs={ + # 'out_shape': new_var.shape, + # 'dtype': new_var.dtype, + # self._op_role_key: self._op_role.Backward, + # 'peer': 0, + # 'use_calc_stream': True, + # 'ring_id': ring_id, + # }) + # if op.type == "send_v2" and self._is_backward_op(op) \ + # and not add_send_for_backward \ + # and cur_device_index > 0: + # pair_key = cur_device_index * 1000 + prev_device_index + # add_send_for_backward = True + # ring_id = self._pp_ring_map[pair_key] + # block._insert_op( + # index=op_idx + extra_index, + # type='send_v2', + # outputs={'Out': [new_var]}, + # attrs={ + # 'out_shape': new_var.shape, + # 'dtype': new_var.dtype, + # self._op_role_key: self._op_role.Backward, + # 'peer': 1, + # 'use_calc_stream': True, + # 'ring_id': ring_id, + # }) + # cur_device_index += 1 + #self._insert_sendrecv_for_data_var(main_block, program_list, + # startup_program, device_list) + + # Step4: Special Case: process persistable vars that exist in # multiple sections - self._process_persistable_vars_in_multi_sections( - main_program, startup_program, program_list) + #self._process_persistable_vars_in_multi_sections( + # main_program, startup_program, program_list) - # Step6: Add sub blocks for section programs + # Step5: Add sub blocks for section programs self._add_sub_blocks(main_block, program_list) - assert (main_program._pipeline_opt and - isinstance(main_program._pipeline_opt, dict) and - 'local_rank' in main_program._pipeline_opt), \ - "You must use pipeline with fleet" - local_rank = main_program._pipeline_opt['local_rank'] % len( - device_specs) + local_rank = main_program._pipeline_opt['local_rank'] % len(device_list) place_list = [] - for dev_spec in device_specs: - dev_index = dev_spec.split(":")[1] - place_list.append(core.CUDAPlace(local_rank)) + for dev in device_list: + dev_index = int(dev.split(":")[1]) + place_list.append(core.CUDAPlace(dev_index % 8)) - # Step7: Split startup program + # Step6: Split startup program new_startup_program = self._split_startup_program(startup_program, local_rank) - - # Step8: clear gradients before each mini-batch and - # accumulate gradients during backward - self._clear_gradients( - program_list[local_rank]['program'].global_block(), - dev_spec=device_specs[local_rank]) - self._accumulate_gradients(program_list[local_rank]['program'] - .global_block()) - startup_program._pipeline_opt = { "startup_program": new_startup_program, } + real_block = program_list[local_rank]['program'].global_block() + self._insert_loss_scale(real_block) + if not self.use_sharding: + # Step7: clear gradients before each mini-batch and + # accumulate gradients during backward + param_list = [] + for param, grad in params_grads: + if real_block.has_var(param): param_list.append(param) + #self._clear_gradients(real_block, param_list) + self._rename_gradient_var_name(real_block) + real_block._sync_with_cpp() + self._accumulate_gradients(real_block) + real_block._sync_with_cpp() + place_id = int(os.getenv("FLAGS_selected_gpus", "0")) main_program._pipeline_opt = { "trainer": "PipelineTrainer", "device_worker": "Section", - "inner_parallelism": len(device_specs), + "inner_parallelism": len(device_list), + "num_pipeline_stages": len(device_list), + "pipeline_stage": local_rank, + "schedule_mode": schedule_mode, "section_program": program_list[local_rank], "place": place_list[local_rank], "place_id": place_id, @@ -4527,7 +5310,7 @@ class PipelineOptimizer(object): "num_microbatches": self._num_microbatches, "start_cpu_core_id": self._start_cpu_core_id, } - return optimize_ops, params_grads, program_list + return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map class RecomputeOptimizer(Optimizer): @@ -4928,10 +5711,10 @@ class RecomputeOptimizer(Optimizer): for output_var in output_vars: if output_var in need_offload_checkpoint_names: - assert len( - output_vars - ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( - output_var, op) + #assert len( + # output_vars + #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + # output_var, op) if output_var in self.un_offload_checkpoint_names: # insert sync op if last checkpoint has not been sync @@ -4956,14 +5739,14 @@ class RecomputeOptimizer(Optimizer): format(output_var)) # need to sync the last need to offload checkpoint before the last checkpoint as output op if output_var == last_checkpoint: - assert len( - output_vars - ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( - output_var, op) - assert last_offload_checkpoint == self.sorted_checkpoint_names[ - -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( - last_checkpoint, self.sorted_checkpoint_names[-2], - last_offload_checkpoint) + #assert len( + # output_vars + #) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + # output_var, op) + #assert last_offload_checkpoint == self.sorted_checkpoint_names[ + # -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( + # last_checkpoint, self.sorted_checkpoint_names[-2], + # last_offload_checkpoint) # sync if last checkpoint has not been sync if self.checkpoint_usage_count_and_idx[ last_offload_checkpoint]['idx'] == 0: @@ -5487,7 +6270,7 @@ class GradientMergeOptimizer(object): def _is_the_backward_op(self, op): op_maker = core.op_proto_and_checker_maker - backward = core.op_proto_and_checker_maker.OpRole.Backward + backward = core.op_proto_and_checker_maker.OpRole.Bcackward if op_maker.kOpRoleVarAttrName() in op.attr_names and \ int(op.all_attrs()[op_maker.kOpRoleAttrName()]) == int(backward): return True