Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
333045d7
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
333045d7
编写于
10月 19, 2017
作者:
D
Dong Zhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"move nccl to another directory"
上级
fdfc8f9b
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
215 addition
and
147 deletion
+215
-147
paddle/operators/CMakeLists.txt
paddle/operators/CMakeLists.txt
+15
-1
paddle/operators/nccl/CMakeLists.txt
paddle/operators/nccl/CMakeLists.txt
+2
-6
paddle/operators/nccl/nccl_gpu_common.cc
paddle/operators/nccl/nccl_gpu_common.cc
+12
-56
paddle/operators/nccl/nccl_gpu_common.h
paddle/operators/nccl/nccl_gpu_common.h
+13
-48
paddle/operators/nccl_op.cc
paddle/operators/nccl_op.cc
+32
-25
paddle/operators/nccl_op.cu
paddle/operators/nccl_op.cu
+66
-0
paddle/operators/nccl_op.h
paddle/operators/nccl_op.h
+50
-0
python/paddle/v2/framework/tests/test_nccl_ops.py
python/paddle/v2/framework/tests/test_nccl_ops.py
+25
-11
未找到文件。
paddle/operators/CMakeLists.txt
浏览文件 @
333045d7
...
...
@@ -76,6 +76,14 @@ function(op_library TARGET)
file
(
APPEND
${
pybind_file
}
"USE_OP(sigmoid);
\n
"
)
endif
()
# nccl_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"nccl_op"
)
set
(
pybind_flag 1
)
# It's enough to just adding one operator to pybind
file
(
APPEND
${
pybind_file
}
"USE_OP(ncclInit);
\n
"
)
# file(APPEND ${pybind_file} "USE_OP(ncclInit);\n")
endif
()
# reduce_op contains several operators
if
(
"
${
TARGET
}
"
STREQUAL
"reduce_op"
)
set
(
pybind_flag 1
)
...
...
@@ -116,7 +124,9 @@ set(DEPS_OPS
softmax_with_cross_entropy_op
sum_op
pool_op
pool_with_index_op
)
pool_with_index_op
nccl_op
)
op_library
(
recurrent_op SRCS recurrent_op.cc rnn/recurrent_op_utils.cc
...
...
@@ -127,6 +137,9 @@ op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax)
op_library
(
sum_op DEPS net_op
)
op_library
(
pool_op DEPS pooling
)
op_library
(
pool_with_index_op DEPS pooling
)
if
(
WITH_GPU
)
op_library
(
nccl_op DEPS nccl_common
)
endif
()
list
(
REMOVE_ITEM GENERAL_OPS
${
DEPS_OPS
}
)
foreach
(
src
${
GENERAL_OPS
}
)
...
...
@@ -134,6 +147,7 @@ foreach(src ${GENERAL_OPS})
endforeach
()
set
(
GLOB_OP_LIB
${
OP_LIBRARY
}
CACHE INTERNAL
"Global OP library"
)
message
(
STATUS
"operators_list:
${
OP_LIBRARY
}
"
)
cc_test
(
gather_test SRCS gather_test.cc DEPS tensor
)
cc_test
(
net_op_test SRCS net_op_test.cc DEPS net_op
)
...
...
paddle/operators/nccl/CMakeLists.txt
浏览文件 @
333045d7
if
(
WITH_GPU
)
nv_library
(
nccl_common SRCS nccl_gpu_common DEPS device_context operator
)
nv_library
(
nccl_op SRCS nccl_ops.cc DEPS nccl_common
)
else
()
cc_library
(
nccl_common SRCS nccl_gpu_common DEPS device_context operator
)
nv_library
(
nccl_common SRCS nccl_gpu_common.cc DEPS device_context operator
)
nv_test
(
nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common
)
endif
()
cc_test
(
nccl_gpu_common_test SRCS nccl_gpu_common_test.cc DEPS nccl_common
)
paddle/operators/nccl/nccl_gpu_common.cc
浏览文件 @
333045d7
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl/nccl_gpu_common.h"
#include "paddle/platform/gpu_info.h"
namespace
paddle
{
namespace
platform
{
NCCLManager
::
NCCLManager
()
{}
NCCLManager
::~
NCCLManager
()
{
for
(
auto
&
p
:
comm_table
)
{
auto
&
comm
=
p
.
second
;
auto
&
gpus_
=
comm
->
gpus_
;
for
(
size_t
i
=
0
;
i
<
gpus_
.
size
();
++
i
)
{
int
gid
=
gpus_
[
i
];
platform
::
SetDeviceId
(
gid
);
// mapping gid to idx
int
idx
=
gid
%
gpus_
.
size
();
// wait finish
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
comm
->
streams_
[
idx
],
comm
->
events_
[
idx
],
0
));
PADDLE_ENFORCE
(
cudaEventDestroy
(
comm
->
events_
[
idx
]));
PADDLE_ENFORCE
(
ncclCommDestroy
(
comm
->
comms_
[
idx
]));
}
comm
.
reset
(
nullptr
);
}
}
Communicator
*
NCCLManager
::
GetCommunicator
(
const
std
::
vector
<
int
>&
gpus
)
{
std
::
string
key
;
for
(
auto
&
id
:
gpus
)
{
key
+=
std
::
to_string
(
id
);
}
std
::
sort
(
key
.
begin
(),
key
.
end
());
std
::
mutex
mu
;
std
::
lock_guard
<
std
::
mutex
>
lk
(
mu
);
auto
it
=
comm_table
.
find
(
key
);
if
(
it
->
second
==
nullptr
)
{
auto
*
comm
=
new
Communicator
(
gpus
);
PADDLE_ENFORCE
(
ncclCommInitAll
(
comm
->
comms_
.
data
(),
gpus
.
size
(),
gpus
.
data
()));
for
(
size_t
i
=
0
;
i
<
gpus
.
size
();
++
i
)
{
platform
::
SetDeviceId
(
gpus
[
i
]);
// block wait
PADDLE_ENFORCE
(
cudaEventCreateWithFlags
(
&
comm
->
events_
[
i
],
cudaEventBlockingSync
|
cudaEventDisableTiming
));
}
comm_table
[
key
].
reset
(
comm
);
}
return
comm_table
[
key
].
get
();
}
}
// namespace operators
namespace
platform
{}
// namespace platform
}
// namespace paddle
paddle/operators/nccl/nccl_gpu_common.h
浏览文件 @
333045d7
...
...
@@ -65,65 +65,30 @@ class WaitGroup {
std
::
condition_variable
cv_
;
};
// TODO(dzh) : make resources managed unified with framework
struct
Communicator
{
std
::
vector
<
ncclComm_t
>
comms_
;
std
::
vector
<
cudaStream_t
>
streams_
;
std
::
vector
<
cudaEvent_t
>
events_
;
std
::
vector
<
int
>
gpus_
;
WaitGroup
wg_
;
int
root_gpu
=
-
1
;
// cudaEvent_t root_monitor;
explicit
Communicator
(
const
std
::
vector
<
int
>&
gpus
)
:
gpus_
(
gpus
)
{
std
::
unordered_map
<
int
,
int
>
comm_id_map_
;
int
GetCommId
(
int
device_id
)
const
{
return
comm_id_map_
.
at
(
device_id
);
}
void
InitAll
(
const
std
::
vector
<
int
>&
gpus
)
{
comms_
.
resize
(
gpus
.
size
());
streams_
.
resize
(
gpus
.
size
());
events_
.
resize
(
gpus
.
size
());
for
(
size_t
i
=
0
;
i
<
gpus
.
size
();
++
i
)
{
comm_id_map_
[
gpus
[
i
]]
=
i
;
}
PADDLE_ENFORCE
(
ncclCommInitAll
(
comms_
.
data
(),
gpus
.
size
(),
gpus
.
data
()));
}
~
Communicator
()
{
for
(
size_t
i
=
0
;
i
<
gpus_
.
size
();
++
i
)
{
int
gid
=
gpus_
[
i
];
platform
::
SetDeviceId
(
gid
);
int
idx
=
gid
%
gpus_
.
size
();
// wait finish
PADDLE_ENFORCE
(
cudaStreamWaitEvent
(
comm
->
streams_
[
idx
],
comm
->
events_
[
idx
],
0
));
PADDLE_ENFORCE
(
cudaEventDestroy
(
comm
->
events_
[
idx
]));
PADDLE_ENFORCE
(
ncclCommDestroy
(
comm
->
comms_
[
idx
]));
for
(
size_t
i
=
0
;
i
<
comms_
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
ncclCommDestroy
(
comms_
[
i
]));
}
}
inline
int
get_root_gpu
()
const
{
return
root_gpu
;
}
inline
void
set_root_gpu
(
int
id
)
{
root_gpu
=
id
;
}
// DISABLE_COPY_AND_ASSIGN(Communicator);
};
class
NCCLManager
{
public:
static
NCCLManager
*
Get
()
{
static
NCCLManager
m
;
return
&
m
;
}
NCCLManager
();
~
NCCLManager
();
// for each card only have one communicator
Communicator
*
GetCommunicator
(
const
std
::
vector
<
int
>&
gpus
);
private:
// // the gpu id list available. Note that only support
// // whole world communication.
// std::vector<int> _gpu_worlds;
// communicator list
std
::
unordered_map
<
std
::
string
/* key*/
,
std
::
unique_ptr
<
Communicator
>>
comm_table
;
};
Communicator
*
NewCommunicator
(
const
std
::
vector
<
int
>&
gpus
);
}
// namespace platform
}
// namespace paddle
paddle/operators/nccl
/nccl_ops
.cc
→
paddle/operators/nccl
_op
.cc
浏览文件 @
333045d7
...
...
@@ -9,7 +9,7 @@
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/nccl
/nccl_ops
.h"
#include "paddle/operators/nccl
_op
.h"
namespace
paddle
{
namespace
operators
{
...
...
@@ -85,31 +85,36 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
// // BcastSendOp
// class NCCLBcastSendOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddInput("X", "The input of BcastSend op");
// AddComment(R"DOC(
// BcastSend the tensors.
// )DOC");
// }
// };
// BcastOp
class
NCCLBcastOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
NCCLAllBcastOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of Bcast op"
);
AddInput
(
"Communicator"
,
"Communicator for communicating between gpus"
);
AddInput
(
"root"
,
"root gpu of Bcast"
);
AddComment
(
R"DOC(
Bcast the tensors.
)DOC"
);
}
};
// // BcastRecvOp
// class NCCLBcastRecvOpMaker : public framework::OpProtoAndCheckerMaker {
// public:
// NCCLAllReduceOpMaker(framework::OpProto *proto,
// framework::OpAttrChecker *op_checker)
// : OpProtoAndCheckerMaker(proto, op_checker) {
// AddOutput("Out", "The output of BcastRecv op");
// AddComment(R"DOC(
// BcastRecv the tensors.
// )DOC");
// }
// };
// BcastRecvOp
class
NCCLReduceOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
NCCLReduceOpMaker
(
framework
::
OpProto
*
proto
,
framework
::
OpAttrChecker
*
op_checker
)
:
OpProtoAndCheckerMaker
(
proto
,
op_checker
)
{
AddInput
(
"X"
,
"The input of Reduce op"
);
AddInput
(
"Communicator"
,
"Communicator for communicating between gpus"
);
AddInput
(
"root"
,
"root gpu of Reduce"
);
AddOutput
(
"Out"
,
"The output of Reduce op"
);
AddComment
(
R"DOC(
Reduce the tensors.
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
...
...
@@ -117,3 +122,5 @@ class NCCLAllReduceOpMaker : public framework::OpProtoAndCheckerMaker {
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
ncclAllReduce
,
ops
::
NCCLAllReduceOp
,
ops
::
NCCLAllReduceOpMaker
);
REGISTER_OP_WITHOUT_GRADIENT
(
ncclInit
,
ops
::
NCCLInitOp
,
ops
::
NCCLInitOpMaker
);
REGISTER_OP_CPU_KERNEL
(
ncclInit
,
ops
::
NCCLInitKernel
<
float
>
);
paddle/operators/nccl
/nccl_ops
.cu
→
paddle/operators/nccl
_op
.cu
浏览文件 @
333045d7
...
...
@@ -10,7 +10,57 @@ See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/nccl/nccl_ops.h"
#include "paddle/operators/nccl_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
NCCLAllReduceKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
"This kernel only runs on GPU device."
);
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
outs
=
ctx
.
MultiOutput
<
Tensor
>
(
"Out"
);
std
::
string
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
ncclRedOp_t
op_type
;
if
(
reduction
==
"ncclSum"
)
{
op_type
=
ncclSum
;
}
else
if
(
reduction
==
"ncclProd"
)
{
op_type
=
ncclProd
;
}
else
if
(
reduction
==
"ncclMin"
)
{
op_type
=
ncclMin
;
}
else
if
(
reduction
==
"ncclMax"
)
{
op_type
=
ncclMax
;
}
else
{
PADDLE_ENFORCE
(
false
,
"reduction error."
);
}
auto
*
comm
=
ctx
.
Input
<
Communicator
>
(
"Communicator"
);
auto
stream
=
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
ctx
.
device_context
())
.
stream
();
// device id
int
device_id
=
boost
::
get
<
platform
::
GPUPlace
>
(
ctx
.
GetPlace
()).
GetDeviceId
();
int
idx
=
comm
->
GetCommId
(
device_id
);
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
ncclAllReduce
(
ins
[
i
]
->
data
<
T
>
(),
outs
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
outs
[
i
]
->
numel
()
*
sizeof
(
T
),
NCCLTypeWrapper
<
T
>::
type
,
op_type
,
comm
->
comms_
[
idx
],
stream
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream
));
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
ncclAllReduce
,
ops
::
NCCLAllReduceKernel
<
float
>
);
\ No newline at end of file
REGISTER_OP_GPU_KERNEL
(
ncclAllReduce
,
ops
::
NCCLAllReduceKernel
<
float
>
);
paddle/operators/nccl
/nccl_ops
.h
→
paddle/operators/nccl
_op
.h
浏览文件 @
333045d7
...
...
@@ -19,6 +19,7 @@ namespace paddle {
namespace
operators
{
using
framework
::
Tensor
;
using
platform
::
Communicator
;
template
<
typename
Type
>
class
NCCLTypeWrapper
;
...
...
@@ -35,67 +36,13 @@ class NCCLTypeWrapper<double> {
static
const
ncclDataType_t
type
=
ncclDouble
;
};
class
NCCLInitOp
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
gpus
=
ctx
.
Input
<
std
::
vector
<
int
>>
(
"gpus"
);
auto
*
comm
=
ctx
.
Output
<
Communicator
>
(
"Communicator"
);
comm
->
mutable_data
<
Communicator
>
(
CPUPlace
());
comm
=
NCCLManager
::
GetCommunicator
(
gpus
);
}
};
template
<
typename
T
>
class
NCCL
AllReduce
Kernel
:
public
framework
::
OpKernel
<
T
>
{
class
NCCL
Init
Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
ins
=
ctx
.
MultiInput
<
Tensor
>
(
"X"
);
auto
outs
=
ctx
.
MultiOutput
<
Tensor
>
(
"Out"
);
std
::
string
reduction
=
ctx
.
Attr
<
std
::
string
>
(
"reduction"
);
std
::
vector
<
int
>
gpus
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"gpus"
);
ncclRedOp_t
op_type
;
if
(
reduction
==
"ncclSum"
)
{
op_type
=
ncclSum
;
}
else
if
(
reduction
==
"ncclProd"
)
{
op_type
=
ncclProd
;
}
else
if
(
reduction
==
"ncclMin"
)
{
op_type
=
ncclMin
;
}
else
if
(
reduction
==
"ncclMax"
)
{
op_type
=
ncclMax
;
}
auto
*
comm
=
ctx
.
Input
<
Communicator
>
(
"Communicator"
);
auto
dev_ctx
=
static_cast
<
const
platform
::
CUDADeviceContext
>
(
ctx
.
device_context
());
// platform::NCCLManager* m = platform::NCCLManager::Get();
// auto* comm = m->GetCommunicator(gpus);
// comm->wg_.Add(1);
auto
stream
=
dev_ctx
.
stream
();
// device id
int
gid
=
static_cast
<
platform
::
GPUPlace
>
(
ctx
.
GetPlace
()).
GetDeviceId
();
int
idx
=
gid
%
gpus
.
size
();
comm
->
streams_
[
idx
]
=
stream
;
for
(
size_t
i
=
0
;
i
<
ins
.
size
();
++
i
)
{
PADDLE_ENFORCE
(
ncclAllReduce
(
ins
[
i
]
->
data
<
T
>
(),
outs
[
i
]
->
mutable_data
<
T
>
(),
outs
[
i
]
->
numel
()
*
sizeof
(
T
),
NCCLTypeWrapper
<
T
>::
type
,
op_type
,
comm
->
comms_
[
idx
],
comm
->
streams_
[
idx
]));
PADDLE_ENFORCE
(
cudaEventRecord
(
comm
->
events_
[
idx
],
comm
->
streams_
[
idx
]));
// // wait finish
// PADDLE_ENFORCE(
// cudaStreamWaitEvent(comm->streams_[idx], comm->events_[idx], 0));
}
// comm->wg_.Done();
// comm->wg_.Wait();
auto
*
gpus
=
ctx
.
Input
<
std
::
vector
<
int
>>
(
"gpus"
);
auto
*
comm
=
ctx
.
Output
<
Communicator
>
(
"Communicator"
);
comm
->
InitAll
(
*
gpus
);
}
};
...
...
python/paddle/v2/framework/tests/test_nccl_ops.py
浏览文件 @
333045d7
...
...
@@ -5,13 +5,15 @@ from paddle.v2.framework.op import Operator
import
paddle.v2.framework.core
as
core
from
op_test
import
OpTest
,
create_op
,
set_input
gpu_list
=
os
.
environ
[
"NV_LIST"
]
# gpu_list = os.environ["NV_LIST"]
gpu_list
=
"0,1,2,3"
if
not
core
.
is_compile_gpu
()
or
not
gpu_list
:
exit
(
0
)
def
allreduce
(
tensors
,
num_device
):
def
allreduce
(
tensors
,
gpus
):
num_device
=
len
(
gpus
)
assert
(
len
(
tensors
)
==
num_device
),
"not match of tensor and device"
Out
=
tensors
for
i
in
range
(
1
,
len
(
tensors
)):
...
...
@@ -24,23 +26,32 @@ def allreduce(tensors, num_device):
class
TestNCCLAllReduce
(
unittest
.
TestCase
):
def
__init__
(
self
):
self
.
op_type
=
"nnclAllReduce"
def
setUp
(
self
):
self
.
gpus
=
[
int
(
g
)
for
g
in
gpu_list
]
self
.
op_type
=
"ncclAllReduce"
self
.
gpus
=
[
int
(
g
)
for
g
in
gpu_list
.
split
(
","
)]
self
.
g_scope
=
core
.
Scope
()
self
.
g_ctx
=
core
.
DeviceContext
.
create
(
core
.
CPUPlace
())
self
.
scopes
=
[]
self
.
ops
=
[]
self
.
places
=
[]
self
.
input_data
=
[]
for
i
in
range
(
len
(
self
.
gpus
)):
input_data
.
append
(
np
.
random
.
random
((
32
,
32
)))
self
.
output_data
=
allreduce
(
input_data
)
self
.
input_data
.
append
(
np
.
random
.
random
((
32
,
32
)))
self
.
output_data
=
allreduce
(
self
.
input_data
,
self
.
gpus
)
nccl_init
=
Operator
(
"ncclInit"
,
Out
=
"Communicator"
,
gpus
=
self
.
gpus
)
op
.
run
(
self
.
g_scope
,
self
.
g_ctx
)
for
i
in
range
(
len
(
self
.
gpus
)):
scope
=
core
.
Scope
()
# insert kid scope
scope
=
self
.
g_scope
.
new_scope
()
place
=
core
.
GPUPlace
(
self
.
gpus
[
i
])
inputs
=
{
"X"
:
self
.
input_data
[
i
]}
outputs
=
{
"Out"
:
self
.
output_data
[
i
]}
attrs
=
{
"gpus"
:
self
.
gpus
}
...
...
@@ -66,8 +77,11 @@ class TestNCCLAllReduce(unittest.TestCase):
self
.
assertTrue
(
actual
,
expect
),
"has diff"
if
__name__
==
"__main__"
:
# usage : export NV_LIST=0,1,2,3 python *.py
# if __name__ == "__main__":
# unittest.main()
# usage : export NV_LIST=0,1,2,3 python *.py
# os.environ["NV_LIST"] = ["0,1,2,3"]
os
.
environ
[
"NV_LIST"
]
=
[
"0,1,2,3"
]
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录