Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
2ca9e448
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ca9e448
编写于
4月 16, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 16, 2020
浏览文件
操作
浏览文件
下载
差异文件
!360 Add GPU send and recv control kernels
Merge pull request !360 from ZPaC/gpu-backend-supports-multiple-streams
上级
5ed799d7
b8a91215
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
201 addition
and
13 deletion
+201
-13
mindspore/ccsrc/device/gpu/cuda_driver.cc
mindspore/ccsrc/device/gpu/cuda_driver.cc
+1
-1
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
+14
-10
mindspore/ccsrc/device/gpu/gpu_stream_assign.h
mindspore/ccsrc/device/gpu/gpu_stream_assign.h
+2
-2
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc
+23
-0
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h
+66
-0
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc
+23
-0
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h
+66
-0
mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h
mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h
+6
-0
未找到文件。
mindspore/ccsrc/device/gpu/cuda_driver.cc
浏览文件 @
2ca9e448
...
...
@@ -96,7 +96,7 @@ size_t CudaDriver::free_mem_size() {
}
bool
CudaDriver
::
CreateStream
(
DeviceStream
*
stream
)
{
auto
ret
=
cudaStreamCreate
(
reinterpret_cast
<
CUstream_st
**>
(
stream
)
);
auto
ret
=
cudaStreamCreate
WithFlags
(
reinterpret_cast
<
CUstream_st
**>
(
stream
),
cudaStreamNonBlocking
);
if
(
ret
!=
cudaSuccess
)
{
MS_LOG
(
ERROR
)
<<
"cudaStreamCreate failed, ret["
<<
static_cast
<
int
>
(
ret
)
<<
"], "
<<
cudaGetErrorString
(
ret
);
return
false
;
...
...
mindspore/ccsrc/device/gpu/gpu_stream_assign.cc
浏览文件 @
2ca9e448
...
...
@@ -28,21 +28,25 @@ namespace device {
namespace
gpu
{
void
AssignGpuStream
(
const
std
::
shared_ptr
<
session
::
KernelGraph
>
&
kernel_graph
)
{
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
CNodePtr
>
allreduce_
cnode
s
;
std
::
vector
<
CNodePtr
>
allreduce_
kernel
s
;
auto
execution_kernels
=
kernel_graph
->
execution_order
();
for
(
auto
kernel
:
execution_kernels
)
{
std
::
string
kernel_name
=
AnfAlgo
::
GetCNodeName
(
kernel
);
for
(
auto
kernel
_node
:
execution_kernels
)
{
std
::
string
kernel_name
=
AnfAlgo
::
GetCNodeName
(
kernel
_node
);
if
(
kernel_name
==
kAllReduceOpName
)
{
allreduce_cnodes
.
emplace_back
(
kernel
);
allreduce_kernels
.
emplace_back
(
kernel_node
);
}
else
{
DeviceStream
compute_stream
=
GPUDeviceManager
::
GetInstance
().
default_stream
();
AnfAlgo
::
SetNodeAttr
(
"stream_id"
,
MakeValue
(
reinterpret_cast
<
uintptr_t
>
(
compute_stream
)),
kernel_node
);
}
}
if
(
allreduce_
cnode
s
.
size
()
>
1
)
{
if
(
allreduce_
kernel
s
.
size
()
>
1
)
{
DeviceStream
comm_stream
=
nullptr
;
GPUDeviceManager
::
GetInstance
().
CreateStream
(
&
comm_stream
);
std
::
transform
(
allreduce_cnodes
.
begin
(),
allreduce_cnodes
.
end
(),
allreduce_cnodes
.
begin
(),
[
&
](
CNodePtr
node
)
{
AnfAlgo
::
SetNodeAttr
(
"stream_id"
,
MakeValue
(
reinterpret_cast
<
uintptr_t
>
(
comm_stream
)),
node
);
return
node
;
});
std
::
transform
(
allreduce_kernels
.
begin
(),
allreduce_kernels
.
end
(),
allreduce_kernels
.
begin
(),
[
&
](
CNodePtr
allreduce_kernel
)
{
AnfAlgo
::
SetNodeAttr
(
"stream_id"
,
MakeValue
(
reinterpret_cast
<
uintptr_t
>
(
comm_stream
)),
allreduce_kernel
);
return
allreduce_kernel
;
});
std
::
vector
<
SendRecvPair
>
send_recv_pairs
;
FindAllReduceStreamSwitchPos
(
kernel_graph
,
&
send_recv_pairs
);
...
...
@@ -137,7 +141,7 @@ void InsertStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &kernel_
}
// Step 3: insert stream switch CNodes into execution kernel list.
auto
execution_kernels
=
kernel_graph
->
execution_order
();
for
(
auto
node
=
ordered_stream_switch_nodes
.
begin
();
node
!=
ordered_stream_switch_nodes
.
end
();
node
++
)
{
for
(
auto
node
=
ordered_stream_switch_nodes
.
rbegin
();
node
!=
ordered_stream_switch_nodes
.
r
end
();
node
++
)
{
execution_kernels
.
insert
(
execution_kernels
.
begin
()
+
node
->
offset
,
node
->
cnode
);
}
kernel_graph
->
set_execution_order
(
execution_kernels
);
...
...
mindspore/ccsrc/device/gpu/gpu_stream_assign.h
浏览文件 @
2ca9e448
/**
* Copyright 20
19
Huawei Technologies Co., Ltd
* Copyright 20
20
Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
...
...
@@ -70,4 +70,4 @@ CNodePtr CreateStreamSwitchNode(const std::shared_ptr<session::KernelGraph> &ker
}
// namespace device
}
// namespace mindspore
#endif
#endif
// MINDSPORE_CCSRC_DEVICE_GPU_GPU_STREAM_ASSIGN_H_
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.cc
0 → 100644
浏览文件 @
2ca9e448
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/control/recv_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_REGULAR
(
Recv
,
KernelAttr
(),
RecvGpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/control/recv_gpu_kernel.h
0 → 100644
浏览文件 @
2ca9e448
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace
mindspore
{
namespace
kernel
{
class
RecvGpuKernel
:
public
GpuKernel
{
public:
RecvGpuKernel
()
{}
~
RecvGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
,
uintptr_t
)
override
{
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaStreamWaitEvent
(
wait_stream_
,
wait_event_
,
0
),
"Waiting cuda event failed."
);
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
wait_stream_
=
reinterpret_cast
<
cudaStream_t
>
(
GetAttr
<
uintptr_t
>
(
kernel_node
,
"wait_event_stream"
));
wait_event_
=
reinterpret_cast
<
cudaEvent_t
>
(
GetAttr
<
uintptr_t
>
(
kernel_node
,
"wait_event"
));
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
clear
();
output_size_list_
.
clear
();
workspace_size_list_
.
clear
();
return
;
}
private:
cudaStream_t
wait_stream_
{
nullptr
};
cudaEvent_t
wait_event_
{
nullptr
};
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_RECV_GPU_KERNEL_H_
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.cc
0 → 100644
浏览文件 @
2ca9e448
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "kernel/gpu/control/send_gpu_kernel.h"
namespace
mindspore
{
namespace
kernel
{
MS_REG_GPU_KERNEL_REGULAR
(
Send
,
KernelAttr
(),
SendGpuKernel
)
}
// namespace kernel
}
// namespace mindspore
mindspore/ccsrc/kernel/gpu/control/send_gpu_kernel.h
0 → 100644
浏览文件 @
2ca9e448
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
#include <vector>
#include "kernel/gpu/gpu_kernel.h"
#include "kernel/gpu/gpu_kernel_factory.h"
namespace
mindspore
{
namespace
kernel
{
class
SendGpuKernel
:
public
GpuKernel
{
public:
SendGpuKernel
()
{}
~
SendGpuKernel
()
override
=
default
;
const
std
::
vector
<
size_t
>
&
GetInputSizeList
()
const
override
{
return
input_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetOutputSizeList
()
const
override
{
return
output_size_list_
;
}
const
std
::
vector
<
size_t
>
&
GetWorkspaceSizeList
()
const
override
{
return
workspace_size_list_
;
}
bool
Launch
(
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
,
const
std
::
vector
<
AddressPtr
>
&
,
uintptr_t
)
override
{
CHECK_CUDA_RET_WITH_EXCEPT
(
cudaEventRecord
(
record_event_
,
record_stream_
),
"Recording cuda event failed."
);
return
true
;
}
bool
Init
(
const
CNodePtr
&
kernel_node
)
override
{
record_stream_
=
reinterpret_cast
<
cudaStream_t
>
(
GetAttr
<
uintptr_t
>
(
kernel_node
,
"record_event_stream"
));
record_event_
=
reinterpret_cast
<
cudaEvent_t
>
(
GetAttr
<
uintptr_t
>
(
kernel_node
,
"record_event"
));
InitSizeLists
();
return
true
;
}
protected:
void
InitSizeLists
()
override
{
input_size_list_
.
clear
();
output_size_list_
.
clear
();
workspace_size_list_
.
clear
();
return
;
}
private:
cudaStream_t
record_stream_
{
nullptr
};
cudaEvent_t
record_event_
{
nullptr
};
std
::
vector
<
size_t
>
input_size_list_
;
std
::
vector
<
size_t
>
output_size_list_
;
std
::
vector
<
size_t
>
workspace_size_list_
;
};
}
// namespace kernel
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CONTROL_SEND_GPU_KERNEL_H_
mindspore/ccsrc/kernel/gpu/nccl/nccl_gpu_kernel.h
浏览文件 @
2ca9e448
...
...
@@ -124,6 +124,12 @@ class NcclGpuKernel : public GpuKernel {
InferCommType
(
kernel_node
);
collective_handle_
=
device
::
gpu
::
CollectiveInitializer
::
instance
().
collective_handle
();
MS_EXCEPTION_IF_NULL
(
collective_handle_
);
auto
comm_stream_attr
=
AnfAlgo
::
GetCNodePrimitive
(
kernel_node
)
->
GetAttr
(
"stream_id"
);
if
(
comm_stream_attr
)
{
comm_stream_
=
reinterpret_cast
<
cudaStream_t
>
(
GetValue
<
uintptr_t
>
(
comm_stream_attr
));
MS_EXCEPTION_IF_NULL
(
comm_stream_
);
}
return
true
;
}
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录