From e755c07e6a5c0880e28b3025a7bb41d684850dd3 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 9 Sep 2022 14:25:22 +0800 Subject: [PATCH] [new-exe] convert fused_all_reduce_op_handle to program (#45774) * add operator<< for BuildStrategy * add fake_coalesce * fit allreduce mode for new_exe * remove dubeg code * follow comments --- .../fluid/framework/details/build_strategy.h | 70 +++++++++++++++- .../fluid/framework/details/nccl_op_handle.h | 22 +++++ .../details/scale_loss_grad_op_handle.h | 2 + paddle/fluid/framework/ir/graph_helper.cc | 80 ++++++++++++++++++- .../framework/new_executor/interpretercore.cc | 2 +- .../operators/check_memory_continue_op.cc | 61 ++++++++++++++ .../operators/collective/c_allreduce_op.h | 4 + paddle/fluid/platform/collective_helper.cc | 3 +- paddle/fluid/platform/collective_helper.h | 11 +++ paddle/fluid/pybind/parallel_executor.cc | 6 ++ paddle/phi/infermeta/multiary.cc | 18 +++++ paddle/phi/infermeta/multiary.h | 5 ++ 12 files changed, 277 insertions(+), 7 deletions(-) create mode 100644 paddle/fluid/operators/check_memory_continue_op.cc diff --git a/paddle/fluid/framework/details/build_strategy.h b/paddle/fluid/framework/details/build_strategy.h index 513df4f1974..c1ef2eba646 100644 --- a/paddle/fluid/framework/details/build_strategy.h +++ b/paddle/fluid/framework/details/build_strategy.h @@ -61,8 +61,8 @@ struct BuildStrategy { // separately, if you choose kReduce, every thread is to optimize 25 // parameters. // Of particular note is, if you use kReduce when using CPU training, - // all the parameters are shared between different threads. This feature will - // save memory. + // all the parameters are shared between different threads. This + // feature will save memory. // FIXME(zcd): The result of the two modes(kAllReduce and kReduce) maybe not // equal for GPU. Because, the result of the different order of summing maybe // different, for example, the result of `a+b+c+d` may be different with the @@ -88,8 +88,8 @@ struct BuildStrategy { std::string debug_graphviz_path_{""}; // Add dependency between backward ops and optimization ops, make sure that - // all the backward ops are finished before running the optimization ops. - // It might make the training speed of data parallelism faster. + // all the backward ops are finished before running the optimization + // ops. It might make the training speed of data parallelism faster. bool enable_backward_optimizer_op_deps_{true}; // TODO(dev-paddle): enable_sequential_execution depends on // kStaleProgramOpDescs, it is not appropriate, because kStaleProgramOpDescs @@ -229,6 +229,68 @@ struct BuildStrategy { mutable std::shared_ptr pass_builder_; }; +inline std::ostream &operator<<(std::ostream &os, + const BuildStrategy &strategy) { + os << "BuildStrategy: " << &strategy << std::endl; + os << "reduce_: " << static_cast(strategy.reduce_) << std::endl; + os << "gradient_scale_: " << static_cast(strategy.gradient_scale_) + << std::endl; + os << "debug_graphviz_path_: " << strategy.debug_graphviz_path_ << std::endl; + os << "enable_backward_optimizer_op_deps_: " + << strategy.enable_backward_optimizer_op_deps_ << std::endl; + os << "enable_sequential_execution_: " + << strategy.enable_sequential_execution_ << std::endl; + os << "remove_unnecessary_lock_: " << strategy.remove_unnecessary_lock_ + << std::endl; + os << "cache_runtime_context_: " << strategy.cache_runtime_context_ + << std::endl; + os << "fix_op_run_order_: " << strategy.fix_op_run_order_ << std::endl; + os << "fuse_bn_act_ops_: " << strategy.fuse_bn_act_ops_ << std::endl; + os << "fuse_bn_add_act_ops_: " << strategy.fuse_bn_add_act_ops_ << std::endl; + os << "fuse_elewise_add_act_ops_: " << strategy.fuse_elewise_add_act_ops_ + << std::endl; + os << "enable_auto_fusion_: " << strategy.enable_auto_fusion_ << std::endl; + os << "fuse_all_optimizer_ops_: " << strategy.fuse_all_optimizer_ops_ + << std::endl; + os << "fuse_all_reduce_ops_: " << strategy.fuse_all_reduce_ops_ << std::endl; + os << "fuse_relu_depthwise_conv_: " << strategy.fuse_relu_depthwise_conv_ + << std::endl; + os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl; + os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl; + os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl; + os << "mkldnn_enabled_op_types_: "; + for (auto str : strategy.mkldnn_enabled_op_types_) { + os << str << ", "; + } + os << std::endl; + os << "memory_optimize_: " << strategy.memory_optimize_ << std::endl; + os << "enable_inplace_: " << strategy.enable_inplace_ << std::endl; + os << "allow_cuda_graph_capture_: " << strategy.allow_cuda_graph_capture_ + << std::endl; + os << "enable_inference_pass_: " << strategy.enable_inference_pass_ + << std::endl; + os << "delete_dropout_: " << strategy.delete_dropout_ << std::endl; + os << "use_mkldnn_: " << strategy.use_mkldnn_ << std::endl; + os << "is_distribution_: " << strategy.is_distribution_ << std::endl; + os << "async_mode_: " << strategy.async_mode_ << std::endl; + os << "num_trainers_: " << strategy.num_trainers_ << std::endl; + os << "trainer_id_: " << strategy.trainer_id_ << std::endl; + os << "trainers_endpoints_: "; + for (auto str : strategy.trainers_endpoints_) { + os << str << ", "; + } + os << std::endl; + os << "nccl_comm_num_: " << strategy.nccl_comm_num_ << std::endl; + os << "bkcl_comm_num_: " << strategy.bkcl_comm_num_ << std::endl; + os << "use_hierarchical_allreduce_: " << strategy.use_hierarchical_allreduce_ + << std::endl; + os << "hierarchical_allreduce_inter_nranks_: " + << strategy.hierarchical_allreduce_inter_nranks_ << std::endl; + os << "enable_parallel_graph_: " << strategy.enable_parallel_graph_ + << std::endl; + return os; +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nccl_op_handle.h b/paddle/fluid/framework/details/nccl_op_handle.h index 41c0b4753a6..7aa200fc91e 100644 --- a/paddle/fluid/framework/details/nccl_op_handle.h +++ b/paddle/fluid/framework/details/nccl_op_handle.h @@ -66,6 +66,28 @@ class NCCLOpHandleBase : public OpHandleBase { #endif } } + + const platform::NCCLCommunicator* GetNcclContext() const { + return nccl_ctxs_; + } + + const ncclComm_t GetComm() const { + PADDLE_ENFORCE_EQ( + places_.size(), + 1, + platform::errors::Unimplemented( + "Only supported for single place now, but got %d", places_.size())); + PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_, + 0, + platform::errors::Unimplemented( + "Not supported use_hierarchical_allreduce_ now")); + auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_); + int dev_id = places_[0].device; + auto& nccl_ctx = flat_nccl_ctxs->at(dev_id); + auto comm = nccl_ctx.comm_; + return comm; + } + void SetRunEnv(int run_order, bool use_hierarchical_allreduce) { PADDLE_ENFORCE_GE( run_order, diff --git a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h index d447d0284d8..9351b8c0c31 100644 --- a/paddle/fluid/framework/details/scale_loss_grad_op_handle.h +++ b/paddle/fluid/framework/details/scale_loss_grad_op_handle.h @@ -46,6 +46,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase { proto::VarType::Type DType() const { return out_dtype_; } + float Coeff() const { return coeff_; } + std::string Name() const override; platform::Place GetPlace() const { return place_; } diff --git a/paddle/fluid/framework/ir/graph_helper.cc b/paddle/fluid/framework/ir/graph_helper.cc index f800a1eba89..51e49be5f96 100644 --- a/paddle/fluid/framework/ir/graph_helper.cc +++ b/paddle/fluid/framework/ir/graph_helper.cc @@ -23,6 +23,11 @@ limitations under the License. */ #include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/program_utils.h" +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) +#include "paddle/fluid/framework/details/nccl_op_handle.h" +#include "paddle/fluid/platform/collective_helper.h" +#endif + DECLARE_bool(convert_all_blocks); PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir, "", @@ -481,6 +486,9 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { desc->SetAttr( "dtype", dynamic_cast(&op_hander)->DType()); + desc->SetAttr( + "value", + dynamic_cast(&op_hander)->Coeff()); } desc->SetAttr("force_cpu", false); @@ -497,6 +505,71 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) { return desc; } +static void ReplaceAllReduceOp(const Node &node, + proto::BlockDesc *block, + std::vector *ops) { +#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) + ops->emplace_back(); + auto &desc1 = ops->back(); + std::string name = "fake_coalesce_" + std::to_string(ops->size()); + desc1.SetType("check_memory_continue"); + + ops->emplace_back(); + auto &desc2 = ops->back(); + desc2.SetType("c_allreduce_sum"); + + if (node.IsWrappedBy()) { + details::OpHandleBase &op_hander = + const_cast(&node)->Wrapper(); + + // set inputs + auto in_var_handles = op_hander.Inputs(); + std::vector in_names; + for (const auto &in : in_var_handles) { + if (dynamic_cast(in) != nullptr) { + continue; + } + in_names.emplace_back(in->Name()); + } + desc1.SetInput("X", in_names); + + proto::VarDesc var_desc; + var_desc.set_name(name); + var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR); + block->mutable_vars()->Add()->CopyFrom(var_desc); + desc1.SetOutput("Out", {name}); + desc1.SetOutput("XOut", in_names); + VLOG(4) << "add variable for check_memory_continue: " << name; + + desc2.SetInput("X", {name}); + // set outputs + auto out_var_handles = op_hander.Outputs(); + std::vector out_names; + for (const auto &out : out_var_handles) { + if (dynamic_cast(out) != nullptr) { + continue; + } + out_names.emplace_back(out->Name()); + } + desc2.SetOutput("Out", {name}); + + int ring_id = platform::NCCLCommContext::Instance().GetRingId( + dynamic_cast(&op_hander)->GetComm()); + desc2.SetAttr("ring_id", ring_id); + desc2.SetAttr("use_calc_stream", true); + } + + desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward))); + desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(), + (static_cast(OpRole::kBackward))); +#else + PADDLE_THROW( + platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented " + "for paddle compiled with NCCL/RCCL.")); +#endif +} + void UpdateControlOpSkipEagerDeletionVars(const Node &node, const Graph &graph, const size_t graph_idx, @@ -526,6 +599,7 @@ void UpdateControlOpSkipEagerDeletionVars(const Node &node, } static void GetGraphOpDesc(const std::vector &nodes, + proto::BlockDesc *block, std::vector *ops, const Graph &graph, const size_t graph_idx) { @@ -552,6 +626,10 @@ static void GetGraphOpDesc(const std::vector &nodes, ops->emplace_back(); auto &desc = ops->back(); ReplaceScaleLossGradOp(*n, &desc); + } else if (n->Name() == "fused_all_reduce") { + VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum"; + ReplaceAllReduceOp(*n, block, ops); + VLOG(4) << n->ToString(); } else if (n->Op()) { VLOG(4) << "convert op node to desc " << n->Op()->Type(); if (is_fused_opt(n)) { @@ -645,7 +723,7 @@ static void GraphToBlock(const Graph &graph, } std::vector ops; - GetGraphOpDesc(nodes, &ops, graph, graph_idx); + GetGraphOpDesc(nodes, block, &ops, graph, graph_idx); for (auto &op : ops) { RemoveControlDepInputAndOuput(&op); diff --git a/paddle/fluid/framework/new_executor/interpretercore.cc b/paddle/fluid/framework/new_executor/interpretercore.cc index 6be8aa776a8..39746f07340 100644 --- a/paddle/fluid/framework/new_executor/interpretercore.cc +++ b/paddle/fluid/framework/new_executor/interpretercore.cc @@ -600,7 +600,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) { auto place = instr_node.DeviceContext().GetPlace(); Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope() : var_scope_.GetMutableScope(); - + VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_); #ifdef PADDLE_WITH_ASCEND_CL if (platform::is_npu_place(place)) { auto dev_id = place.device; diff --git a/paddle/fluid/operators/check_memory_continue_op.cc b/paddle/fluid/operators/check_memory_continue_op.cc new file mode 100644 index 00000000000..aca6951c87e --- /dev/null +++ b/paddle/fluid/operators/check_memory_continue_op.cc @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "paddle/fluid/framework/infershape_utils.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/op_version_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/phi/infermeta/multiary.h" + +namespace paddle { +namespace operators { + +class CheckMemoryContinueOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; +}; + +class CheckMemoryContinueOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(vector) The input tensors.").AsDuplicable(); + AddOutput("Out", "(LoDTensor) The output tensor.").AsDuplicable(); + AddOutput( + "XOut", + "(vector) The output tensors which are the same as x. It is " + "used to build the graph dependency"); + AddComment(R"DOC( +CheckMemoryContinue Operator. + +Check if the address of input tensor are continuous. + +Used for converting fused_all_reduce_op_handle in Graph to Program. + +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +DECLARE_INFER_SHAPE_FUNCTOR(check_memory_continue, + CheckMemoryContinueInferShapeFunctor, + PD_INFER_META(phi::CheckMemoryContinueInferMeta)); + +REGISTER_OPERATOR(check_memory_continue, + paddle::operators::CheckMemoryContinueOp, + paddle::operators::CheckMemoryContinueOpMaker, + CheckMemoryContinueInferShapeFunctor); diff --git a/paddle/fluid/operators/collective/c_allreduce_op.h b/paddle/fluid/operators/collective/c_allreduce_op.h index 718c77aaa6f..299dd59d5ef 100644 --- a/paddle/fluid/operators/collective/c_allreduce_op.h +++ b/paddle/fluid/operators/collective/c_allreduce_op.h @@ -423,6 +423,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel { } else { stream = comm->stream(); } + VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel + << ", redtype:" << static_cast(red_type) + << ", dtype:" << dtype << ", comm:" << comm + << ", stream:" << stream; ncclRedOp_t nccl_red_type = ncclSum; switch (red_type) { diff --git a/paddle/fluid/platform/collective_helper.cc b/paddle/fluid/platform/collective_helper.cc index 2589aa9acd0..95bf7a7aeaa 100644 --- a/paddle/fluid/platform/collective_helper.cc +++ b/paddle/fluid/platform/collective_helper.cc @@ -251,7 +251,8 @@ NCCLComm* NCCLCommContext::AssignNCCLComm( platform::CUDAPlace(dev_id))); dev_ctx->set_nccl_comm(comm); } - + VLOG(4) << "add mccl comm: " << comm_map_[ring_id][dev_id].get() + << ", ring_id:" << ring_id << ", dev_id:" << dev_id; return comm_map_[ring_id][dev_id].get(); } diff --git a/paddle/fluid/platform/collective_helper.h b/paddle/fluid/platform/collective_helper.h index 207496d9f46..0b037c48f0b 100644 --- a/paddle/fluid/platform/collective_helper.h +++ b/paddle/fluid/platform/collective_helper.h @@ -105,6 +105,17 @@ class NCCLCommContext { return comm_map_.at(ring_id).begin()->second.get(); } + int GetRingId(ncclComm_t comm) const { + for (const auto& pair : comm_map_) { + for (const auto& p : pair.second) { + if (p.second.get()->comm() == comm) { + return pair.first; + } + } + } + return -1; + } + // retrieve a communicator by the ring id and the device id NCCLComm* Get(int ring_id, int dev_id) const { PADDLE_ENFORCE_GT( diff --git a/paddle/fluid/pybind/parallel_executor.cc b/paddle/fluid/pybind/parallel_executor.cc index ba1d4ca250b..0b44dc5d2a2 100644 --- a/paddle/fluid/pybind/parallel_executor.cc +++ b/paddle/fluid/pybind/parallel_executor.cc @@ -987,6 +987,12 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT new_bs.ClearFinalized(); return new_bs; }) + .def("__str__", + [](const BuildStrategy &self) { + std::stringstream ss; + ss << self; + return ss.str(); + }) .def( "_finalize_strategy_and_create_passes", [](BuildStrategy &self) -> std::shared_ptr { diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 56dc40cc7c9..7bab7747798 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -835,6 +835,24 @@ void CoalesceTensorInferMeta(const std::vector& input, } } +void CheckMemoryContinueInferMeta(const std::vector& input, + MetaTensor* output, + std::vector xout, + MetaConfig config) { + if (config.is_runtime) { + return; + } + int64_t numel = 0; + for (size_t i = 0; i < input.size(); ++i) { + const auto& dim = input[i]->dims(); + auto size = phi::product(dim); + auto len = size * paddle::experimental::SizeOf(input[i]->dtype()); + numel += len; + } + output->set_dims(phi::make_ddim({numel})); + output->set_dtype(phi::DataType::INT8); +} + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index 4e95303f1a0..b9d8aedfb2e 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -216,6 +216,11 @@ void CoalesceTensorInferMeta(const std::vector& input, MetaTensor* fused_output, MetaConfig config = MetaConfig()); +void CheckMemoryContinueInferMeta(const std::vector& input, + MetaTensor* output, + std::vector xout, + MetaConfig config = MetaConfig()); + void ConcatInferMeta(const std::vector& x, const Scalar& axis_scalar, MetaTensor* out, -- GitLab