未验证 提交 172d1de6 编写于 作者: R Ruibiao Chen 提交者: GitHub

Support stream priority for standalone executor (#49939)

* Support stream priority for standalone executor

* Fix compile error

* Fix compile error

* Fix compile error

* Fix compile error

* Fix compile error
上级 f12f2a9d
......@@ -293,6 +293,7 @@ std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_idx",
"is_recompute",
"execution_stream",
"stream_priority",
"scheduling_priority"};
OperatorDistAttr::OperatorDistAttr(const OpDesc& op) {
......@@ -318,6 +319,8 @@ OperatorDistAttr& OperatorDistAttr::operator=(
std::swap(this->impl_idx_, tmp.impl_idx_);
std::swap(this->is_recompute_, tmp.is_recompute_);
std::swap(this->execution_stream_, tmp.execution_stream_);
std::swap(this->stream_priority_, tmp.stream_priority_);
std::swap(this->scheduling_priority_, tmp.scheduling_priority_);
std::swap(this->annotated_, tmp.annotated_);
// Note: Make sure all tensor dist attr has the same process_mesh
set_process_mesh(this->process_mesh_);
......@@ -349,6 +352,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
impl_idx_ = 0;
is_recompute_ = false;
execution_stream_ = kDefault;
stream_priority_ = 0;
scheduling_priority_ = 0;
}
......@@ -361,6 +365,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_impl_idx(dist_attr.impl_idx());
set_is_recompute(dist_attr.is_recompute());
set_execution_stream(dist_attr.execution_stream());
set_stream_priority(dist_attr.stream_priority());
set_scheduling_priority(dist_attr.scheduling_priority());
set_annotated(dist_attr.annotated());
}
......@@ -599,6 +604,7 @@ std::string OperatorDistAttr::to_string() const {
str += "{impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "execution_stream: " + execution_stream_ + ", ";
str += "stream_priority: " + std::to_string(stream_priority_) + ", ";
str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", ";
str += "annotated: [" + str_join(annotated_) + "], ";
str += "\nprocess_mesh: " + process_mesh_.to_string() + ", ";
......@@ -684,6 +690,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.execution_stream() != rhs.execution_stream()) {
return false;
}
if (lhs.stream_priority() != rhs.stream_priority()) {
return false;
}
if (lhs.scheduling_priority() != rhs.scheduling_priority()) {
return false;
}
......
......@@ -222,6 +222,12 @@ class OperatorDistAttr {
execution_stream_ = execution_stream;
}
int stream_priority() const { return stream_priority_; }
void set_stream_priority(int stream_priority) {
stream_priority_ = stream_priority;
}
int64_t scheduling_priority() const { return scheduling_priority_; }
void set_scheduling_priority(int64_t scheduling_priority) {
......@@ -289,6 +295,7 @@ class OperatorDistAttr {
int64_t impl_idx_ = 0;
bool is_recompute_ = false;
std::string execution_stream_ = kDefault;
int stream_priority_ = 0; // lower value, higher priority
int64_t scheduling_priority_ = 0; // lower value, higher priority
std::map<std::string, bool> annotated_;
};
......
......@@ -46,7 +46,8 @@ FastThreadedSSAGraphExecutor::FastThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);
if (ir::IsTopologySortOperationsUnique(*graph_)) {
VLOG(10)
<< "Change thread number to 1 because the toposort order is unique";
......
......@@ -41,7 +41,8 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
platform::EmplaceDeviceContexts(
&fetch_ctxs_,
places,
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
/*stream_priority=*/0);
if (strategy_.num_iteration_per_run_ > 1) {
int read_op_num = 0;
......
......@@ -666,21 +666,28 @@ bool BuildOpFuncList(const platform::Place& place,
op_func_node.output_index = outs_name2id;
const OperatorDistAttr* dist_attr = block.Op(i)->DistAttr();
if (dist_attr &&
dist_attr->execution_stream() != distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}
if (dist_attr) {
op_func_node.priority_ = dist_attr->scheduling_priority();
} else if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than that
// of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32 training).
op_func_node.priority_ = 1;
if (dist_attr->execution_stream() !=
distributed::auto_parallel::kDefault) {
op_func_node.execution_stream_ = dist_attr->execution_stream();
}
op_func_node.stream_priority_ = dist_attr->stream_priority();
op_func_node.scheduling_priority_ = dist_attr->scheduling_priority();
} else {
if (interpreter::IsCommunicationOp(op_type)) {
// NOTE(Ruibiao): Dispatching computation before communication improves
// multi-stream overlap when the time cost of communication less than
// that of the calculation (e.g., ResNet50_bs128_pure_fp16 N4C32
// training).
op_func_node.scheduling_priority_ = 1;
}
}
VLOG(6) << "scheduling priority of " << op_type << " : "
<< op_func_node.priority_;
VLOG(6) << op_type
<< " : [execution_stream, stream_priority, scheduling_priority] = ["
<< op_func_node.execution_stream_ << ", "
<< op_func_node.stream_priority_ << ", "
<< op_func_node.scheduling_priority_ << "]";
SingleStreamGuard single_stream_guard(ops[i]);
......
......@@ -39,7 +39,9 @@ class ContextManager {
}
std::shared_future<std::unique_ptr<DeviceContext>> Get(
const std::string& type, const platform::Place& place) {
const std::string& type,
const platform::Place& place,
int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get dev_ctx for " << type << " - " << place;
......@@ -48,7 +50,8 @@ class ContextManager {
platform::EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true);
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place];
}
......@@ -142,6 +145,7 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
const std::string& execution_stream = op_func_node.execution_stream_;
const int stream_priority = op_func_node.stream_priority_;
ContextManager& ctx_manager = ContextManager::Instance();
// only gpu/npu need update. xpu not need, because xpu memcpy op kernel is
......@@ -152,15 +156,21 @@ DeviceContext* StreamAnalyzer::ParseDeviceContext(
<< ", execution stream = " << execution_stream;
if (execution_stream != kDefaultStream) {
return ctx_manager
.Get(std::string(kCustomStream) + "-" + execution_stream, place_)
.Get(std::string(kCustomStream) + "-" + execution_stream,
place_,
stream_priority)
.get()
.get();
}
if (op_type == interpreter::kMemcpyD2H) {
return ctx_manager.Get(std::string(kD2HStream), place_).get().get();
return ctx_manager.Get(std::string(kD2HStream), place_, stream_priority)
.get()
.get();
} else if (op_type == interpreter::kMemcpyH2D) {
return ctx_manager.Get(std::string(kH2DStream), place_).get().get();
return ctx_manager.Get(std::string(kH2DStream), place_, stream_priority)
.get()
.get();
}
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
......
......@@ -139,13 +139,15 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
}
var_scope_.SetLocalScope(local_scope_);
instruction_prority_less = [this](size_t lhs, size_t rhs) {
Priority lhs_prority = vec_instruction_[lhs].GetPriority();
Priority rhs_prority = vec_instruction_[rhs].GetPriority();
if (lhs_prority == rhs_prority) {
instruction_scheduling_priority_less = [this](size_t lhs, size_t rhs) {
SchedulingPriority lhs_scheduling_priority =
vec_instruction_[lhs].GetSchedulingPriority();
SchedulingPriority rhs_scheduling_priority =
vec_instruction_[rhs].GetSchedulingPriority();
if (lhs_scheduling_priority == rhs_scheduling_priority) {
return lhs < rhs;
}
return lhs_prority > rhs_prority;
return lhs_scheduling_priority > rhs_scheduling_priority;
};
PrepareForCUDAGraphCapture();
......@@ -1089,7 +1091,7 @@ void InterpreterCore::RunInstructionAsync(size_t instr_id) {
// scheduling, the priority order involved cross-thread scheduling is not
// guaranteed. Only Ops scheduled by the same AddTask call have the guarantee
// of priority order.
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);
ready_ops.push(instr_id);
while (!ready_ops.empty()) {
instr_id = ready_ops.top();
......@@ -1427,7 +1429,7 @@ void InterpreterCore::AnalyseExecuteOrderForTrace() {
};
std::vector<size_t> trace_order;
SchedulingQueue ready_ops(instruction_prority_less);
SchedulingQueue ready_ops(instruction_scheduling_priority_less);
for (size_t instr_id = 0; instr_id < dependecy_count_.size(); ++instr_id) {
if (dependecy_count_[instr_id] == 0) {
......
......@@ -79,9 +79,11 @@ class InterpreterCore {
const platform::Place& GetPlace() const { return place_; }
private:
using InstructionPriorityLess = std::function<bool(size_t, size_t)>;
using InstructionSchedulingPriorityLess = std::function<bool(size_t, size_t)>;
using SchedulingQueue =
std::priority_queue<size_t, std::vector<size_t>, InstructionPriorityLess>;
std::priority_queue<size_t,
std::vector<size_t>,
InstructionSchedulingPriorityLess>;
// build graph
void Convert(std::vector<paddle::framework::OpFuncNode>* op_func_nodes);
......@@ -181,7 +183,7 @@ class InterpreterCore {
int64_t sync_op_num_{-1};
std::vector<size_t> trace_execute_order_;
InstructionPriorityLess instruction_prority_less;
InstructionSchedulingPriorityLess instruction_scheduling_priority_less;
};
std::shared_ptr<InterpreterCore> CreateInterpreterCore(
......
......@@ -32,12 +32,12 @@ namespace framework {
using OpKernelComputeFunc = std::function<void(const ExecutionContext&)>;
using Priority = int64_t;
using SchedulingPriority = int64_t;
constexpr const char* kCoalesceTensor = "coalesce_tensor";
// stream types
constexpr const char* kCustomStream = "CustromStream";
constexpr const char* kCustomStream = "CustomStream";
constexpr const char* kDefaultStream = "DefaultStream";
constexpr const char* kD2HStream = "D2HStream";
constexpr const char* kH2DStream = "H2DStream";
......@@ -263,6 +263,7 @@ enum class OpFuncType {
class RuntimeInferShapeContext;
struct OpFuncNode {
int stream_priority_{0}; // lower value, higher priority
// fit for phi kernel
phi::Kernel* phi_kernel_{nullptr}; // not owned
platform::DeviceContext* dev_ctx_; // not owned
......@@ -279,7 +280,7 @@ struct OpFuncNode {
OpFuncType type_;
OpKernelComputeFunc kernel_func_;
Priority priority_{0}; // lower value, higher priority
SchedulingPriority scheduling_priority_{0}; // lower value, higher priority
};
class Instruction {
......@@ -369,7 +370,9 @@ class Instruction {
void ClearInplace();
Priority GetPriority() const { return op_func_node_.priority_; }
SchedulingPriority GetSchedulingPriority() const {
return op_func_node_.scheduling_priority_;
}
private:
bool is_artificial_; // Instruction is artificial means that it is only used
......
......@@ -25,13 +25,13 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/expect.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
#include "paddle/fluid/platform/cuda_device_guard.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#endif
#ifdef PADDLE_WITH_MLU
......@@ -145,12 +145,37 @@ void DeviceContextPool::SetDeviceContexts(
external_device_contexts_ = dev_ctxs;
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
template <typename DevCtx>
typename std::enable_if<!std::is_same<DevCtx, phi::GPUContext>::value,
DevCtx*>::type
ConstructDevCtx(const platform::Place& p, /*unused*/ int stream_priority = 0) {
return new DevCtx(p);
}
template <typename DevCtx>
typename std::enable_if<std::is_same<DevCtx, phi::GPUContext>::value,
DevCtx*>::type
ConstructDevCtx(const platform::Place& p, int stream_priority) {
return new DevCtx(p, /*init=*/true, stream_priority);
}
#else
template <typename DevCtx>
DevCtx* ConstructDevCtx(const platform::Place& p,
/*unused*/ int stream_priority) {
return new DevCtx(p);
}
#endif
template <typename DevCtx>
std::unique_ptr<DeviceContext> CreateDeviceContext(
const platform::Place& p,
bool disable_setting_default_stream_for_allocator = false) {
bool disable_setting_default_stream_for_allocator = false,
int stream_priority = 0) {
using PtrType = std::unique_ptr<DeviceContext>;
auto* dev_ctx = new DevCtx(p);
DevCtx* dev_ctx = ConstructDevCtx<DevCtx>(p, stream_priority);
if (is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto* cuda_ctx = dynamic_cast<phi::GPUContext*>(dev_ctx);
......@@ -201,21 +226,24 @@ inline void EmplaceDeviceContext(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
platform::Place place,
bool disable_setting_default_stream_for_allocator) {
bool disable_setting_default_stream_for_allocator,
int stream_priority) {
// lazy evaluation. i.e., only create device context at first `Get`
place_to_device_context->emplace(
place,
std::async(std::launch::deferred,
CreateDeviceContext<DevCtx>,
place,
disable_setting_default_stream_for_allocator));
disable_setting_default_stream_for_allocator,
stream_priority));
}
void EmplaceDeviceContexts(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
const std::vector<platform::Place>& places,
bool disable_setting_default_stream_for_allocator) {
bool disable_setting_default_stream_for_allocator,
int stream_priority) {
PADDLE_ENFORCE_GT(
places.size(),
0,
......@@ -234,19 +262,22 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<phi::OneDNNContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
EmplaceDeviceContext<phi::CPUContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#endif
} else if (platform::is_gpu_place(p)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EmplaceDeviceContext<phi::GPUContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
stream_priority);
#else
PADDLE_THROW(
platform::errors::Unimplemented("CUDAPlace is not supported. Please "
......@@ -257,7 +288,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<CUDAPinnedDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CUDAPlace is not supported. Please re-compile with WITH_GPU "
......@@ -268,7 +300,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<XPUDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(
platform::errors::Unimplemented("XPUPlace is not supported. Please "
......@@ -279,7 +312,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<MLUDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(
platform::errors::Unimplemented("MLUPlace is not supported. Please "
......@@ -290,7 +324,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<IPUDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(
platform::errors::Unimplemented("IPUPlace is not supported. Please "
......@@ -301,7 +336,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<NPUDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPlace is not supported. Please "
......@@ -312,7 +348,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<NPUPinnedDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"NPUPinnedPlace is not supported. Please re-compile with "
......@@ -324,7 +361,8 @@ void EmplaceDeviceContexts(
EmplaceDeviceContext<CustomDeviceContext>(
place_to_device_context,
p,
disable_setting_default_stream_for_allocator);
disable_setting_default_stream_for_allocator,
/*unused*/ stream_priority);
#else
PADDLE_THROW(platform::errors::Unimplemented(
"CustomPlace is not supported. Please re-compile with "
......@@ -339,7 +377,8 @@ DeviceContextPool::DeviceContextPool(
const std::vector<platform::Place>& places) {
EmplaceDeviceContexts(&device_contexts_,
places,
/*disable_setting_default_stream_for_allocator=*/false);
/*disable_setting_default_stream_for_allocator=*/false,
/*stream_priority=*/0);
}
#ifdef PADDLE_WITH_IPU
......
......@@ -346,7 +346,8 @@ void EmplaceDeviceContexts(
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>*
place_to_device_context,
const std::vector<platform::Place>& places,
bool disable_setting_default_stream_for_allocator);
bool disable_setting_default_stream_for_allocator,
int stream_priority);
/*! \brief device context pool singleton */
class DeviceContextPool {
......
......@@ -288,6 +288,9 @@ void BindAutoParallel(py::module *m) {
.def_property("execution_stream",
&OperatorDistAttr::execution_stream,
&OperatorDistAttr::set_execution_stream)
.def_property("stream_priority",
&OperatorDistAttr::stream_priority,
&OperatorDistAttr::set_stream_priority)
.def_property("scheduling_priority",
&OperatorDistAttr::scheduling_priority,
&OperatorDistAttr::set_scheduling_priority)
......
......@@ -252,8 +252,6 @@ void BindCudaStream(py::module *m_ptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Priority should be 1(high) or 2(normal) "));
}
auto prio = phi::CUDAStream::Priority(priority);
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
if (place == nullptr) {
int curr_device_id = platform::GetCurrentDeviceId();
......@@ -261,7 +259,10 @@ void BindCudaStream(py::module *m_ptr) {
place = &place_tmp;
}
new (&self) phi::CUDAStream(*place, prio, stream_flag);
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
// seting priority 1(high) and 2(normal) correspond to the actual
// cuda stream priority -1 and 0.
new (&self) phi::CUDAStream(*place, priority - 2, stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
......@@ -277,8 +278,6 @@ void BindCudaStream(py::module *m_ptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"Priority should be 1(high) or 2(normal) "));
}
auto prio = phi::CUDAStream::Priority(priority);
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
int device_count = platform::GetGPUDeviceCount();
if (device < 0) {
......@@ -291,8 +290,11 @@ void BindCudaStream(py::module *m_ptr) {
device));
}
new (&self)
phi::CUDAStream(platform::CUDAPlace(device), prio, stream_flag);
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
// seting priority 1(high) and 2(normal) correspond to the actual
// cuda stream priority -1 and 0.
new (&self) phi::CUDAStream(
platform::CUDAPlace(device), priority - 2, stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
......@@ -302,13 +304,10 @@ void BindCudaStream(py::module *m_ptr) {
py::arg("priority") = 2)
.def("__init__", [](phi::CUDAStream &self) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
auto prio = phi::CUDAStream::Priority::kNormal;
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
int device_id = platform::GetCurrentDeviceId();
new (&self)
phi::CUDAStream(platform::CUDAPlace(device_id), prio, stream_flag);
auto stream_flag = phi::CUDAStream::StreamFlag::kStreamNonBlocking;
new (&self) phi::CUDAStream(
platform::CUDAPlace(device_id), /*priority=*/0, stream_flag);
#else
PADDLE_THROW(platform::errors::Unavailable(
"Class CUDAStream can only be initialized on the GPU platform."));
......
......@@ -196,6 +196,13 @@ std::array<int, 3> GetGpuMaxGridDimSize(int id) {
return ret;
}
std::pair<int, int> GetGpuStreamPriorityRange() {
int least_priority, greatest_priority;
PADDLE_ENFORCE_GPU_SUCCESS(
cudaDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
return std::make_pair(least_priority, greatest_priority);
}
const gpuDeviceProp &GetDeviceProperties(int id) {
std::call_once(g_device_props_size_init_flag, [&] {
int gpu_num = 0;
......
......@@ -218,7 +218,7 @@ struct GPUContext::Impl {
InitDnnWorkspace();
}
void PartialInitWithoutAllocator() {
void PartialInitWithoutAllocator(int stream_priority) {
owned_ = true;
stream_owned_ = true;
backends::gpu::GPUDeviceGuard guard(place_.device);
......@@ -230,7 +230,7 @@ struct GPUContext::Impl {
&max_threads_per_mp_,
&max_threads_per_block_,
&max_grid_dim_size_);
stream_ = new CUDAStream(place_);
stream_ = new CUDAStream(place_, stream_priority);
}
void PartialInitWithAllocator() {
......@@ -818,10 +818,10 @@ GPUContext::GPUContext(GPUContext&&) = default;
GPUContext& GPUContext::operator=(GPUContext&&) = default;
GPUContext::GPUContext(const GPUPlace& place, bool init)
GPUContext::GPUContext(const GPUPlace& place, bool init, int stream_priority)
: DeviceContext(), impl_(std::make_unique<Impl>(place)) {
if (init) {
impl_->PartialInitWithoutAllocator();
impl_->PartialInitWithoutAllocator(stream_priority);
}
}
......@@ -1001,8 +1001,8 @@ void GPUContext::SetDnnWorkspaceHandle(DnnWorkspaceHandle* handle) {
impl_->workspace_ = handle;
}
void GPUContext::PartialInitWithoutAllocator() {
impl_->PartialInitWithoutAllocator();
void GPUContext::PartialInitWithoutAllocator(int stream_priority) {
impl_->PartialInitWithoutAllocator(stream_priority);
}
void GPUContext::PartialInitWithAllocator() {
......
......@@ -81,7 +81,9 @@ class DnnWorkspaceHandle {
class PADDLE_API GPUContext : public DeviceContext,
public TypeInfoTraits<DeviceContext, GPUContext> {
public:
explicit GPUContext(const GPUPlace& place, bool init = true);
explicit GPUContext(const GPUPlace& place,
bool init = true,
int stream_priority = 0);
GPUContext(GPUContext&&);
GPUContext& operator=(GPUContext&&);
......@@ -198,7 +200,7 @@ class PADDLE_API GPUContext : public DeviceContext,
// Note that this is a trick implementation, which can be used to partially
// initialize when the SetAllocator interface is not called.
void PartialInitWithoutAllocator();
void PartialInitWithoutAllocator(int stream_priority = 0);
// Note that this is a trick implementation that can be used to initialize
// resources that require an Allocator when the SetAllocator interface is
// called.
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <array>
#include <string>
#include <utility>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_types.h"
......@@ -58,6 +59,8 @@ int GetCurrentDeviceId();
//! Get the maximum GridDim size for GPU buddy allocator.
std::array<int, 3> GetGpuMaxGridDimSize(int);
std::pair<int, int> GetGpuStreamPriorityRange();
//! Get a list of device ids from environment variable or use all.
std::vector<int> GetSelectedDevices();
......
......@@ -200,6 +200,13 @@ std::array<int, 3> GetGpuMaxGridDimSize(int id) {
return ret;
}
std::pair<int, int> GetGpuStreamPriorityRange() {
int least_priority, greatest_priority;
PADDLE_ENFORCE_GPU_SUCCESS(
hipDeviceGetStreamPriorityRange(&least_priority, &greatest_priority));
return std::make_pair(least_priority, greatest_priority);
}
const gpuDeviceProp &GetDeviceProperties(int id) {
std::call_once(g_device_props_size_init_flag, [&] {
int gpu_num = 0;
......
......@@ -35,12 +35,6 @@ namespace phi {
// Currently, CudaStream is used in python-side API only
class CUDAStream {
public:
enum class Priority : uint8_t {
kNull = 0x0,
kHigh = 0x1,
kNormal = 0x2,
};
enum class StreamFlag : uint8_t {
kDefaultFlag = 0x0,
kStreamNonBlocking = 0x1,
......@@ -50,29 +44,41 @@ class CUDAStream {
CUDAStream(const Place& place, const Stream& stream)
: place_(place), stream_(stream) {}
CUDAStream(const Place& place,
const Priority& priority = Priority::kNormal,
const int priority = 0,
const StreamFlag& flag = StreamFlag::kDefaultFlag) {
place_ = place;
gpuStream_t stream = nullptr;
backends::gpu::GPUDeviceGuard guard(place_.device);
if (priority == Priority::kHigh) {
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), -1));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), -1));
#endif
} else if (priority == Priority::kNormal) {
// Stream priorities follow a convention where lower numbers imply greater
// priorities
auto priority_range = backends::gpu::GetGpuStreamPriorityRange();
int least_priority = priority_range.first; // 0 in V100
int greatest_priority = priority_range.second; // -5 in V100
// NOTE(Ruibiao): Replacing the following `PADDLE_ENFORCE_EQ` with
// `PADDLE_ENFORCE` leads to a nvcc compile error. This is probably a bug.
PADDLE_ENFORCE_EQ(
priority <= least_priority && priority >= greatest_priority,
true,
phi::errors::InvalidArgument(
"Cannot create a stream with priority = %d because stream priority "
"must be inside the meaningful range [%d, %d].",
priority,
least_priority,
greatest_priority));
#ifdef PADDLE_WITH_HIP
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), 0));
PADDLE_ENFORCE_GPU_SUCCESS(hipStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), priority));
#else
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), 0));
PADDLE_ENFORCE_GPU_SUCCESS(cudaStreamCreateWithPriority(
&stream, static_cast<unsigned int>(flag), priority));
#endif
}
VLOG(10) << "CUDAStream " << stream;
VLOG(10) << "Create CUDAStream " << stream
<< " with priority = " << priority
<< ", flag = " << static_cast<unsigned int>(flag);
stream_ = Stream(reinterpret_cast<StreamId>(stream));
owned_ = true;
}
......
......@@ -50,8 +50,10 @@ class TestCustomStream(unittest.TestCase):
ops = prog.global_block().ops
for op_index in op_index_for_stream1:
ops[op_index].dist_attr.execution_stream = "s1"
ops[op_index].dist_attr.stream_priority = -1
for op_index in op_index_for_stream2:
ops[op_index].dist_attr.execution_stream = "s2"
ops[op_index].dist_attr.stream_priority = -2
def run_program(self, apply_custom_stream=False):
paddle.seed(2022)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册