未验证 提交 e755c07e 编写于 作者: L Leo Chen 提交者: GitHub

[new-exe] convert fused_all_reduce_op_handle to program (#45774)

* add operator<< for BuildStrategy

* add fake_coalesce

* fit allreduce mode for new_exe

* remove dubeg code

* follow comments
上级 2632d77d
......@@ -61,8 +61,8 @@ struct BuildStrategy {
// separately, if you choose kReduce, every thread is to optimize 25
// parameters.
// Of particular note is, if you use kReduce when using CPU training,
// all the parameters are shared between different threads. This feature will
// save memory.
// all the parameters are shared between different threads. This
// feature will save memory.
// FIXME(zcd): The result of the two modes(kAllReduce and kReduce) maybe not
// equal for GPU. Because, the result of the different order of summing maybe
// different, for example, the result of `a+b+c+d` may be different with the
......@@ -88,8 +88,8 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""};
// Add dependency between backward ops and optimization ops, make sure that
// all the backward ops are finished before running the optimization ops.
// It might make the training speed of data parallelism faster.
// all the backward ops are finished before running the optimization
// ops. It might make the training speed of data parallelism faster.
bool enable_backward_optimizer_op_deps_{true};
// TODO(dev-paddle): enable_sequential_execution depends on
// kStaleProgramOpDescs, it is not appropriate, because kStaleProgramOpDescs
......@@ -229,6 +229,68 @@ struct BuildStrategy {
mutable std::shared_ptr<ir::PassBuilder> pass_builder_;
};
inline std::ostream &operator<<(std::ostream &os,
const BuildStrategy &strategy) {
os << "BuildStrategy: " << &strategy << std::endl;
os << "reduce_: " << static_cast<int>(strategy.reduce_) << std::endl;
os << "gradient_scale_: " << static_cast<int>(strategy.gradient_scale_)
<< std::endl;
os << "debug_graphviz_path_: " << strategy.debug_graphviz_path_ << std::endl;
os << "enable_backward_optimizer_op_deps_: "
<< strategy.enable_backward_optimizer_op_deps_ << std::endl;
os << "enable_sequential_execution_: "
<< strategy.enable_sequential_execution_ << std::endl;
os << "remove_unnecessary_lock_: " << strategy.remove_unnecessary_lock_
<< std::endl;
os << "cache_runtime_context_: " << strategy.cache_runtime_context_
<< std::endl;
os << "fix_op_run_order_: " << strategy.fix_op_run_order_ << std::endl;
os << "fuse_bn_act_ops_: " << strategy.fuse_bn_act_ops_ << std::endl;
os << "fuse_bn_add_act_ops_: " << strategy.fuse_bn_add_act_ops_ << std::endl;
os << "fuse_elewise_add_act_ops_: " << strategy.fuse_elewise_add_act_ops_
<< std::endl;
os << "enable_auto_fusion_: " << strategy.enable_auto_fusion_ << std::endl;
os << "fuse_all_optimizer_ops_: " << strategy.fuse_all_optimizer_ops_
<< std::endl;
os << "fuse_all_reduce_ops_: " << strategy.fuse_all_reduce_ops_ << std::endl;
os << "fuse_relu_depthwise_conv_: " << strategy.fuse_relu_depthwise_conv_
<< std::endl;
os << "fuse_broadcast_ops_: " << strategy.fuse_broadcast_ops_ << std::endl;
os << "sync_batch_norm_: " << strategy.sync_batch_norm_ << std::endl;
os << "fuse_gemm_epilogue_: " << strategy.fuse_gemm_epilogue_ << std::endl;
os << "mkldnn_enabled_op_types_: ";
for (auto str : strategy.mkldnn_enabled_op_types_) {
os << str << ", ";
}
os << std::endl;
os << "memory_optimize_: " << strategy.memory_optimize_ << std::endl;
os << "enable_inplace_: " << strategy.enable_inplace_ << std::endl;
os << "allow_cuda_graph_capture_: " << strategy.allow_cuda_graph_capture_
<< std::endl;
os << "enable_inference_pass_: " << strategy.enable_inference_pass_
<< std::endl;
os << "delete_dropout_: " << strategy.delete_dropout_ << std::endl;
os << "use_mkldnn_: " << strategy.use_mkldnn_ << std::endl;
os << "is_distribution_: " << strategy.is_distribution_ << std::endl;
os << "async_mode_: " << strategy.async_mode_ << std::endl;
os << "num_trainers_: " << strategy.num_trainers_ << std::endl;
os << "trainer_id_: " << strategy.trainer_id_ << std::endl;
os << "trainers_endpoints_: ";
for (auto str : strategy.trainers_endpoints_) {
os << str << ", ";
}
os << std::endl;
os << "nccl_comm_num_: " << strategy.nccl_comm_num_ << std::endl;
os << "bkcl_comm_num_: " << strategy.bkcl_comm_num_ << std::endl;
os << "use_hierarchical_allreduce_: " << strategy.use_hierarchical_allreduce_
<< std::endl;
os << "hierarchical_allreduce_inter_nranks_: "
<< strategy.hierarchical_allreduce_inter_nranks_ << std::endl;
os << "enable_parallel_graph_: " << strategy.enable_parallel_graph_
<< std::endl;
return os;
}
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -66,6 +66,28 @@ class NCCLOpHandleBase : public OpHandleBase {
#endif
}
}
const platform::NCCLCommunicator* GetNcclContext() const {
return nccl_ctxs_;
}
const ncclComm_t GetComm() const {
PADDLE_ENFORCE_EQ(
places_.size(),
1,
platform::errors::Unimplemented(
"Only supported for single place now, but got %d", places_.size()));
PADDLE_ENFORCE_EQ(use_hierarchical_allreduce_,
0,
platform::errors::Unimplemented(
"Not supported use_hierarchical_allreduce_ now"));
auto flat_nccl_ctxs = nccl_ctxs_->GetFlatCtx(run_order_);
int dev_id = places_[0].device;
auto& nccl_ctx = flat_nccl_ctxs->at(dev_id);
auto comm = nccl_ctx.comm_;
return comm;
}
void SetRunEnv(int run_order, bool use_hierarchical_allreduce) {
PADDLE_ENFORCE_GE(
run_order,
......
......@@ -46,6 +46,8 @@ struct ScaleLossGradOpHandle : public OpHandleBase {
proto::VarType::Type DType() const { return out_dtype_; }
float Coeff() const { return coeff_; }
std::string Name() const override;
platform::Place GetPlace() const { return place_; }
......
......@@ -23,6 +23,11 @@ limitations under the License. */
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_utils.h"
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
#include "paddle/fluid/framework/details/nccl_op_handle.h"
#include "paddle/fluid/platform/collective_helper.h"
#endif
DECLARE_bool(convert_all_blocks);
PADDLE_DEFINE_EXPORTED_string(print_sub_graph_dir,
"",
......@@ -481,6 +486,9 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
desc->SetAttr(
"dtype",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->DType());
desc->SetAttr(
"value",
dynamic_cast<details::ScaleLossGradOpHandle *>(&op_hander)->Coeff());
}
desc->SetAttr("force_cpu", false);
......@@ -497,6 +505,71 @@ static OpDesc *ReplaceScaleLossGradOp(const Node &node, OpDesc *desc) {
return desc;
}
static void ReplaceAllReduceOp(const Node &node,
proto::BlockDesc *block,
std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
ops->emplace_back();
auto &desc1 = ops->back();
std::string name = "fake_coalesce_" + std::to_string(ops->size());
desc1.SetType("check_memory_continue");
ops->emplace_back();
auto &desc2 = ops->back();
desc2.SetType("c_allreduce_sum");
if (node.IsWrappedBy<details::OpHandleBase>()) {
details::OpHandleBase &op_hander =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
// set inputs
auto in_var_handles = op_hander.Inputs();
std::vector<std::string> in_names;
for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
continue;
}
in_names.emplace_back(in->Name());
}
desc1.SetInput("X", in_names);
proto::VarDesc var_desc;
var_desc.set_name(name);
var_desc.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
block->mutable_vars()->Add()->CopyFrom(var_desc);
desc1.SetOutput("Out", {name});
desc1.SetOutput("XOut", in_names);
VLOG(4) << "add variable for check_memory_continue: " << name;
desc2.SetInput("X", {name});
// set outputs
auto out_var_handles = op_hander.Outputs();
std::vector<std::string> out_names;
for (const auto &out : out_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(out) != nullptr) {
continue;
}
out_names.emplace_back(out->Name());
}
desc2.SetOutput("Out", {name});
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_hander)->GetComm());
desc2.SetAttr("ring_id", ring_id);
desc2.SetAttr("use_calc_stream", true);
}
desc1.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
desc2.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
#else
PADDLE_THROW(
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
"for paddle compiled with NCCL/RCCL."));
#endif
}
void UpdateControlOpSkipEagerDeletionVars(const Node &node,
const Graph &graph,
const size_t graph_idx,
......@@ -526,6 +599,7 @@ void UpdateControlOpSkipEagerDeletionVars(const Node &node,
}
static void GetGraphOpDesc(const std::vector<Node *> &nodes,
proto::BlockDesc *block,
std::vector<OpDesc> *ops,
const Graph &graph,
const size_t graph_idx) {
......@@ -552,6 +626,10 @@ static void GetGraphOpDesc(const std::vector<Node *> &nodes,
ops->emplace_back();
auto &desc = ops->back();
ReplaceScaleLossGradOp(*n, &desc);
} else if (n->Name() == "fused_all_reduce") {
VLOG(4) << "convert op node fused_all_reduce to desc c_allreduce_sum";
ReplaceAllReduceOp(*n, block, ops);
VLOG(4) << n->ToString();
} else if (n->Op()) {
VLOG(4) << "convert op node to desc " << n->Op()->Type();
if (is_fused_opt(n)) {
......@@ -645,7 +723,7 @@ static void GraphToBlock(const Graph &graph,
}
std::vector<OpDesc> ops;
GetGraphOpDesc(nodes, &ops, graph, graph_idx);
GetGraphOpDesc(nodes, block, &ops, graph, graph_idx);
for (auto &op : ops) {
RemoveControlDepInputAndOuput(&op);
......
......@@ -600,7 +600,7 @@ void InterpreterCore::RunInstruction(const Instruction& instr_node) {
auto place = instr_node.DeviceContext().GetPlace();
Scope* local_scope = create_local_scope_ ? var_scope_.GetMutableLocalScope()
: var_scope_.GetMutableScope();
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope_);
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place)) {
auto dev_id = place.device;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
class CheckMemoryContinueOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
};
class CheckMemoryContinueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(vector<LoDTensor>) The input tensors.").AsDuplicable();
AddOutput("Out", "(LoDTensor) The output tensor.").AsDuplicable();
AddOutput(
"XOut",
"(vector<LoDTensor>) The output tensors which are the same as x. It is "
"used to build the graph dependency");
AddComment(R"DOC(
CheckMemoryContinue Operator.
Check if the address of input tensor are continuous.
Used for converting fused_all_reduce_op_handle in Graph to Program.
)DOC");
}
};
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(check_memory_continue,
CheckMemoryContinueInferShapeFunctor,
PD_INFER_META(phi::CheckMemoryContinueInferMeta));
REGISTER_OPERATOR(check_memory_continue,
paddle::operators::CheckMemoryContinueOp,
paddle::operators::CheckMemoryContinueOpMaker,
CheckMemoryContinueInferShapeFunctor);
......@@ -423,6 +423,10 @@ class CAllReduceOpCUDAKernel : public framework::OpKernel<T> {
} else {
stream = comm->stream();
}
VLOG(10) << "all reduce buffer:" << sendbuff << ", numel:" << numel
<< ", redtype:" << static_cast<int>(red_type)
<< ", dtype:" << dtype << ", comm:" << comm
<< ", stream:" << stream;
ncclRedOp_t nccl_red_type = ncclSum;
switch (red_type) {
......
......@@ -251,7 +251,8 @@ NCCLComm* NCCLCommContext::AssignNCCLComm(
platform::CUDAPlace(dev_id)));
dev_ctx->set_nccl_comm(comm);
}
VLOG(4) << "add mccl comm: " << comm_map_[ring_id][dev_id].get()
<< ", ring_id:" << ring_id << ", dev_id:" << dev_id;
return comm_map_[ring_id][dev_id].get();
}
......
......@@ -105,6 +105,17 @@ class NCCLCommContext {
return comm_map_.at(ring_id).begin()->second.get();
}
int GetRingId(ncclComm_t comm) const {
for (const auto& pair : comm_map_) {
for (const auto& p : pair.second) {
if (p.second.get()->comm() == comm) {
return pair.first;
}
}
}
return -1;
}
// retrieve a communicator by the ring id and the device id
NCCLComm* Get(int ring_id, int dev_id) const {
PADDLE_ENFORCE_GT(
......
......@@ -987,6 +987,12 @@ void BindParallelExecutor(pybind11::module &m) { // NOLINT
new_bs.ClearFinalized();
return new_bs;
})
.def("__str__",
[](const BuildStrategy &self) {
std::stringstream ss;
ss << self;
return ss.str();
})
.def(
"_finalize_strategy_and_create_passes",
[](BuildStrategy &self) -> std::shared_ptr<ir::PassBuilder> {
......
......@@ -835,6 +835,24 @@ void CoalesceTensorInferMeta(const std::vector<const MetaTensor*>& input,
}
}
void CheckMemoryContinueInferMeta(const std::vector<const MetaTensor*>& input,
MetaTensor* output,
std::vector<MetaTensor*> xout,
MetaConfig config) {
if (config.is_runtime) {
return;
}
int64_t numel = 0;
for (size_t i = 0; i < input.size(); ++i) {
const auto& dim = input[i]->dims();
auto size = phi::product(dim);
auto len = size * paddle::experimental::SizeOf(input[i]->dtype());
numel += len;
}
output->set_dims(phi::make_ddim({numel}));
output->set_dtype(phi::DataType::INT8);
}
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
......
......@@ -216,6 +216,11 @@ void CoalesceTensorInferMeta(const std::vector<const MetaTensor*>& input,
MetaTensor* fused_output,
MetaConfig config = MetaConfig());
void CheckMemoryContinueInferMeta(const std::vector<const MetaTensor*>& input,
MetaTensor* output,
std::vector<MetaTensor*> xout,
MetaConfig config = MetaConfig());
void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
const Scalar& axis_scalar,
MetaTensor* out,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册