// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/computation_op_handle.h" #include "paddle/fluid/framework/details/scale_loss_grad_op_handle.h" #include "paddle/fluid/framework/details/send_op_handle.h" #include "paddle/fluid/framework/scope.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h" #endif #include #include namespace paddle { namespace framework { namespace details { #ifdef PADDLE_WITH_CUDA MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, bool skip_scale_loss, platform::NCCLContextMap *nccl_ctxs) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes), nccl_ctxs_(nccl_ctxs) { #else MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( const std::vector &places, const std::string &loss_var_name, const std::unordered_set ¶ms, const std::vector &local_scopes, bool skip_scale_loss) : loss_var_name_(loss_var_name), places_(places), local_scopes_(local_scopes) { #endif for (auto &p : params) { grad_names_.insert(GradVarName(p)); } skip_scale_loss_ = skip_scale_loss; } void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, const OpDesc &op, const platform::Place &p, const size_t &i) const { auto *op_handle = result->ops_.back().get(); op_handle->SetDeviceContext(p, platform::DeviceContextPool::Instance().Get(p)); auto var_names = op.InputArgumentNames(); for (auto &each_var_name : var_names) { VarHandle *var = CreateOrGetLatestVarHandle(result, each_var_name, p, i); op_handle->AddInput(var); } var_names = op.OutputArgumentNames(); for (auto &each_var_name : var_names) { CreateOpOutput(result, op_handle, each_var_name, p, i); } } bool MultiDevSSAGraphBuilder::IsDistTrainOp(const OpDesc &op, OpDesc *send_op) const { if (send_op == nullptr) { return false; } auto checker = [&](const std::vector opvars, const std::vector sendvars) -> bool { bool is_dist_train_op = false; for (auto &var : opvars) { if (var.find(".block") != std::string::npos && std::find(sendvars.begin(), sendvars.end(), var) != sendvars.end()) { is_dist_train_op = true; break; } } return is_dist_train_op; }; if (op.Type() == "split") { return checker(op.OutputArgumentNames(), send_op->InputArgumentNames()); } else if (op.Type() == "concat") { return checker(op.InputArgumentNames(), send_op->OutputArgumentNames()); } return false; } std::unique_ptr MultiDevSSAGraphBuilder::Build( const ProgramDesc &program) const { auto graph = new SSAGraph(); SSAGraph &result = *graph; std::unordered_set og_has_been_broadcast; // We cannot invoke resize. It is a bug of GCC 4.8 result.vars_ = std::vector< std::unordered_map>>>( places_.size()); // Find "send" op first for split is in front of send. OpDesc *send_op = nullptr; for (auto *op : program.Block(0).AllOps()) { if (op->Type() == "send") { send_op = op; break; } } bool is_forwarding = true; for (auto *op : program.Block(0).AllOps()) { if (op->Type() == "send") { // append send op if program is distributed trainer main program. // always use the first device CreateSendOp(&result, *op); } else if (IsDistTrainOp(*op, send_op)) { CreateComputationalOps(&result, *op, 1); } else if (IsScaleLossOp(*op)) { if (!skip_scale_loss_) { CreateScaleLossGradOp(&result); } is_forwarding = false; } else { CreateComputationalOps(&result, *op, places_.size()); if (!is_forwarding) { // Currently, we assume that once gradient is generated, it can be // broadcast, and each gradient is only broadcast once. But there are no // other cases, for example, we need to adjust the gradient according to // the input when we get the gradient, which is not considered at // present. for (auto &og : op->OutputArgumentNames()) { if (IsParameterGradientOnce(og, &og_has_been_broadcast)) { InsertNCCLAllReduceOp(&result, og); } } } } } /* Dependency graph has been constructed. However, there are still data harzaeds need to be handled. */ PolishGraphToSupportDataHazards(&result); /* * Only variables should be the leaves of graph. */ AddOutputToLeafOps(&result); if (VLOG_IS_ON(10)) { std::ostringstream sout; PrintGraphviz(*graph, sout); VLOG(10) << sout.str(); } return std::unique_ptr(graph); } void MultiDevSSAGraphBuilder::InsertNCCLAllReduceOp( SSAGraph *result, const std::string &og) const { #ifdef PADDLE_WITH_CUDA result->ops_.emplace_back( new NCCLAllReduceOpHandle(local_scopes_, places_, *nccl_ctxs_)); auto *op_handle = result->ops_.back().get(); for (size_t i = 0; i < places_.size(); ++i) { auto &p = places_[i]; auto &vars = result->vars_[i][og]; PADDLE_ENFORCE(!vars.empty()); auto &prev_grad = vars.back(); op_handle->AddInput(prev_grad.get()); auto var = new VarHandle(vars.size() - 1, i, og, p); vars.emplace_back(var); op_handle->AddOutput(var); } #else PADDLE_ENFORCE("Not implemented"); #endif } bool MultiDevSSAGraphBuilder::IsParameterGradientOnce( const std::string &og, std::unordered_set *og_has_been_broadcast) const { bool is_pg_once = grad_names_.count(og) != 0 && og_has_been_broadcast->count(og) == 0; if (is_pg_once) { // Insert NCCL AllReduce Op og_has_been_broadcast->insert(og); } return is_pg_once; } void MultiDevSSAGraphBuilder::CreateScaleLossGradOp(SSAGraph *result) const { for (size_t i = 0; i < places_.size(); ++i) { // Insert ScaleCost OpHandle #ifdef PADDLE_WITH_CUDA auto *communication_dev_ctx = nccl_ctxs_->DevCtx(places_[i]); #else auto *communication_dev_ctx = platform::DeviceContextPool::Instance().Get(platform::CPUPlace()); #endif auto *op_handle = new ScaleLossGradOpHandle(local_scopes_.size(), local_scopes_[i], places_[i], communication_dev_ctx); result->ops_.emplace_back(op_handle); // FIXME: Currently ScaleLossGradOp only use device_count as scale // factor. So it does not depend on any other operators. // VarHandle *loss = GetVarHandle(loss_var_name, place); // loss->pending_ops_.emplace_back(op_handle); // op_handle->inputs_.emplace_back(loss); CreateOpOutput(result, op_handle, GradVarName(loss_var_name_), places_[i], i); } } void MultiDevSSAGraphBuilder::CreateComputationalOps(SSAGraph *result, const OpDesc &op, size_t num_places) const { for (size_t scope_idx = 0; scope_idx < num_places; ++scope_idx) { auto p = places_[scope_idx]; auto s = local_scopes_[scope_idx]; result->ops_.emplace_back(new ComputationOpHandle(op, s, p)); CreateOpHandleIOs(result, op, p, scope_idx); } } void MultiDevSSAGraphBuilder::CreateSendOp(SSAGraph *result, const OpDesc &op) const { auto &p = places_[0]; auto *s = local_scopes_[0]; // FIXME(wuyi): send op always copy from GPU 0 result->ops_.emplace_back(new SendOpHandle(op, s, p)); // Create inputs for output on original place and no ssa output // is created for send op. CreateOpHandleIOs(result, op, p, 0); } bool MultiDevSSAGraphBuilder::IsScaleLossOp(const OpDesc &op) const { // FIXME(yy): Do not hard code like this return op.OutputArgumentNames().size() == 1 && op.OutputArgumentNames()[0] == GradVarName(loss_var_name_); } } // namespace details } // namespace framework } // namespace paddle