提交 57aee805 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!274 GPU multiple stream feature

Merge pull request !274 from ZPaC/gpu-backend-supports-multiple-streams
......@@ -25,7 +25,7 @@ namespace device {
namespace gpu {
void GPUDeviceManager::InitDevice() {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::set_current_device(SizeToInt(cur_dev_id_)), "Failed to set current device id");
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(&stream_), "Failed to create CUDA stream.");
CHECK_OP_RET_WITH_EXCEPT(CreateStream(&default_stream_), "Failed to create CUDA stream.");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnCreate(&cudnn_handle_), "Failed to create cuDNN handle");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetStream(cudnn_handle_, reinterpret_cast<cudaStream_t>(default_stream())),
"Failed to set stream for cuDNN handle.");
......@@ -36,19 +36,27 @@ void GPUDeviceManager::InitDevice() {
}
void GPUDeviceManager::ReleaseDevice() {
if (stream_ != nullptr) {
CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream_), "Failed to destroy cuda stream.");
for (DeviceStream stream : gpu_streams_) {
if (stream != nullptr) {
CHECK_OP_RET_WITH_ERROR(CudaDriver::DestroyStream(stream), "Failed to destroy CUDA stream.");
}
}
if (cudnn_handle_ != nullptr) {
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cudnn handle");
CHECK_CUDNN_RET_WITH_ERROR(cudnnDestroy(cudnn_handle_), "Failed to destroy cuDNN handle");
}
if (cublas_handle_ != nullptr) {
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cublas handle.");
CHECK_CUBLAS_RET_WITH_ERROR(cublasDestroy(cublas_handle_), "Failed to destroy cuBLAS handle.");
}
CHECK_OP_RET_WITH_ERROR(GPUMemoryAllocator::GetInstance().Finalize(), "Failed to destroy gpu memory allocator");
}
const DeviceStream& GPUDeviceManager::default_stream() const { return stream_; }
bool GPUDeviceManager::CreateStream(DeviceStream* stream) {
CHECK_OP_RET_WITH_EXCEPT(CudaDriver::CreateStream(stream), "Failed to create CUDA stream");
gpu_streams_.emplace_back(*stream);
return true;
}
const DeviceStream& GPUDeviceManager::default_stream() const { return default_stream_; }
int GPUDeviceManager::device_count() const { return CudaDriver::device_count(); }
......
......@@ -19,6 +19,7 @@
#include <cudnn.h>
#include <cublas_v2.h>
#include <vector>
#include <memory>
#include "device/gpu/cuda_driver.h"
#include "device/gpu/gpu_memory_allocator.h"
......@@ -36,13 +37,15 @@ class GPUDeviceManager {
uint32_t cur_device_id() const;
bool is_device_id_init() const;
bool CreateStream(DeviceStream* stream);
bool SyncStream(const DeviceStream& stream) const;
const DeviceStream& default_stream() const;
const cudnnHandle_t& GetCudnnHandle() const;
const cublasHandle_t& GetCublasHandle() const;
bool CopyDeviceMemToHost(const HostMemPtr& dst, const DeviceMemPtr& src, size_t size) const;
bool CopyHostMemToDevice(const DeviceMemPtr& dst, const void* src, size_t size) const;
bool SyncStream(const DeviceStream& stream) const;
static GPUDeviceManager& GetInstance() {
static GPUDeviceManager instance;
......@@ -55,13 +58,16 @@ class GPUDeviceManager {
GPUDeviceManager(const GPUDeviceManager&) = delete;
GPUDeviceManager& operator=(const GPUDeviceManager&) = delete;
// default cuda stream used for all the kernels.
DeviceStream stream_{nullptr};
// default CUDA stream used for all the kernels.
DeviceStream default_stream_{nullptr};
// all gpu CUDA streams including default_stream_.
std::vector<DeviceStream> gpu_streams_;
// handle used for cudnn kernels.
// handle used for cuDNN kernels.
cudnnHandle_t cudnn_handle_{nullptr};
// handle used for cublas kernels.
// handle used for cuBLAS kernels.
cublasHandle_t cublas_handle_{nullptr};
bool dev_id_init_;
......
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <set>
#include <string>
#include <memory>
#include <algorithm>
#include "device/gpu/gpu_common.h"
#include "device/gpu/kernel_info_setter.h"
#include "device/gpu/gpu_device_manager.h"
#include "device/gpu/gpu_stream_assign.h"
namespace mindspore {
namespace device {
namespace gpu {
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<CNodePtr> allreduce_cnodes;
auto execution_kernels = kernel_graph->execution_order();
for (auto kernel : execution_kernels) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel);
if (kernel_name == kAllReduceOpName) {
allreduce_cnodes.emplace_back(kernel);
}
}
if (allreduce_cnodes.size() > 1) {
DeviceStream comm_stream = nullptr;
GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
std::transform(allreduce_cnodes.begin(), allreduce_cnodes.end(), allreduce_cnodes.begin(), [&](CNodePtr node) {
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), node);
return node;
});
std::vector<SendRecvPair> send_recv_pairs;
FindAllReduceStreamSwitchPos(kernel_graph, &send_recv_pairs);
InsertStreamSwitchNode(kernel_graph, send_recv_pairs);
}
}
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
std::vector<SendRecvPair> *send_recv_pairs) {
auto execution_kernels = kernel_graph->execution_order();
std::vector<CNodePtr>::iterator iter, iter_begin;
iter = iter_begin = execution_kernels.begin();
std::vector<CNodePtr>::iterator iter_end = execution_kernels.end();
for (; iter != execution_kernels.end(); ++iter) {
std::string kernel_name = AnfAlgo::GetCNodeName(*iter);
if (kernel_name == kAllReduceOpName) {
// Find AllReduce node's last input node.
std::vector<CNodePtr>::iterator mock_send_node_iter =
FindSendNodePos(iter_begin, iter + 1, *iter, kAllReduceStreamSwitch);
if (mock_send_node_iter == iter + 1) {
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
continue;
}
SendRecvPair pair1 = {kAllReduceStreamSwitch, *mock_send_node_iter, *iter,
IntToSize(mock_send_node_iter - iter_begin + 1), IntToSize(iter - iter_begin)};
send_recv_pairs->push_back(pair1);
// Find node which uses AllReduce as input[0].
std::vector<CNodePtr>::iterator mock_recv_node_iter =
FindRecvNodePos(iter, iter_end, *iter, kAllReduceStreamSwitch);
if (mock_recv_node_iter == iter_end) {
MS_LOG(WARNING) << "Can't find send node place before AllReduce node.";
continue;
}
SendRecvPair pair2 = {kAllReduceStreamSwitch, *iter, *mock_recv_node_iter, IntToSize(iter - iter_begin + 1),
IntToSize(mock_recv_node_iter - iter_begin)};
send_recv_pairs->push_back(pair2);
}
}
}
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
std::vector<CNodePtr>::iterator end, const CNodePtr mock_recv_node,
StreamSwitchType stream_switch_type) {
MS_EXCEPTION_IF_NULL(mock_recv_node);
if (stream_switch_type == kAllReduceStreamSwitch) {
for (auto iter = begin; iter != end; iter++) {
if (*(iter + 1) == mock_recv_node) {
return iter;
}
}
}
return end;
}
std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator begin,
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
StreamSwitchType stream_switch_type) {
MS_EXCEPTION_IF_NULL(mock_send_node);
for (auto iter = begin; iter != end; iter++) {
auto node = *iter;
if (stream_switch_type == kAllReduceStreamSwitch) {
for (auto input : node->inputs()) {
if (mock_send_node == AnfAlgo::VisitKernel(input, 0).first) {
return iter;
}
}
}
}
return end;
}
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::vector<SendRecvPair> &send_recv_pairs) {
std::set<StreamSwitchNode> ordered_stream_switch_nodes;
for (SendRecvPair pair : send_recv_pairs) {
StreamSwitchType stream_switch_type = pair.stream_switch_type;
CNodePtr mock_send_node = pair.mock_send_node;
CNodePtr mock_recv_node = pair.mock_recv_node;
size_t send_node_offset = pair.send_node_offset;
size_t recv_node_offset = pair.recv_node_offset;
CNodePtr send_node = nullptr;
CNodePtr recv_node = nullptr;
// Step 1: generate Send and Recv CNodes.
if (stream_switch_type == kAllReduceStreamSwitch) {
if (!GenSendRecvCNodesForAllReduce(kernel_graph, mock_send_node, mock_recv_node, &send_node, &recv_node)) {
MS_LOG(EXCEPTION) << "Generating CNodes for send and recv failed. Stream switch type: kAllReduceStreamSwitch";
}
}
// Step 2: sort send and recv CNodes by offset.
ordered_stream_switch_nodes.insert({send_node_offset, send_node});
ordered_stream_switch_nodes.insert({recv_node_offset, recv_node});
}
// Step 3: insert stream switch CNodes into execution kernel list.
auto execution_kernels = kernel_graph->execution_order();
for (auto node = ordered_stream_switch_nodes.begin(); node != ordered_stream_switch_nodes.end(); node++) {
execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode);
}
kernel_graph->set_execution_order(execution_kernels);
}
bool GenSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph> &kernel_graph,
const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node,
CNodePtr *recv_node) {
*send_node = CreateStreamSwitchNode(kernel_graph, kSendOpName);
MS_EXCEPTION_IF_NULL(*send_node);
*recv_node = CreateStreamSwitchNode(kernel_graph, kRecvOpName);
MS_EXCEPTION_IF_NULL(*recv_node);
cudaEvent_t event = nullptr;
CHECK_CUDA_RET_WITH_EXCEPT(cudaEventCreate(&event, cudaEventDisableTiming), "Creating cuda event failed.");
AnfAlgo::SetNodeAttr("record_event", MakeValue(reinterpret_cast<uintptr_t>(event)), *send_node);
AnfAlgo::SetNodeAttr("wait_event", MakeValue(reinterpret_cast<uintptr_t>(event)), *recv_node);
uintptr_t send_stream = AnfAlgo::GetNodeAttr<uintptr_t>(mock_send_node, "stream_id");
AnfAlgo::SetNodeAttr("record_event_stream", MakeValue(send_stream), *send_node);
uintptr_t recv_stream = AnfAlgo::GetNodeAttr<uintptr_t>(mock_recv_node, "stream_id");
AnfAlgo::SetNodeAttr("wait_event_stream", MakeValue(recv_stream), *recv_node);
return true;
}
CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &name) {
auto op = std::make_shared<Primitive>(name);
auto apply = std::make_shared<ValueNode>(op);
std::vector<AnfNodePtr> input_list = {apply};
CNodePtr node = kernel_graph->NewCNode(input_list);
MS_EXCEPTION_IF_NULL(node);
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
AnfAlgo::SetSelectKernelBuildInfo(selected_kernel_builder.Build(), node.get());
auto abstract_none = std::make_shared<abstract::AbstractNone>();
node->set_abstract(abstract_none);
SetKernelInfo(node);
return node;
}
} // namespace gpu
} // namespace device
} // namespace mindspore
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
#define MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
#include <vector>
#include <string>
#include <memory>
#include "session/kernel_graph.h"
#include "session/anf_runtime_algorithm.h"
namespace mindspore {
namespace device {
namespace gpu {
enum StreamSwitchType { kAllReduceStreamSwitch, kStreamSwitchInvalidType = 255 };
struct SendRecvPair {
StreamSwitchType stream_switch_type;
CNodePtr mock_send_node;
CNodePtr mock_recv_node;
size_t send_node_offset;
size_t recv_node_offset;
};
struct StreamSwitchNode {
size_t offset;
CNodePtr cnode;
bool operator<(const StreamSwitchNode &n) const {
if (offset < n.offset) {
return true;
} else if (offset == n.offset) {
return AnfAlgo::GetCNodeName(cnode) == kSendOpName ? true : false;
} else {
return false;
}
}
};
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph);
void FindAllReduceStreamSwitchPos(const std::shared_ptr<session::KernelGraph> &kernel_graph,
std::vector<SendRecvPair> *send_recv_pairs);
// Find Send node position according to "mock" recv node.
// "mock" recv node is a gpu kernel node after a real Recv node, e.g. AllReduce node.
std::vector<CNodePtr>::iterator FindSendNodePos(std::vector<CNodePtr>::iterator begin,
std::vector<CNodePtr>::iterator end, const CNodePtr mock_recv_node,
StreamSwitchType stream_switch_type);
// Find Recv node position according to "mock" send node.
// "mock" send node is a gpu kernel node before a real send node, e.g. AllReduce node.
std::vector<CNodePtr>::iterator FindRecvNodePos(std::vector<CNodePtr>::iterator begin,
std::vector<CNodePtr>::iterator end, const CNodePtr mock_send_node,
StreamSwitchType stream_switch_type);
void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph,
const std::vector<SendRecvPair> &send_recv_pairs);
bool GenSendRecvCNodesForAllReduce(const std::shared_ptr<session::KernelGraph> &kernel_graph,
const CNodePtr &mock_send_node, const CNodePtr &mock_recv_node, CNodePtr *send_node,
CNodePtr *recv_node);
CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_graph, const std::string &name);
} // namespace gpu
} // namespace device
} // namespace mindspore
#endif
......@@ -52,7 +52,8 @@ class NcclGpuKernel : public GpuKernel {
nccl_reduce_type_(ncclSum),
input_size_(0),
output_size_(0),
collective_handle_(nullptr) {}
collective_handle_(nullptr),
comm_stream_(nullptr) {}
~NcclGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
......@@ -63,34 +64,33 @@ class NcclGpuKernel : public GpuKernel {
T *input_addr = GetDeviceAddress<T>(inputs, 0);
T *output_addr = GetDeviceAddress<T>(outputs, 0);
cudaStream_t stream = comm_stream_ ? comm_stream_ : reinterpret_cast<cudaStream_t>(stream_ptr);
switch (nccl_kernel_type_) {
case NCCL_ALL_REDUCE: {
auto all_reduce_funcptr =
reinterpret_cast<AllReduce>(dlsym(const_cast<void *>(collective_handle_), "AllReduce"));
MS_EXCEPTION_IF_NULL(all_reduce_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT(
(*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_, nccl_reduce_type_,
reinterpret_cast<cudaStream_t>(stream_ptr)),
"ncclAllReduce failed");
CHECK_NCCL_RET_WITH_EXCEPT((*all_reduce_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, nccl_reduce_type_, stream),
"ncclAllReduce failed");
break;
}
case NCCL_ALL_GATHER: {
auto all_gather_funcptr =
reinterpret_cast<AllGather>(dlsym(const_cast<void *>(collective_handle_), "AllGather"));
MS_EXCEPTION_IF_NULL(all_gather_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT((*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T),
nccl_data_type_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"ncclAllGather failed");
CHECK_NCCL_RET_WITH_EXCEPT(
(*all_gather_funcptr)(input_addr, output_addr, input_size_ / sizeof(T), nccl_data_type_, stream),
"ncclAllGather failed");
break;
}
case NCCL_REDUCE_SCATTER: {
auto reduce_scatter_funcptr =
reinterpret_cast<ReduceScatter>(dlsym(const_cast<void *>(collective_handle_), "ReduceScatter"));
MS_EXCEPTION_IF_NULL(reduce_scatter_funcptr);
CHECK_NCCL_RET_WITH_EXCEPT(
(*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T), nccl_data_type_,
nccl_reduce_type_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"ncclReduceScatter failed");
CHECK_NCCL_RET_WITH_EXCEPT((*reduce_scatter_funcptr)(input_addr, output_addr, output_size_ / sizeof(T),
nccl_data_type_, nccl_reduce_type_, stream),
"ncclReduceScatter failed");
break;
}
default: {
......@@ -167,6 +167,7 @@ class NcclGpuKernel : public GpuKernel {
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
const void *collective_handle_;
cudaStream_t comm_stream_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -17,6 +17,7 @@
#include "device/gpu/kernel_info_setter.h"
#include "device/gpu/gpu_kernel_build.h"
#include "device/gpu/gpu_kernel_runtime.h"
#include "device/gpu/gpu_stream_assign.h"
#include "pre_activate/common/optimizer.h"
#include "pre_activate/common/pass_manager.h"
#include "pre_activate/common/ir_fusion/allreduce_fusion.h"
......@@ -55,6 +56,11 @@ void GPUSession::Optimize(const std::shared_ptr<KernelGraph> &kernel_graph) {
kernel_graph->SetExecOrderByDefault();
}
void GPUSession::AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph);
device::gpu::AssignGpuStream(kernel_graph);
}
void GPUSession::BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const {
device::gpu::GpuBuild(kernel_graph);
}
......@@ -94,6 +100,8 @@ GraphId GPUSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList
StartKernelRT();
// AllReduce Optimize
Optimize(graph);
// Assign CUDA streams
AssignStream(graph);
// Build kernel if node is cnode
BuildKernel(graph);
// Set graph execution order before memory alloc, ensure that memory alloc is according to the reorder graph
......
......@@ -49,6 +49,8 @@ class GPUSession : public SessionBasic {
void Optimize(const std::shared_ptr<KernelGraph> &kernel_graph);
void AssignStream(const std::shared_ptr<KernelGraph> &kernel_graph);
void BuildKernel(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void AllocateMemory(KernelGraph *kernel_graph) const;
......
......@@ -113,6 +113,8 @@ constexpr auto kFusedMulAddNOpName = "FusedMulAddN";
constexpr auto kFusedMulApplyMomentumOpName = "FusedMulApplyMomentum";
constexpr auto kBiasAddOpName = "BiasAdd";
constexpr auto kConfusionMulGradOpName = "ConfusionMulGrad";
constexpr auto kSendOpName = "Send";
constexpr auto kRecvOpName = "Recv";
// attr key name
constexpr auto kAttrInputNames = "input_names";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册