diff --git a/paddle/fluid/framework/device_worker.h b/paddle/fluid/framework/device_worker.h index 49908be40e7d743c3d02c9ad686d3234de21279c..2f047eb6de914a184d804e8dd4d9a5c9680822a9 100644 --- a/paddle/fluid/framework/device_worker.h +++ b/paddle/fluid/framework/device_worker.h @@ -441,14 +441,14 @@ class SectionWorker : public DeviceWorker { void SetSkipVars(const std::vector& skip_vars) { skip_vars_ = skip_vars; } + void SetStartCpuCoreId(int id) { cpu_id_ = id; } // static void ResetBatchId() { batch_id_ = 0; } - static std::atomic cpu_id_; - protected: void AutoSetCPUAffinity(bool reuse); int section_id_; int thread_id_; + int cpu_id_; int num_microbatches_; std::vector microbatch_scopes_; std::vector skip_vars_; diff --git a/paddle/fluid/framework/pipeline_trainer.cc b/paddle/fluid/framework/pipeline_trainer.cc index 8c6c31574374142dceeff538b7b386c649c55577..62429b7bee14a5e6f74ac6972372972e5a3c4c57 100644 --- a/paddle/fluid/framework/pipeline_trainer.cc +++ b/paddle/fluid/framework/pipeline_trainer.cc @@ -34,8 +34,8 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ParseDumpConfig(trainer_desc); // get filelist from trainer_desc here // const std::vector readers = - // VLOG(3) << "Number of program sections: " << section_num_; // dataset->GetReaders(); + // VLOG(3) << "Number of program sections: " << section_num_; // VLOG(3) << "readers num: " << readers.size(); // int num_readers = readers.size(); // PADDLE_ENFORCE_EQ(num_readers, 1, @@ -108,6 +108,7 @@ void PipelineTrainer::Initialize(const TrainerDesc& trainer_desc, this_worker->SetPlace(place_); this_worker->Initialize(trainer_desc); this_worker->SetMicrobatchNum(num_microbatches_); + this_worker->SetStartCpuCoreId(start_cpu_core_id_); // set debug here SetDebug(trainer_desc.debug()); @@ -207,7 +208,7 @@ void PipelineTrainer::CopyParameters(int microbatch_id, } else if (!var->Persistable() && !is_param_grad) { auto* ptr = microbatch_scopes_[microbatch_id]->Var(var->Name()); VLOG(3) << "Create variable " << var->Name() << " microbatch " - << ", which pointer is " << ptr; + << microbatch_id << ", which pointer is " << ptr; InitializeVariable(ptr, var->GetType()); } } @@ -235,39 +236,40 @@ void PipelineTrainer::CopyParameters(int microbatch_id, // } // } -void PipelineTrainer::GetSkipVars(const ProgramDesc& program) { - auto& global_block = program.Block(0); - for (auto& op : global_block.AllOps()) { - if (op->Type() != "c_send") { - continue; - } - auto input_arg_names = op->InputArgumentNames(); - PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, - platform::errors::InvalidArgument( - "Number of input arguments for c_send op must be 1, " - "but the value given is %d.", - input_arg_names.size())); - std::string input_arg_name = input_arg_names[0]; - if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { - skip_vars_.emplace_back(input_arg_name); - VLOG(3) << "add skip var name: " << input_arg_name; - } - } -} +// void PipelineTrainer::GetSkipVars(const ProgramDesc& program) { +// auto& global_block = program.Block(0); +// for (auto& op : global_block.AllOps()) { +// if (op->Type() != "c_send") { +// continue; +// } +// auto input_arg_names = op->InputArgumentNames(); +// PADDLE_ENFORCE_EQ(input_arg_names.size(), 1, +// platform::errors::InvalidArgument( +// "Number of input arguments for c_send op must be 1, +// " +// "but the value given is %d.", +// input_arg_names.size())); +// std::string input_arg_name = input_arg_names[0]; +// if (input_arg_name.rfind("@GRAD") != input_arg_name.size() - 5) { +// skip_vars_.emplace_back(input_arg_name); +// VLOG(3) << "add skip var name: " << input_arg_name; +// } +// } +// } void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, const platform::Place& place) { PADDLE_ENFORCE_NOT_NULL(root_scope_, platform::errors::InvalidArgument( "root_scope_ can not be nullptr")); - auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id(); - SectionWorker::cpu_id_.store(start_cpu_id); + // auto start_cpu_id = trainer_desc_.section_param().start_cpu_core_id(); + // SectionWorker::cpu_id_.store(start_cpu_id); // minibatch_scopes_.resize(section_num_); // microbatch_scopes_.resize(section_num_); // minibatch_scopes_.resize(1); microbatch_scopes_.resize(num_microbatches_); // skip_vars_.resize(section_num_); - VLOG(3) << "Init ScopeQueues and create all scopes"; + VLOG(3) << "Create minibatch and microbatch scopes..."; // for (int i = 0; i < section_num_; ++i) { minibatch_scope_ = &root_scope_->NewScope(); std::shared_ptr program; @@ -282,7 +284,7 @@ void PipelineTrainer::InitTrainerEnv(const ProgramDesc& main_program, CopyParameters(j, *program, place_); } // GetSkipVars(i, *program); - GetSkipVars(*program); + // GetSkipVars(*program); // } // for (int i = 0; i < section_num_; ++i) { diff --git a/paddle/fluid/framework/section_worker.cc b/paddle/fluid/framework/section_worker.cc index 29beffdbd58f6371d12039abc1a2ccc5ef2939a0..d86eaf1ac5988766da9d1cde371110a4c45578c7 100644 --- a/paddle/fluid/framework/section_worker.cc +++ b/paddle/fluid/framework/section_worker.cc @@ -30,7 +30,7 @@ limitations under the License. */ namespace paddle { namespace framework { -std::atomic SectionWorker::cpu_id_(0); +// std::atomic SectionWorker::cpu_id_(0); // std::mutex SectionWorker::thread_mutex; // std::mutex SectionWorker::cout_mutex; // std::condition_variable SectionWorker::thread_condition; @@ -48,18 +48,20 @@ void SectionWorker::Initialize(const TrainerDesc& desc) { } void SectionWorker::AutoSetCPUAffinity(bool reuse) { - int thread_cpu_id = cpu_id_.fetch_add(1); + // int thread_cpu_id = cpu_id_.fetch_add(1); unsigned concurrency_cap = std::thread::hardware_concurrency(); - unsigned proc = thread_cpu_id; + // unsigned proc = thread_cpu_id; + unsigned proc = cpu_id_; if (proc >= concurrency_cap) { if (reuse) { proc %= concurrency_cap; } else { LOG(INFO) << "All " << concurrency_cap - << " CPUs have been set affinities. Fail to set " - << thread_cpu_id << "th thread"; + << " CPUs have been set affinities. Fail to set " << cpu_id_ + << "th thread."; + // << thread_cpu_id << "th thread"; return; } } @@ -78,7 +80,8 @@ void SectionWorker::AutoSetCPUAffinity(bool reuse) { (0 == CPU_ISSET(proc, &mask))) { LOG(WARNING) << "Fail to set thread affinity to CPU " << proc; } - VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; + // VLOG(3) << "Set " << thread_cpu_id << "th thread affinity to CPU " << proc; + VLOG(3) << "Set " << cpu_id_ << "th thread affinity to CPU " << proc; } void SectionWorker::TrainFiles() { @@ -141,7 +144,8 @@ void SectionWorker::TrainFiles() { VLOG(3) << "thread completed."; // VLOG(3) << "called notify all"; // thread_condition.notify_all(); - VLOG(0) << "EOF encountered"; + VLOG(3) << "EOF encountered"; + // throw platform::EOFException(); break; } } @@ -191,8 +195,8 @@ void SectionWorker::TrainFilesWithProfiler() { platform::Timer batch_timer; platform::Timer timeline; - std::vector op_total_time; std::vector op_name; + std::vector op_total_time; std::vector op_max_time; std::vector op_min_time; std::vector op_count; @@ -204,6 +208,7 @@ void SectionWorker::TrainFilesWithProfiler() { op_min_time.resize(ops_.size()); for (size_t i = 0; i < op_min_time.size(); ++i) { op_min_time[i] = DBL_MAX; + op_max_time[i] = 0.0; } op_count.resize(ops_.size()); @@ -235,7 +240,7 @@ void SectionWorker::TrainFilesWithProfiler() { struct timeval micro_end; // Start a minibatch. batch_timer.Start(); - int real_microbatch_num = 0; + // int real_microbatch_num = 0; for (int i = 0; i < num_microbatches_; ++i) { try { int op_idx = 0; @@ -253,8 +258,9 @@ void SectionWorker::TrainFilesWithProfiler() { op_role == (static_cast(OpRole::kForward) | static_cast(OpRole::kLoss)); if ((i == 0 && run_first_mbatch) || (i != 0 && run_others)) { - VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ - << " for scope " << i; + // VLOG(3) << "running an op " << op->Type() << " for " << thread_id_ + // << " for scope " << i; + VLOG(3) << "running an op " << op->Type() << " for scope " << i; timeline.Start(); op->Run(*microbatch_scopes_[i], place_); if (gc) { @@ -365,11 +371,11 @@ void SectionWorker::TrainFilesWithProfiler() { } } dev_ctx_->Wait(); - if (real_microbatch_num == 0) { - batch_timer.Pause(); - VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); - return; - } + // if (real_microbatch_num == 0) { + // batch_timer.Pause(); + // VLOG(0) << "batch time: " << batch_timer.ElapsedUS(); + // return; + // } // update pass int op_idx = 0; gettimeofday(µ_start, NULL); diff --git a/paddle/fluid/framework/trainer_desc.proto b/paddle/fluid/framework/trainer_desc.proto index a83107c0d2e5499819f92b58f575b0669e410a19..fc2c4b1f076bae1fca082b7b36515dce5bae0921 100644 --- a/paddle/fluid/framework/trainer_desc.proto +++ b/paddle/fluid/framework/trainer_desc.proto @@ -84,7 +84,7 @@ message DownpourWorkerParameter { } message SectionWorkerParameter { - SectionConfig section_config = 1; + optional SectionConfig section_config = 1; optional int32 queue_size = 2 [ default = 1 ]; optional int64 sync_steps = 3 [ default = 1 ]; optional int32 start_cpu_core_id = 4 [ default = 1 ]; diff --git a/paddle/fluid/operators/collective/c_recv_op.cc b/paddle/fluid/operators/collective/c_recv_op.cc index a3d59648e09735f90a4e82dd97e81d6dba2a9e51..10e0ba50e038e861c752d4583dd20e042ee483f7 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/collective/c_recv_op.h" +#include namespace paddle { namespace operators { @@ -33,14 +34,36 @@ class CRecvOp : public framework::OperatorWithKernel { ring_id, 0, platform::errors::InvalidArgument( "The ring_id (%d) for c_send_op must be non-negative.", ring_id)); + auto out_shape = ctx->Attrs().Get>("out_shape"); + PADDLE_ENFORCE_GE(out_shape.size(), 1, + platform::errors::InvalidArgument( + "The size of the output shape must be greater than 0 " + "but the value given is %d.", + out_shape.size())); } protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext& ctx) const override { - auto out = ctx.Output("Out"); - auto dtype = out->type(); - return framework::OpKernelType(dtype, ctx.GetPlace()); + VLOG(0) << "wow1"; + std::string dtype = ctx.Attr("dtype"); + framework::proto::VarType::Type type; + if (dtype == "fp32") { + type = framework::proto::VarType::FP32; + } else if (dtype == "fp64") { + type = framework::proto::VarType::FP64; + } else if (dtype == "fp16") { + type = framework::proto::VarType::FP16; + } else if (dtype == "int32") { + type = framework::proto::VarType::INT32; + } else if (dtype == "int64") { + type = framework::proto::VarType::INT64; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unknown data type %s for c_recv op.", dtype)); + } + VLOG(0) << "wow2"; + return framework::OpKernelType(type, ctx.GetPlace()); // OperatorWithKernel::IndicateVarDataType(ctx, "Out"), ctx.GetPlace()); } }; @@ -52,6 +75,11 @@ class CRecvOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr("ring_id", "(int default 0) nccl communication ring id.") .SetDefault(0); AddAttr("peer", "(int default 0) rank id for sender.").SetDefault(0); + AddAttr("dtype", + "(std::string default fp32) data type of tensor.") + .SetDefault("fp32"); + AddAttr>("out_shape", "shape of the output tensor.") + .SetDefault(std::vector()); AddAttr( "use_calc_stream", "(bool default false) eject CUDA operations to calculation stream.") diff --git a/paddle/fluid/operators/collective/c_recv_op.cu.cc b/paddle/fluid/operators/collective/c_recv_op.cu.cc index 4a716ab61b5aa2bfb97b61729be11f4cfd618b3d..5ea96dfed583b5231140b0f501995238467f6839 100644 --- a/paddle/fluid/operators/collective/c_recv_op.cu.cc +++ b/paddle/fluid/operators/collective/c_recv_op.cu.cc @@ -27,13 +27,20 @@ class CRecvOpCUDAKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { #if defined(PADDLE_WITH_NCCL) + VLOG(0) << "here1"; auto out = ctx.Output("Out"); - int numel = out->numel(); - ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); + VLOG(0) << "here2"; + auto out_shape = ctx.Attr>("out_shape"); + auto out_dims = paddle::framework::make_ddim(out_shape); int rid = ctx.Attr("ring_id"); auto place = ctx.GetPlace(); auto comm = platform::NCCLCommContext::Instance().Get(rid, place); + out->mutable_data(out_dims, place); + VLOG(0) << "out_dims:" << out_dims; + ncclDataType_t dtype = platform::ToNCCLDataType(out->type()); + int numel = out->numel(); + VLOG(0) << "numel:" << numel; cudaStream_t stream = nullptr; if (ctx.Attr("use_calc_stream")) { @@ -49,9 +56,10 @@ class CRecvOpCUDAKernel : public framework::OpKernel { platform::errors::InvalidArgument("The value of peer (%d) you set must " "be less than comm->nranks (%d).", peer, comm->nranks())); + VLOG(0) << "here3"; PADDLE_ENFORCE_CUDA_SUCCESS(platform::dynload::ncclRecv( out->data(), numel, dtype, peer, comm->comm(), stream)); - VLOG(3) << "rank " << comm->rank() << " recv " + VLOG(0) << "rank " << comm->rank() << " recv " << framework::product(out->dims()) << " from " << peer; #else PADDLE_THROW( diff --git a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py index d5a45e2b4e1aeda2e1c66c0a5a36236622f093ec..d82f240af7f531193ca6a7f49773e03e882c915f 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/pipeline_optimizer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and from __future__ import print_function +from __future__ import division import paddle.fluid as fluid from paddle.fluid import core, unique_name @@ -21,9 +22,50 @@ from .meta_optimizer_base import MetaOptimizerBase from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op -class PipelineHelper(CollectiveHelper): - def __init__(self, role_maker, nrings=1, wait_port='6174'): - super(PipelineHelper, self).__init__(role_maker, nrings, wait_port) +def _get_node_num(endpoints): + ss = set() + for ep in endpoints: + ip = ep.split(":")[0].strip() + if ip not in ss: + ss.add(ip) + return len(ss) + + +class PipelineHelper(object): + def __init__(self, role_maker, wait_port='6174'): + self.wait_port = wait_port + self.role_maker = role_maker + + def update_startup_program(self, startup_program=None): + self.startup_program = startup_program + if startup_program is None: + self.startup_program = fluid.default_startup_program() + + endpoints = self.role_maker.get_trainer_endpoints() + current_endpoint = endpoints[self.role_maker.worker_index()] + node_num = _get_node_num(endpoints) + assert len(endpoints) % node_num == 0 + gpus_per_node = len(endpoints) // node_num + + # Create a global ring for all gpus + print("current_endpoint:", current_endpoint) + print("endpoints:", endpoints) + print("rank:", self.role_maker.worker_index()) + self._init_communicator( + self.startup_program, current_endpoint, endpoints, + self.role_maker.worker_index(), 0, self.wait_port) + + if node_num == 1: return + # Create rings for gpus with the same gpu id + eps = [] + local_rank = self.role_maker.worker_index() % gpus_per_node + ring_id = local_rank + 1 + for i in range(node_num): + eps.append(endpoints[i * gpus_per_node + local_rank]) + temp_rank = self.role_maker.worker_index() // node_num + self._init_communicator(self.startup_program, current_endpoint, eps, + temp_rank, ring_id, self.wait_port) + self._broadcast_params(ring_id) def _init_communicator(self, program, current_endpoint, endpoints, rank, ring_id, wait_port): @@ -46,9 +88,8 @@ class PipelineHelper(CollectiveHelper): 'rank': rank, 'endpoint': current_endpoint, 'other_endpoints': other_endpoints, - OP_ROLE_KEY: OpRole.Forward + OP_ROLE_KEY: OpRole.Forward, }) - block.append_op( type='c_comm_init', inputs={'X': nccl_id_var}, @@ -58,12 +99,10 @@ class PipelineHelper(CollectiveHelper): 'rank': rank, 'ring_id': ring_id, OP_ROLE_KEY: OpRole.Forward, - 'device_id': OpRole.Forward }) - def _broadcast_params(self): + def _broadcast_params(self, ring_id): block = self.startup_program.global_block() - ring_id = 0 for param in block.iter_parameters(): if param.is_distributed: continue @@ -78,13 +117,12 @@ class PipelineHelper(CollectiveHelper): OP_ROLE_KEY: OpRole.Forward }) - for ring_id in range(self.nrings): - block.append_op( - type='c_sync_comm_stream', - inputs={'X': param}, - outputs={'Out': param}, - attrs={'ring_id': ring_id, - OP_ROLE_KEY: OpRole.Forward}) + block.append_op( + type='c_sync_comm_stream', + inputs={'X': param}, + outputs={'Out': param}, + attrs={'ring_id': ring_id, + OP_ROLE_KEY: OpRole.Forward}) class PipelineOptimizer(MetaOptimizerBase): @@ -100,7 +138,12 @@ class PipelineOptimizer(MetaOptimizerBase): super(PipelineOptimizer, self)._set_basic_info( loss, role_maker, user_defined_optimizer, user_defined_strategy) num_microbatches = user_defined_strategy.pipeline_configs['micro_batch'] - self.wrapped_opt = PO(self.inner_opt, num_microbatches=num_microbatches) + endpoints = role_maker.get_trainer_endpoints() + current_endpoint = endpoints[role_maker.worker_index()] + self.local_rank = self._get_local_rank(current_endpoint, endpoints) + self.wrapped_opt = PO(self.inner_opt, + num_microbatches=num_microbatches, + start_cpu_core_id=self.local_rank) def _can_apply(self): if self.user_defined_strategy.pipeline == True: @@ -111,23 +154,37 @@ class PipelineOptimizer(MetaOptimizerBase): dist_strategy.pipeline = False dist_strategy.pipeline_configs = {} + def _get_local_rank(self, current_endpoint, endpoints): + cur_node_endpoints = [] + cur_ip = current_endpoint.split(':')[0].strip() + for ep in endpoints: + if cur_ip == ep.split(':')[0].strip(): + cur_node_endpoints.append(ep) + return cur_node_endpoints.index(current_endpoint) + def minimize_impl(self, loss, startup_program=None, parameter_list=None, no_grad_set=None): - optimize_ops, params_grads, prog_list = \ - self.wrapped_opt.minimize(loss, startup_program, - parameter_list, no_grad_set) - if self.role_maker.worker_num() == 1: - return optimize_ops, params_grads - endpoints = self.role_maker.get_trainer_endpoints() current_endpoint = endpoints[self.role_maker.worker_index()] + node_num = _get_node_num(endpoints) + gpus_per_node = len(endpoints) // node_num self.startup_program = startup_program + self.local_rank = self._get_local_rank(current_endpoint, endpoints) if startup_program is None: self.startup_program = fluid.default_startup_program() + if self.role_maker.worker_num() == 1: + return self.inner_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + loss.block.program._pipeline_opt = dict() + loss.block.program._pipeline_opt['local_rank'] = self.local_rank + optimize_ops, params_grads, prog_list = \ + self.wrapped_opt.minimize(loss, startup_program, + parameter_list, no_grad_set) + assert prog_list self.main_program_list = prog_list self.main_program = loss.block.program @@ -139,24 +196,24 @@ class PipelineOptimizer(MetaOptimizerBase): self.endpoints = endpoints self.current_endpoint = current_endpoint - pipeline_helper = PipelineHelper(self.role_maker, nrings=self.nrings) + pipeline_helper = PipelineHelper(self.role_maker) pipeline_helper.update_startup_program(self.startup_program) - self._transpile_main_program() + self._transpile_main_program(loss, node_num, gpus_per_node) return optimize_ops, params_grads - def _transpile_main_program(self): - self._insert_loss_grad_ops() - for ring_id in range(self.nrings): + def _transpile_main_program(self, loss, node_num, gpus_per_node): + self._insert_loss_grad_ops(loss, gpus_per_node, node_num) + for ring_id in range(1, node_num + 1): self._insert_allreduce_ops(ring_id) - def _insert_loss_grad_ops(self): + def _insert_loss_grad_ops(self, loss, gpus_per_node, node_num): """ In order to keep the learning rate consistent in different numbers of training workers, we scale the loss grad by the number of workers """ - block = self.main_program_list[self.nrings - 1]['program'].global_block( - ) + block = self.main_program_list[gpus_per_node - 1][ + 'program'].global_block() for idx, op in reversed(list(enumerate(block.ops))): if is_loss_grad_op(op): loss_grad_var = block.vars[op.output_arg_names[0]] @@ -166,12 +223,12 @@ class PipelineOptimizer(MetaOptimizerBase): inputs={'X': loss_grad_var}, outputs={'Out': loss_grad_var}, attrs={ - 'scale': 1.0 / self.nranks, + 'scale': 1.0 / node_num, OP_ROLE_KEY: OpRole.Backward }) def _insert_allreduce_ops(self, ring_id): - block = self.main_program_list[ring_id]['program'].global_block() + block = self.main_program_list[ring_id - 1]['program'].global_block() origin_block = self.main_program.global_block() grad = None for idx, op in reversed(list(enumerate(block.ops))): diff --git a/python/paddle/fluid/device_worker.py b/python/paddle/fluid/device_worker.py index 4796cd5ada420567fa126154cc1ac28badc0f2c0..2d7a4ebcc48fbd4ed22498f4a2e2882a14aaa872 100644 --- a/python/paddle/fluid/device_worker.py +++ b/python/paddle/fluid/device_worker.py @@ -406,25 +406,44 @@ class Section(DeviceWorker): section_param = trainer_desc.section_param section_param.num_microbatches = pipeline_opt["num_microbatches"] section_param.start_cpu_core_id = pipeline_opt["start_cpu_core_id"] - for i, program in enumerate(pipeline_opt["section_program_list"]): - cfg = section_param.section_config.add() - cfg.program_desc.ParseFromString(program["program"]._get_desc() - .serialize_to_string()) - # TODO: why does not work - # cfg.program_desc.CopyFrom(program.program._get_desc()) - place = pipeline_opt["place_list"][i] - place_id = pipeline_opt["place_id_list"][i] - if isinstance(place, core.CPUPlace): - cfg.place = cfg.CPUPlace - elif isinstance(place, core.CUDAPlace): - cfg.place = cfg.CUDAPlace - elif isinstance(place, core.CUDAPinnedPlace): - cfg.place = cfg.CUDAPinnedPlace - else: - raise NotImplementedError( - "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now." - ) - cfg.place_id = place_id + cfg = section_param.section_config + program = pipeline_opt["section_program"] + cfg.program_desc.ParseFromString(program["program"]._get_desc() + .serialize_to_string()) + # TODO: why does not work + # cfg.program_desc.CopyFrom(program.program._get_desc()) + place = pipeline_opt["place"] + place_id = pipeline_opt["place_id"] + if isinstance(place, core.CPUPlace): + cfg.place = cfg.CPUPlace + elif isinstance(place, core.CUDAPlace): + cfg.place = cfg.CUDAPlace + elif isinstance(place, core.CUDAPinnedPlace): + cfg.place = cfg.CUDAPinnedPlace + else: + raise NotImplementedError( + "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now." + ) + cfg.place_id = place_id + # for i, program in enumerate(pipeline_opt["section_program_list"]): + # cfg = section_param.section_config.add() + # cfg.program_desc.ParseFromString(program["program"]._get_desc() + # .serialize_to_string()) + # # TODO: why does not work + # # cfg.program_desc.CopyFrom(program.program._get_desc()) + # place = pipeline_opt["place_list"][i] + # place_id = pipeline_opt["place_id_list"][i] + # if isinstance(place, core.CPUPlace): + # cfg.place = cfg.CPUPlace + # elif isinstance(place, core.CUDAPlace): + # cfg.place = cfg.CUDAPlace + # elif isinstance(place, core.CUDAPinnedPlace): + # cfg.place = cfg.CUDAPinnedPlace + # else: + # raise NotImplementedError( + # "SectionWorker only supports CPUPlace, CUDAPlace and CUDAPinnedPlace now." + # ) + # cfg.place_id = place_id class DeviceWorkerFactory(object): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a74071b5ca546385f10925da37f1d0716f9c212d..1e74513de8d69d2ab0b4747cab8c4bb0abb60d9c 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -3818,6 +3818,24 @@ class PipelineOptimizer(object): return programs + def _split_startup_program(self, startup_program, local_rank): + block = startup_program.block(0) + new_startup_program = Program() + for op in block.ops: + device = op.attr(self._op_device_key) + if device: + device_index = int(device.split(":")[1]) + else: + device_index = 0 + if device_index != local_rank: continue + op_role = op.attr(self._op_role_key) + op_desc = op.desc + ap_op = new_startup_program.block(0).desc.append_op() + ap_op.copy_from(op_desc) + ap_op._set_attr(self._op_device_key, device) + self._create_vars(new_startup_program.block(0), startup_program) + return new_startup_program + def _find_post_op(self, ops, cur_op, var_name): """ Find the real post op that has variable named var_name as input. @@ -3933,6 +3951,7 @@ class PipelineOptimizer(object): if op.type == "read": break first_dev_spec = devices[0] + first_dev_index = int(first_dev_spec.split(':')[1]) for var_name in data_devices_map.keys(): for device in data_devices_map[var_name]: if device == first_dev_spec: continue @@ -3940,13 +3959,15 @@ class PipelineOptimizer(object): assert main_var.is_data if not var_name in first_block.vars: self._create_var(first_block, main_var, var_name) + dev_index = int(device.split(':')[1]) first_block._insert_op( index=insert_index, type='c_send', inputs={'X': first_block.var(var_name)}, attrs={ self._op_device_key: first_dev_spec, - self._op_role_key: self._op_role.Forward + self._op_role_key: self._op_role.Forward, + 'peer': dev_index }) # Get the device that that data on assert device in devices @@ -3961,8 +3982,10 @@ class PipelineOptimizer(object): type='c_recv', outputs={'Out': [new_var]}, attrs={ + 'out_shape': new_var.shape, self._op_device_key: device, self._op_role_key: self._op_role.Forward, + 'peer': first_dev_index }) def _strip_grad_suffix(self, name): @@ -4105,13 +4128,16 @@ class PipelineOptimizer(object): op_role = op.all_attrs()[self._op_role_key] var = block.vars[var_name] + prev_device_index = int(prev_device_spec.split(':')[1]) + cur_device_index = int(cur_device_spec.split(':')[1]) block._insert_op( index=index + extra_index, type='c_send', inputs={'X': var}, attrs={ self._op_device_key: prev_device_spec, - self._op_role_key: op_role + self._op_role_key: op_role, + 'peer': prev_device_index }) extra_index += 1 block._insert_op( @@ -4119,8 +4145,10 @@ class PipelineOptimizer(object): type='c_recv', outputs={'Out': [var]}, attrs={ + 'out_shape': var.shape, self._op_device_key: cur_device_spec, - self._op_role_key: op_role + self._op_role_key: op_role, + 'peer': cur_device_index }) extra_index += 1 @@ -4271,9 +4299,13 @@ class PipelineOptimizer(object): write_prog = write_info[var_name] write_block = write_prog.block(0) write_device = self._get_device_info(write_block) + write_dev_index = int(write_device.split(':')[1]) all_progs = var_info[var_name] for prog in all_progs: if prog == write_prog: continue + read_block = prog.block(0) + read_device = self._get_device_info(read_block) + read_dev_index = int(read_device.split(':')[1]) write_block._insert_op( index=0, @@ -4283,19 +4315,20 @@ class PipelineOptimizer(object): self._op_device_key: write_device, # A trick to make the role LRSched to avoid copy every # microbatch - self._op_role_key: self._op_role.LRSched + self._op_role_key: self._op_role.LRSched, + 'peer': read_dev_index }) - read_block = prog.block(0) - read_device = self._get_device_info(read_block) read_block._insert_op( index=0, type='c_recv', outputs={'Out': [read_block.var(var_name)]}, attrs={ + 'out_shape': read_block.var(var_name).shape, self._op_device_key: read_device, # A trick to make the role LRSched to avoid copy every # microbatch self._op_role_key: self._op_role.LRSched, + 'peer': write_dev_index }) def minimize(self, @@ -4363,12 +4396,25 @@ class PipelineOptimizer(object): # Step7: Add sub blocks for section programs self._add_sub_blocks(main_block, program_list) + assert (main_program._pipeline_opt and + isinstance(main_program._pipeline_opt, dict) and + 'local_rank' in main_program._pipeline_opt), \ + "You must use pipeline with fleet" + local_rank = main_program._pipeline_opt['local_rank'] + # Step8: Split startup program + startup_program = self._split_startup_program( + startup_program, program_list[local_rank]['program']) + with open("startup_prog_%d" % local_rank, 'w') as f: + f.writelines(str(startup_program)) + with open("main_prog_%d" % local_rank, 'w') as f: + f.writelines(str(program_list[local_rank]['program'])) + main_program._pipeline_opt = { "trainer": "PipelineTrainer", "device_worker": "Section", - "section_program_list": program_list, - "place_list": place_list, - "place_id_list": place_id_list, + "section_program": program_list[local_rank], + "place": place_list[local_rank], + "place_id": place_id_list[local_rank], "sync_steps": -1, "num_microbatches": self._num_microbatches, "start_cpu_core_id": self._start_cpu_core_id,