未验证 提交 d00b7d83 编写于 作者: R Ruibiao Chen 提交者: GitHub

Support stream overlap for c_allreduce_sum (#47030)

* Support stream overlap for c_allreduce_sum

* Test CI

* Add notes

* Add SingleStreamGuard for BuildOpFuncList
上级 3bc4b850
......@@ -511,8 +511,10 @@ void ReplaceAllReduceOp(const Node &node,
std::vector<OpDesc> *ops) {
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
bool is_fused = (node.Name() == "fused_all_reduce");
details::OpHandleBase &op_handle =
const_cast<Node *>(&node)->Wrapper<details::OpHandleBase>();
auto &in_var_handles = op_handle.Inputs();
// Even if PADDLE_WITH_NCCL is defined, if the program runs on CPU,
// nccl_ctxs_ in NCCLOpHandleBase will be nullptr, and calling the
......@@ -537,7 +539,6 @@ void ReplaceAllReduceOp(const Node &node,
<< all_reduce_var_name;
// get inputs of check_memory_continue
auto in_var_handles = op_handle.Inputs();
std::vector<std::string> in_names;
for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
......@@ -555,7 +556,7 @@ void ReplaceAllReduceOp(const Node &node,
fuse_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
} else {
all_reduce_var_name = op_handle.Inputs()[0]->Name();
all_reduce_var_name = in_var_handles[0]->Name();
}
// add c_allreduce_sum OP
......@@ -568,7 +569,7 @@ void ReplaceAllReduceOp(const Node &node,
int ring_id = platform::NCCLCommContext::Instance().GetRingId(
dynamic_cast<details::NCCLOpHandleBase *>(&op_handle)->GetComm());
all_reduce_op_desc.SetAttr("ring_id", ring_id);
all_reduce_op_desc.SetAttr("use_calc_stream", true);
all_reduce_op_desc.SetAttr("use_calc_stream", false);
all_reduce_op_desc.SetAttr(OpProtoAndCheckerMaker::OpRoleAttrName(),
(static_cast<int>(OpRole::kBackward)));
......@@ -586,6 +587,34 @@ void ReplaceAllReduceOp(const Node &node,
->GradMergeCondName();
all_reduce_op_desc.SetInput("Cond", {cond_name});
}
// Add dependency for FusedAllReduce.
// For the following example:
// ### fused_grad = FusedAllReduce(grad0, grad1, grad2, ...)
// ### v0 = op0(grad0)
// ### v1 = op1(grad1)
// It is converted to:
// ### fused_grad = check_memory_continue(grad0, grad1, grad2, ...)
// ### fused_grad = c_sum_allreduce(fused_grad)
// ### v0 = op0(grad0)
// ### v1 = op1(grad1)
// We should add the following dependency to ensure that op0 and op1 both run
// afer c_sum_allreduce:
// ### grad0 = depend(grad0, fused_grad)
// ### grad1 = depend(grad1, fused_grad)
if (is_fused) {
for (const auto &in : in_var_handles) {
if (dynamic_cast<details::DummyVarHandle *>(in) != nullptr) {
continue;
}
ops->emplace_back();
OpDesc &depend_op_desc = ops->back();
depend_op_desc.SetType("depend");
depend_op_desc.SetInput("X", {in->Name()});
depend_op_desc.SetInput("Dep", {all_reduce_var_name});
depend_op_desc.SetOutput("Out", {in->Name()});
}
}
#else
PADDLE_THROW(
platform::errors::Unimplemented("ReplaceAllReduceOp is only implemented "
......
......@@ -2,48 +2,11 @@ add_subdirectory(garbage_collector)
add_subdirectory(interpreter)
add_subdirectory(workqueue)
set(STANDALONE_EXECUTOR_SRCS
data_transfer.cc
new_executor_defs.cc
interpretercore_util.cc
event_manager.cc
stream_analyzer.cc
interpretercore.cc
standalone_executor.cc)
set(STANDALONE_EXECUTOR_SRCS interpretercore.cc new_executor_defs.cc
stream_analyzer.cc standalone_executor.cc)
set(STANDALONE_EXECUTOR_DEPS
dependency_builder
device_context
execution_config
op_registry
scope
framework_proto
data_feed_proto
ops_extra_info
heter_service_proto
trainer_desc_proto
glog
lod_rank_table
fs
shell
fleet_wrapper
heter_wrapper
ps_gpu_wrapper
box_wrapper
lodtensor_printer
feed_fetch_method
graph_to_program_pass
variable_helper
timer
monitor
nan_inf_utils
enforce
scope
glog
workqueue
interpretercore_garbage_collector
${DEVICE_EVENT_LIBS}
glog)
set(STANDALONE_EXECUTOR_DEPS interpreter interpretercore_garbage_collector
workqueue)
cc_library(
standalone_executor
......
cc_library(
dependency_builder
SRCS dependency_builder.cc
DEPS operator)
set(INTERPRETER_SRCS data_transfer.cc dependency_builder.cc event_manager.cc
execution_config.cc interpreter_util.cc)
set(INTERPRETER_DEPS
device_context
op_registry
scope
framework_proto
data_feed_proto
ops_extra_info
heter_service_proto
trainer_desc_proto
glog
lod_rank_table
fs
shell
fleet_wrapper
heter_wrapper
ps_gpu_wrapper
box_wrapper
lodtensor_printer
feed_fetch_method
graph_to_program_pass
variable_helper
timer
monitor
nan_inf_utils
enforce
scope
glog
${DEVICE_EVENT_LIBS}
glog)
cc_library(
execution_config
SRCS execution_config.cc
DEPS phi_backends)
interpreter
SRCS ${INTERPRETER_SRCS}
DEPS ${INTERPRETER_DEPS})
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/core/kernel_context.h"
......
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h"
#include <queue>
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
namespace paddle {
namespace framework {
......@@ -27,22 +28,6 @@ size_t CountDownstreamMap(const std::map<int, std::set<int>>& downstream_map) {
}
return count;
}
bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = {
"send",
"recv",
"send_v2",
"recv_v2",
};
const std::string communication_op_prefix = "c_";
if (op_name.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op_name)) {
return true;
}
return false;
}
const std::string StringizeDownstreamMap(
const std::map<int, std::set<int>>& downstream_map) {
std::ostringstream oss;
......@@ -187,21 +172,6 @@ void DependencyBuilder::AddDependencyForCoalesceTensorOp() {
}
void DependencyBuilder::AddDependencyForCommunicationOp() {
auto IsCommunicationOp = [](std::string op) -> bool {
const std::set<std::string> special_comm_op_set = {
"send",
"recv",
"send_v2",
"recv_v2",
};
const std::string communication_op_prefix = "c_";
if (op.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op)) {
return true;
}
return false;
};
int dependence_op_idx = -1;
for (size_t op_idx = 0; op_idx < op_num_; ++op_idx) {
if (IsCommunicationOp(instructions_->at(op_idx).OpBase()->Type())) {
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpreter/event_manager.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
......
......@@ -12,13 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include <algorithm>
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor_gc_helper.h"
#include "paddle/fluid/framework/new_executor/data_transfer.h"
#include "paddle/fluid/framework/new_executor/interpreter/data_transfer.h"
#include "paddle/fluid/memory/stats.h"
#include "paddle/fluid/operators/controlflow/conditional_block_op_helper.h"
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
......@@ -50,6 +50,38 @@ namespace interpreter {
using VariableIdMap = std::map<std::string, std::vector<int>>;
// NOTE(Ruibiao): SingleStreamGuard make some multi-strem op (i.e.,
// c_allreduce_sum) run in single stream. It is dedicated to BuildOpFuncList
// which run kernel without stream synchronization.
class SingleStreamGuard {
public:
explicit SingleStreamGuard(std::shared_ptr<OperatorBase>& op) : op_(op) {
if (op_->Type() == "c_allreduce_sum" &&
op_->Attr<bool>("use_calc_stream") == false) {
VLOG(6) << "Set c_allredce_sum's attr use_calc_stream to true";
op_->SetAttr("use_calc_stream", true);
is_changed = true;
}
}
~SingleStreamGuard() {
if (!is_changed) {
return;
}
if (op_->Type() == "c_allreduce_sum") {
op_->SetAttr("use_calc_stream", false);
VLOG(6) << "Set c_allredce_sum's attr use_calc_stream to false";
}
}
DISABLE_COPY_AND_ASSIGN(SingleStreamGuard);
private:
bool is_changed{false};
std::shared_ptr<OperatorBase> op_;
};
const std::vector<WorkQueueOptions> ConstructWorkQueueOptions(
size_t host_num_threads, size_t device_num_threads, EventsWaiter* waiter) {
std::vector<WorkQueueOptions> group_options;
......@@ -471,6 +503,9 @@ void BuildOpFuncList(const platform::Place& place,
op_func_node.operator_base_ = ops[i];
op_func_node.input_index = ins_name2id;
op_func_node.output_index = outs_name2id;
SingleStreamGuard single_stream_guard(ops[i]);
VLOG(4) << "Start run " << place << " " << op->DebugStringEx(local_scope);
#ifdef PADDLE_WITH_ASCEND_CL
......@@ -514,16 +549,13 @@ void BuildOpFuncList(const platform::Place& place,
auto& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
VLOG(4) << "get dev_ctx";
auto exec_ctx = ExecutionContext(
*op_with_kernel, *runtime_scope, *dev_ctx, runtime_context);
VLOG(4) << "get exec_ctx";
auto expected_kernel_key =
op_with_kernel->GetExpectedKernelType(exec_ctx);
VLOG(4) << "get expected_kernel_key";
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// change device by the device_guard()
ApplyDeviceGuard(op, place, &expected_kernel_key);
VLOG(4) << "expected_kernel_key : " << expected_kernel_key;
// step 2. select op kernel
auto run_phi_kernel = false;
......@@ -722,6 +754,21 @@ void AddFetch(const std::vector<std::string>& fetch_names,
}
}
bool IsCommunicationOp(const std::string& op_name) {
const std::set<std::string> special_comm_op_set = {
"send",
"recv",
"send_v2",
"recv_v2",
};
const std::string communication_op_prefix = "c_";
if (op_name.find(communication_op_prefix) != std::string::npos ||
special_comm_op_set.count(op_name)) {
return true;
}
return false;
}
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -81,6 +81,8 @@ void BuildOpFuncList(const platform::Place& place,
void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
bool IsCommunicationOp(const std::string& op_name);
} // namespace interpreter
} // namespace framework
} // namespace paddle
......@@ -18,7 +18,7 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/share_tensor_buffer_functor.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/os_info.h"
......@@ -101,7 +101,7 @@ inline void SetDeviceId(const platform::Place& place) {
}
}
// TODO(Ruibia): Pass skip_gc_vars, used_for_jit, and other config messages by
// TODO(Ruibiao): Pass skip_gc_vars, used_for_jit, and other config messages by
// constructing an interpreter::ExecutionConfig
InterpreterCore::InterpreterCore(const platform::Place& place,
const BlockDesc& block,
......
......@@ -20,11 +20,11 @@
#include <vector>
#include "paddle/fluid/framework/details/exception_holder.h"
#include "paddle/fluid/framework/new_executor/event_manager.h"
#include "paddle/fluid/framework/new_executor/garbage_collector/garbage_collector.h"
#include "paddle/fluid/framework/new_executor/interpreter/dependency_builder.h"
#include "paddle/fluid/framework/new_executor/interpreter/event_manager.h"
#include "paddle/fluid/framework/new_executor/interpreter/execution_config.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/framework/new_executor/new_executor_defs.h"
#include "paddle/fluid/framework/new_executor/profiler.h"
#include "paddle/fluid/framework/new_executor/stream_analyzer.h"
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/new_executor/standalone_executor.h"
#include "paddle/fluid/framework/new_executor/interpretercore_util.h"
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
......
......@@ -17,44 +17,50 @@
#include <future>
#include <unordered_set>
#include "paddle/fluid/framework/new_executor/interpreter/interpreter_util.h"
#include "paddle/fluid/platform/collective_helper.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
namespace {
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
d2h_ctxs = nullptr;
std::map<Place, std::shared_future<std::unique_ptr<platform::DeviceContext>>>*
h2d_ctxs = nullptr;
std::mutex ctx_mtx;
} // namespace
StreamAnalyzer::StreamAnalyzer(const platform::Place& place) : place_(place) {
if (platform::is_gpu_place(place) || platform::is_npu_place(place) ||
platform::is_custom_place(place)) {
std::lock_guard<std::mutex> lk(ctx_mtx);
if (d2h_ctxs == nullptr) {
d2h_ctxs = new std::map<
Place,
std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
h2d_ctxs = new std::map<
Place,
std::shared_future<std::unique_ptr<platform::DeviceContext>>>();
}
if (d2h_ctxs->find(place) == d2h_ctxs->end()) {
platform::EmplaceDeviceContexts(
d2h_ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
// stream types
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";
class ContextManager {
public:
using DeviceContextMap =
std::map<Place,
std::shared_future<std::unique_ptr<platform::DeviceContext>>>;
static ContextManager& Instance() {
static ContextManager* ctx_manager = new ContextManager;
return *ctx_manager;
}
std::shared_future<std::unique_ptr<platform::DeviceContext>> Get(
const std::string& type, const platform::Place& place) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get dev_ctx for " << type << " - " << place;
DeviceContextMap& ctxs = ctx_pool_[type];
if (ctxs.find(place) == ctxs.end()) {
platform::EmplaceDeviceContexts(
h2d_ctxs,
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
}
d2h_ctx_ = (*d2h_ctxs)[place];
h2d_ctx_ = (*h2d_ctxs)[place];
return ctxs[place];
}
}
private:
ContextManager() {}
DISABLE_COPY_AND_ASSIGN(ContextManager);
std::mutex ctx_mtx_;
std::unordered_map<std::string, DeviceContextMap> ctx_pool_;
};
/*
* Parse the var_ids that need to be associated with an event.
......@@ -88,23 +94,26 @@ std::vector<size_t> StreamAnalyzer::GetNeedEventVarIds(
return false;
};
bool is_comm = interpreter::IsCommunicationOp(cur_instr.OpBase()->Type()) ||
interpreter::IsCommunicationOp(next_instr.OpBase()->Type());
std::vector<size_t> need_event_var_ids;
for (auto& item : next_instr.Inputs()) {
for (auto var_id : item.second) {
if (unique_var_ids.count(var_id) > 0) {
if (next_instr.NoDataTransformVars().count(var_id)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoDataTransform";
continue;
if (!is_comm) {
if (next_instr.NoDataTransformVars().count(var_id)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoDataTransform";
continue;
}
if (is_no_need_buffer(item.first)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoNeedBufferVar";
continue;
}
}
if (is_no_need_buffer(item.first)) {
VLOG(4) << "Skip inserting event at variable " << item.first
<< " of operator " << next_instr.OpBase()->Type()
<< " since it is NoNeedBufferVar";
continue;
}
need_event_var_ids.push_back(var_id);
}
}
......@@ -175,21 +184,40 @@ void StreamAnalyzer::Schedule(const std::vector<size_t>& downstream_ops,
platform::DeviceContext* StreamAnalyzer::ParseDeviceContext(
const OpFuncNode& op_func_node) {
auto& op_type = op_func_node.operator_base_->Type();
auto* dev_ctx = op_func_node.dev_ctx_;
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
ContextManager& ctx_manager = ContextManager::Instance();
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
// synchronous.
if (platform::is_gpu_place(place_) || platform::is_npu_place(place_) ||
platform::is_custom_place(place_)) {
if (op_type == interpreter::kMemcpyD2H) {
VLOG(3) << "Get dev_ctx from d2h_context_pool_";
dev_ctx = d2h_ctx_.get().get();
return ctx_manager.Get(std::string(kD2HStream), place_).get().get();
} else if (op_type == interpreter::kMemcpyH2D) {
VLOG(3) << "Get dev_ctx from h2d_context_pool_";
dev_ctx = h2d_ctx_.get().get();
return ctx_manager.Get(std::string(kH2DStream), place_).get().get();
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
// NOTE(Ruibiao): Here supports multi-stream overlap for c_allreduce_sum
// with use_cal_stream==false by returning a device context getting from the
// global NCCLCommContext instance. Because when use_calc_stream==false, in
// OP kernel, the NCCL communication will be launched to the stream directly
// getting from the global NCCLCommContext instance rather than the
// DeviceContext passed from executor (see CAllReduceOpCUDAKernel in
// c_allreduce_op.h). Now it is just a temporary solution for ONLY
// c_allreduce_sum which is used in ResNet50 distributed training.
if (op_type == "c_allreduce_sum" &&
op->Attr<bool>("use_calc_stream") == false) {
int ring_id = op->Attr<int>("ring_id");
return platform::NCCLCommContext::Instance()
.Get(ring_id, place_)
->dev_context();
}
#endif
}
return dev_ctx;
return op_func_node.dev_ctx_;
}
/*
......
......@@ -29,7 +29,7 @@ class StreamAnalyzer {
using Place = platform::Place;
using DeviceContext = platform::DeviceContext;
explicit StreamAnalyzer(const Place& place);
explicit StreamAnalyzer(const Place& place) : place_(place) {}
~StreamAnalyzer() {}
......@@ -54,8 +54,6 @@ class StreamAnalyzer {
platform::DeviceType GetWaiterType(const Instruction& instr);
const Place place_;
std::shared_future<std::unique_ptr<platform::DeviceContext>> d2h_ctx_;
std::shared_future<std::unique_ptr<platform::DeviceContext>> h2d_ctx_;
std::map<size_t, std::shared_ptr<platform::DeviceEvent>> var_id2event_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册