提交 2ca9e448 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!360 Add GPU send and recv control kernels

Merge pull request !360 from ZPaC/gpu-backend-supports-multiple-streams
...@@ -96,7 +96,7 @@ size_t CudaDriver::free_mem_size() { ...@@ -96,7 +96,7 @@ size_t CudaDriver::free_mem_size() {
} }
bool CudaDriver::CreateStream(DeviceStream *stream) { bool CudaDriver::CreateStream(DeviceStream *stream) {
auto ret = cudaStreamCreate(reinterpret_cast<CUstream_st **>(stream)); auto ret = cudaStreamCreateWithFlags(reinterpret_cast<CUstream_st **>(stream), cudaStreamNonBlocking);
if (ret != cudaSuccess) { if (ret != cudaSuccess) {
MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast<int>(ret) << "], " << cudaGetErrorString(ret); MS_LOG(ERROR) << "cudaStreamCreate failed, ret[" << static_cast<int>(ret) << "], " << cudaGetErrorString(ret);
return false; return false;
......
...@@ -28,20 +28,24 @@ namespace device { ...@@ -28,20 +28,24 @@ namespace device {
namespace gpu { namespace gpu {
void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph) { void AssignGpuStream(const std::shared_ptr<session::KernelGraph> &kernel_graph) {
MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(kernel_graph);
std::vector<CNodePtr> allreduce_cnodes; std::vector<CNodePtr> allreduce_kernels;
auto execution_kernels = kernel_graph->execution_order(); auto execution_kernels = kernel_graph->execution_order();
for (auto kernel : execution_kernels) { for (auto kernel_node : execution_kernels) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel); std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
if (kernel_name == kAllReduceOpName) { if (kernel_name == kAllReduceOpName) {
allreduce_cnodes.emplace_back(kernel); allreduce_kernels.emplace_back(kernel_node);
} else {
DeviceStream compute_stream = GPUDeviceManager::GetInstance().default_stream();
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(compute_stream)), kernel_node);
} }
} }
if (allreduce_cnodes.size() > 1) { if (allreduce_kernels.size() > 1) {
DeviceStream comm_stream = nullptr; DeviceStream comm_stream = nullptr;
GPUDeviceManager::GetInstance().CreateStream(&comm_stream); GPUDeviceManager::GetInstance().CreateStream(&comm_stream);
std::transform(allreduce_cnodes.begin(), allreduce_cnodes.end(), allreduce_cnodes.begin(), [&](CNodePtr node) { std::transform(
AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), node); allreduce_kernels.begin(), allreduce_kernels.end(), allreduce_kernels.begin(), [&](CNodePtr allreduce_kernel) {
return node; AnfAlgo::SetNodeAttr("stream_id", MakeValue(reinterpret_cast<uintptr_t>(comm_stream)), allreduce_kernel);
return allreduce_kernel;
}); });
std::vector<SendRecvPair> send_recv_pairs; std::vector<SendRecvPair> send_recv_pairs;
...@@ -137,7 +141,7 @@ void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_ ...@@ -137,7 +141,7 @@ void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_
} }
// Step 3: insert stream switch CNodes into execution kernel list. // Step 3: insert stream switch CNodes into execution kernel list.
auto execution_kernels = kernel_graph->execution_order(); auto execution_kernels = kernel_graph->execution_order();
for (auto node = ordered_stream_switch_nodes.begin(); node != ordered_stream_switch_nodes.end(); node++) { for (auto node = ordered_stream_switch_nodes.rbegin(); node != ordered_stream_switch_nodes.rend(); node++) {
execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode); execution_kernels.insert(execution_kernels.begin() + node->offset, node->cnode);
} }
kernel_graph->set_execution_order(execution_kernels); kernel_graph->set_execution_order(execution_kernels);
......
/** /**
* Copyright 2019 Huawei Technologies Co., Ltd * Copyright 2020 Huawei Technologies Co., Ltd
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -70,4 +70,4 @@ CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &ker ...@@ -70,4 +70,4 @@ CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &ker
} // namespace device } // namespace device
} // namespace mindspore } // namespace mindspore
#endif #endif // MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/control/recv_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_REGULAR(Recv, KernelAttr(), RecvGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class RecvGpuKernel : public GpuKernel {
public:
RecvGpuKernel() {}
~RecvGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
uintptr_t) override {
CHECK_CUDA_RET_WITH_EXCEPT(cudaStreamWaitEvent(wait_stream_, wait_event_, 0), "Waiting cuda event failed.");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
wait_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "wait_event_stream"));
wait_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "wait_event"));
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
return;
}
private:
cudaStream_t wait_stream_{nullptr};
cudaEvent_t wait_event_{nullptr};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/control/send_gpu_kernel.h"
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_REGULAR(Send, KernelAttr(), SendGpuKernel)
} // namespace kernel
} // namespace mindspore
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace mindspore {
namespace kernel {
class SendGpuKernel : public GpuKernel {
public:
SendGpuKernel() {}
~SendGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
const std::vector<size_t> &GetWorkspaceSizeList() const override { return workspace_size_list_; }
bool Launch(const std::vector<AddressPtr> &, const std::vector<AddressPtr> &, const std::vector<AddressPtr> &,
uintptr_t) override {
CHECK_CUDA_RET_WITH_EXCEPT(cudaEventRecord(record_event_, record_stream_), "Recording cuda event failed.");
return true;
}
bool Init(const CNodePtr &kernel_node) override {
record_stream_ = reinterpret_cast<cudaStream_t>(GetAttr<uintptr_t>(kernel_node, "record_event_stream"));
record_event_ = reinterpret_cast<cudaEvent_t>(GetAttr<uintptr_t>(kernel_node, "record_event"));
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.clear();
output_size_list_.clear();
workspace_size_list_.clear();
return;
}
private:
cudaStream_t record_stream_{nullptr};
cudaEvent_t record_event_{nullptr};
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
} // namespace kernel
} // namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
...@@ -124,6 +124,12 @@ class NcclGpuKernel : public GpuKernel { ...@@ -124,6 +124,12 @@ class NcclGpuKernel : public GpuKernel {
InferCommType(kernel_node); InferCommType(kernel_node);
collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle(); collective_handle_ = device::gpu::CollectiveInitializer::instance().collective_handle();
MS_EXCEPTION_IF_NULL(collective_handle_); MS_EXCEPTION_IF_NULL(collective_handle_);
auto comm_stream_attr = AnfAlgo::GetCNodePrimitive(kernel_node)->GetAttr("stream_id");
if (comm_stream_attr) {
comm_stream_ = reinterpret_cast<cudaStream_t>(GetValue<uintptr_t>(comm_stream_attr));
MS_EXCEPTION_IF_NULL(comm_stream_);
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册