Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2ef9e0e2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2ef9e0e2
编写于
12月 09, 2020
作者:
S
ShenLiang
提交者:
GitHub
12月 09, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rebuild group automatically in dynamic graph distributed (#29255)
* add tensor_indices in AssignGroupBySize * add rebuild group in reducer
上级
3a055833
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
318 addition
and
118 deletion
+318
-118
paddle/fluid/imperative/reducer.cc
paddle/fluid/imperative/reducer.cc
+183
-58
paddle/fluid/imperative/reducer.h
paddle/fluid/imperative/reducer.h
+31
-53
paddle/fluid/imperative/tests/CMakeLists.txt
paddle/fluid/imperative/tests/CMakeLists.txt
+4
-0
paddle/fluid/imperative/tests/test_group.cc
paddle/fluid/imperative/tests/test_group.cc
+66
-0
paddle/fluid/pybind/imperative.cc
paddle/fluid/pybind/imperative.cc
+5
-2
python/paddle/distributed/fleet/base/distributed_strategy.py
python/paddle/distributed/fleet/base/distributed_strategy.py
+0
-1
python/paddle/fluid/dygraph/parallel.py
python/paddle/fluid/dygraph/parallel.py
+5
-4
python/paddle/fluid/tests/unittests/test_imperative_group.py
python/paddle/fluid/tests/unittests/test_imperative_group.py
+24
-0
未找到文件。
paddle/fluid/imperative/reducer.cc
浏览文件 @
2ef9e0e2
...
@@ -20,47 +20,98 @@ namespace imperative {
...
@@ -20,47 +20,98 @@ namespace imperative {
#if defined(PADDLE_WITH_NCCL)
#if defined(PADDLE_WITH_NCCL)
std
::
shared_ptr
<
Reducer
>
Reducer
::
s_instance_
=
NULL
;
std
::
shared_ptr
<
Reducer
>
Reducer
::
s_instance_
=
NULL
;
// context is used to select the stream for concat
void
Group
::
ConcatTensors
(
const
platform
::
CUDADeviceContext
&
context
)
{
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP16
:
ConcatTensorsForAllReduce
<
platform
::
float16
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
case
framework
::
proto
::
VarType
::
FP32
:
ConcatTensorsForAllReduce
<
float
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
case
framework
::
proto
::
VarType
::
FP64
:
ConcatTensorsForAllReduce
<
double
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it concats tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
dtype_
)));
}
}
// context is used to select the stream for split
void
Group
::
SplitTensors
(
const
platform
::
CUDADeviceContext
&
context
)
{
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP16
:
SplitTensorsForAllReduce
<
platform
::
float16
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
case
framework
::
proto
::
VarType
::
FP32
:
SplitTensorsForAllReduce
<
float
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
case
framework
::
proto
::
VarType
::
FP64
:
SplitTensorsForAllReduce
<
double
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it splits tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
dtype_
)));
}
}
std
::
ostream
&
operator
<<
(
std
::
ostream
&
out
,
const
Group
&
group
)
{
const
auto
&
vars
=
group
.
variable_indices_
;
out
<<
"numul: "
<<
group
.
all_length_
<<
" ;is_sparse: "
<<
group
.
is_sparse_
<<
" ;var number: "
<<
vars
.
size
()
<<
"
\n
"
;
auto
begin
=
vars
.
begin
();
auto
end
=
vars
.
end
();
out
<<
"["
;
for
(
int
i
=
0
;
begin
!=
end
&&
i
<
100
;
++
i
,
++
begin
)
{
if
(
i
>
0
)
out
<<
' '
;
out
<<
*
begin
;
}
if
(
begin
!=
end
)
{
out
<<
" ..."
;
}
out
<<
"]
\n
"
;
return
out
;
}
Reducer
::
Reducer
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
Reducer
::
Reducer
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
group_indices
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
group_indices
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
)
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
,
const
std
::
vector
<
size_t
>
&
group_size_limits
)
:
vars_
(
vars
),
:
vars_
(
vars
),
group_indices_
(
group_indices
),
group_indices_
(
group_indices
),
is_sparse_gradient_
(
is_sparse_gradient
),
is_sparse_gradient_
(
is_sparse_gradient
),
parallel_ctx_
(
parallel_ctx
)
{
parallel_ctx_
(
parallel_ctx
),
group_size_limits_
(
group_size_limits
)
{
VLOG
(
3
)
<<
"Start construct the Reducer ..."
;
VLOG
(
3
)
<<
"Start construct the Reducer ..."
;
// initialize groups
// initialize groups
InitializeGroups
(
group_indices
);
InitializeGroups
(
group_indices
);
for
(
size_t
global_var_index
=
0
;
global_var_index
<
vars_
.
size
();
{
++
global_var_index
)
{
for
(
size_t
group_index
=
0
;
group_index
<
group_indices
.
size
();
vars_
[
global_var_index
]
->
SharedVar
()
->
AddGradVarLeafBackwardHook
(
++
group_index
)
{
std
::
unique_ptr
<
LambdaGradAccumulatorPostHook
>
(
for
(
size_t
var_index
=
0
;
var_index
<
group_indices
[
group_index
].
size
();
new
LambdaGradAccumulatorPostHook
([
=
](
VariableWrapper
*
grad
)
{
++
var_index
)
{
this
->
AddDistHook
(
grad
,
global_var_index
);
size_t
global_var_index
=
group_indices
[
group_index
][
var_index
];
})));
const
auto
variable_index
=
VariableIndex
{
.
group_index
=
group_index
,
.
inside_group_index
=
var_index
,
};
VLOG
(
3
)
<<
"add hook for var["
<<
vars_
[
global_var_index
]
->
GradVarName
()
<<
"], it's in group ["
<<
group_index
<<
"]"
;
vars_
[
global_var_index
]
->
SharedVar
()
->
AddGradVarLeafBackwardHook
(
std
::
unique_ptr
<
LambdaGradAccumulatorPostHook
>
(
new
LambdaGradAccumulatorPostHook
([
=
](
VariableWrapper
*
grad
)
{
this
->
AddDistHook
(
grad
,
variable_index
);
})));
}
}
}
}
// create streams
compute_stream_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
compute_stream_
=
static_cast
<
platform
::
CUDADeviceContext
*>
(
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
))
platform
::
DeviceContextPool
::
Instance
().
Get
(
place_
))
->
stream
();
->
stream
();
comm_stream_
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
0
,
place_
)
->
stream
();
comm_stream_
=
platform
::
NCCLCommContext
::
Instance
().
Get
(
0
,
place_
)
->
stream
();
events_
.
resize
(
group_indices
.
size
());
// create events
for
(
auto
&
event
:
events_
)
{
CreateGroupEvents
(
group_indices
.
size
());
event
=
platform
::
CudaEventResourcePool
::
Instance
().
New
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
).
device
);
}
comm_enent_
=
platform
::
CudaEventResourcePool
::
Instance
().
New
(
comm_enent_
=
platform
::
CudaEventResourcePool
::
Instance
().
New
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
).
device
);
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
).
device
);
...
@@ -76,7 +127,20 @@ void Reducer::ReleaseReducer() {
...
@@ -76,7 +127,20 @@ void Reducer::ReleaseReducer() {
comm_enent_
.
reset
();
comm_enent_
.
reset
();
}
}
int64_t
Reducer
::
InitializeDenseGroups
(
void
Reducer
::
CreateGroupEvents
(
int
group_num
)
{
// release old events
for
(
auto
&
event
:
events_
)
{
event
.
reset
();
}
events_
.
clear
();
events_
.
resize
(
group_num
);
for
(
auto
&
event
:
events_
)
{
event
=
platform
::
CudaEventResourcePool
::
Instance
().
New
(
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
).
device
);
}
}
void
Reducer
::
InitializeDenseGroups
(
const
std
::
vector
<
size_t
>
&
variable_indices_
,
Group
*
p_group
)
{
const
std
::
vector
<
size_t
>
&
variable_indices_
,
Group
*
p_group
)
{
int64_t
all_length
=
0
;
int64_t
all_length
=
0
;
for
(
size_t
index
=
0
;
index
<
variable_indices_
.
size
();
++
index
)
{
for
(
size_t
index
=
0
;
index
<
variable_indices_
.
size
();
++
index
)
{
...
@@ -85,18 +149,18 @@ int64_t Reducer::InitializeDenseGroups(
...
@@ -85,18 +149,18 @@ int64_t Reducer::InitializeDenseGroups(
const
auto
var_name
=
var
->
Name
();
const
auto
var_name
=
var
->
Name
();
PADDLE_ENFORCE_EQ
(
is_sparse_gradient_
[
variable_index
],
false
,
PADDLE_ENFORCE_EQ
(
is_sparse_gradient_
[
variable_index
],
false
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Tensor
`%s`
's GRAD must be LoDTensor, but received "
"Tensor
%s
's GRAD must be LoDTensor, but received "
"GRAD is SelectedRows"
,
"GRAD is SelectedRows"
,
var_name
));
var_name
));
auto
lod_tensor
=
var
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
lod_tensor
=
var
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
PADDLE_ENFORCE_EQ
(
lod_tensor
->
IsInitialized
(),
true
,
PADDLE_ENFORCE_EQ
(
lod_tensor
->
IsInitialized
(),
true
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"Tensor
`%s`
is not initialized."
,
var_name
));
"Tensor
%s
is not initialized."
,
var_name
));
auto
size
=
lod_tensor
->
numel
();
auto
size
=
lod_tensor
->
numel
();
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
size
,
0
,
platform
::
errors
::
PreconditionNotMet
(
size
,
0
,
platform
::
errors
::
PreconditionNotMet
(
"The number of tensor
`%s`
's elements is 0."
,
var_name
));
"The number of tensor
%s
's elements is 0."
,
var_name
));
all_length
+=
size
;
all_length
+=
size
;
p_group
->
length_
.
push_back
(
size
);
p_group
->
length_
.
push_back
(
size
);
...
@@ -124,7 +188,7 @@ int64_t Reducer::InitializeDenseGroups(
...
@@ -124,7 +188,7 @@ int64_t Reducer::InitializeDenseGroups(
place_
=
place
;
place_
=
place
;
}
}
}
}
return
all_length
;
p_group
->
all_length_
=
all_length
;
}
}
// Each parameter will be initialized according to the group information.
// Each parameter will be initialized according to the group information.
...
@@ -137,6 +201,8 @@ void Reducer::InitializeGroups(
...
@@ -137,6 +201,8 @@ void Reducer::InitializeGroups(
// clear the group
// clear the group
groups_
.
clear
();
groups_
.
clear
();
groups_
.
reserve
(
group_indices
.
size
());
groups_
.
reserve
(
group_indices
.
size
());
variable_locators_
.
clear
();
variable_locators_
.
resize
(
vars_
.
size
());
auto
group_nums
=
group_indices
.
size
();
auto
group_nums
=
group_indices
.
size
();
for
(
size_t
group_index
=
0
;
group_index
<
group_nums
;
++
group_index
)
{
for
(
size_t
group_index
=
0
;
group_index
<
group_nums
;
++
group_index
)
{
...
@@ -144,10 +210,8 @@ void Reducer::InitializeGroups(
...
@@ -144,10 +210,8 @@ void Reducer::InitializeGroups(
PADDLE_ENFORCE_GT
(
PADDLE_ENFORCE_GT
(
variable_indices_
.
size
(),
0
,
variable_indices_
.
size
(),
0
,
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"The number of group
_index[`%d`
]'s elements is 0."
,
group_index
));
"The number of group
[%d
]'s elements is 0."
,
group_index
));
Group
group
;
Group
group
;
group
.
variable_indices_
=
variable_indices_
;
int64_t
all_length
=
0
;
// It's just for check the sparse or dense
// It's just for check the sparse or dense
auto
first_varbase
=
vars_
[
variable_indices_
.
front
()];
auto
first_varbase
=
vars_
[
variable_indices_
.
front
()];
...
@@ -159,17 +223,27 @@ void Reducer::InitializeGroups(
...
@@ -159,17 +223,27 @@ void Reducer::InitializeGroups(
group
.
is_sparse_
=
true
;
group
.
is_sparse_
=
true
;
}
else
{
}
else
{
// process the dense gradient.
// process the dense gradient.
all_length
=
InitializeDenseGroups
(
variable_indices_
,
&
group
);
InitializeDenseGroups
(
variable_indices_
,
&
group
);
// Alloc the continuous space
// Alloc the continuous space
auto
tensor
=
group
.
dense_contents_
.
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
group
.
dense_contents_
.
GetMutable
<
framework
::
LoDTensor
>
();
tensor
->
Resize
(
framework
::
make_ddim
({
all_length
}))
tensor
->
Resize
(
framework
::
make_ddim
({
group
.
all_length_
}))
.
mutable_data
(
place_
,
group
.
dtype_
);
.
mutable_data
(
place_
,
group
.
dtype_
);
}
}
// Debug Message For Reducer
VLOG
(
3
)
<<
"the groups_["
<<
group_index
<<
"] basic message:"
;
// map variables to this group by VariableLocator
VLOG
(
3
)
<<
"numul: "
<<
all_length
<<
" ;is_sparse: "
<<
group
.
is_sparse_
size_t
inside_group_index
=
0
;
<<
" ;var number: "
<<
group
.
variable_indices_
.
size
();
for
(
const
auto
var_index
:
group_indices
[
group_index
])
{
variable_locators_
[
var_index
]
=
VariableLocator
{
.
group_index
=
group_index
,
.
inside_group_index
=
inside_group_index
++
,
};
}
group
.
variable_indices_
=
std
::
move
(
variable_indices_
);
groups_
.
emplace_back
(
std
::
move
(
group
));
groups_
.
emplace_back
(
std
::
move
(
group
));
// Debug Message For Reducer
VLOG
(
3
)
<<
"The Group["
<<
group_index
<<
"]:"
;
VLOG
(
3
)
<<
groups_
.
back
();
}
}
}
}
...
@@ -192,11 +266,16 @@ void Reducer::PrepareForBackward() {
...
@@ -192,11 +266,16 @@ void Reducer::PrepareForBackward() {
// counter is 0, it means that allreduce can be emitted, and
// counter is 0, it means that allreduce can be emitted, and
// concat + allreduce + split is emitted in turn according to next_group_.
// concat + allreduce + split is emitted in turn according to next_group_.
// 3, FinalizeBackward: after the end, synchronize each stream.
// 3, FinalizeBackward: after the end, synchronize each stream.
void
Reducer
::
AddDistHook
(
VariableWrapper
*
var_warpper
,
void
Reducer
::
AddDistHook
(
VariableWrapper
*
var_warpper
,
size_t
var_index
)
{
const
VariableIndex
&
var_index
)
{
const
auto
&
var_locator
=
variable_locators_
[
var_index
];
auto
group_index
=
var_
index
.
group_index
;
auto
group_index
=
var_
locator
.
group_index
;
auto
&
group
=
groups_
[
group_index
];
auto
&
group
=
groups_
[
group_index
];
if
(
!
has_rebuilt_group_
)
{
rebuild_vars_
.
push_back
(
vars_
[
var_index
]);
rebuild_var_indices_
.
push_back
(
var_index
);
}
if
(
!
group
.
is_sparse_
)
{
if
(
!
group
.
is_sparse_
)
{
// Only dense_contents_ need memory copy
// Only dense_contents_ need memory copy
MarkVariableReady
(
var_index
,
var_warpper
);
MarkVariableReady
(
var_index
,
var_warpper
);
...
@@ -211,21 +290,22 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper,
...
@@ -211,21 +290,22 @@ void Reducer::AddDistHook(VariableWrapper *var_warpper,
}
}
}
}
void
Reducer
::
MarkVariableReady
(
const
VariableIndex
&
var_index
,
void
Reducer
::
MarkVariableReady
(
size_t
var_index
,
VariableWrapper
*
var_warpper
)
{
VariableWrapper
*
var_warpper
)
{
auto
group_index
=
var_index
.
group_index
;
const
auto
&
var_locator
=
variable_locators_
[
var_index
];
auto
variable_index
=
var_index
.
inside_group_index
;
auto
group_index
=
var_locator
.
group_index
;
auto
inside_group_index
=
var_locator
.
inside_group_index
;
auto
&
group
=
groups_
[
group_index
];
auto
&
group
=
groups_
[
group_index
];
auto
length
=
group
.
length_
[
variable
_index
];
auto
length
=
group
.
length_
[
inside_group
_index
];
auto
tensor
=
var_warpper
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
tensor
=
var_warpper
->
MutableVar
()
->
GetMutable
<
framework
::
LoDTensor
>
();
group
.
dense_tensors_
[
variable
_index
].
ShareDataWith
(
*
tensor
).
Resize
(
group
.
dense_tensors_
[
inside_group
_index
].
ShareDataWith
(
*
tensor
).
Resize
(
{
static_cast
<
int64_t
>
(
length
)});
{
static_cast
<
int64_t
>
(
length
)});
}
}
void
Reducer
::
MarkGroupReady
(
size_t
group_index
)
{
void
Reducer
::
MarkGroupReady
(
size_t
group_index
)
{
if
(
group_index
>
next_group_
)
{
if
(
group_index
>
next_group_
)
{
VLOG
(
3
)
<<
"
Maybe it need adjust the order of group
"
;
VLOG
(
3
)
<<
"
It will adjust the order of group in next batch automatically
"
;
return
;
return
;
}
}
...
@@ -257,10 +337,31 @@ void Reducer::MarkGroupReady(size_t group_index) {
...
@@ -257,10 +337,31 @@ void Reducer::MarkGroupReady(size_t group_index) {
}
}
}
}
std
::
vector
<
std
::
vector
<
size_t
>>
Reducer
::
RebuildGruops
()
{
std
::
reverse
(
rebuild_vars_
.
begin
(),
rebuild_vars_
.
end
());
std
::
reverse
(
rebuild_var_indices_
.
begin
(),
rebuild_var_indices_
.
end
());
auto
rebuild_group_indices
=
AssignGroupBySize
(
rebuild_vars_
,
is_sparse_gradient_
,
group_size_limits_
,
rebuild_var_indices_
);
has_rebuilt_group_
=
true
;
rebuild_vars_
.
clear
();
rebuild_var_indices_
.
clear
();
std
::
reverse
(
rebuild_group_indices
.
begin
(),
rebuild_group_indices
.
end
());
return
rebuild_group_indices
;
}
void
Reducer
::
FinalizeBackward
()
{
void
Reducer
::
FinalizeBackward
()
{
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
comm_enent_
.
get
(),
comm_stream_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaEventRecord
(
comm_enent_
.
get
(),
comm_stream_
));
PADDLE_ENFORCE_CUDA_SUCCESS
(
PADDLE_ENFORCE_CUDA_SUCCESS
(
cudaStreamWaitEvent
(
compute_stream_
,
comm_enent_
.
get
(),
0
));
cudaStreamWaitEvent
(
compute_stream_
,
comm_enent_
.
get
(),
0
));
if
(
!
has_rebuilt_group_
)
{
VLOG
(
3
)
<<
"Start rebuilding the groups"
;
auto
rebuild_group_indices
=
RebuildGruops
();
auto
rebuild_group_number
=
rebuild_group_indices
.
size
();
group_indices_
=
std
::
move
(
rebuild_group_indices
);
CreateGroupEvents
(
rebuild_group_number
);
InitializeGroups
(
group_indices_
);
}
VLOG
(
3
)
<<
"In the batch, Reducer is finished..."
;
VLOG
(
3
)
<<
"In the batch, Reducer is finished..."
;
}
}
...
@@ -274,12 +375,28 @@ void Reducer::FinalizeBackward() {
...
@@ -274,12 +375,28 @@ void Reducer::FinalizeBackward() {
std
::
vector
<
std
::
vector
<
size_t
>>
AssignGroupBySize
(
std
::
vector
<
std
::
vector
<
size_t
>>
AssignGroupBySize
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
const
std
::
vector
<
size_t
>
&
group_size_limits
)
{
const
std
::
vector
<
size_t
>
&
group_size_limits
,
const
std
::
vector
<
int64_t
>
&
tensor_indices
)
{
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
is_sparse_gradient
.
size
(),
PADDLE_ENFORCE_EQ
(
vars
.
size
(),
is_sparse_gradient
.
size
(),
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"vars len must be equal to is_sparse_gradient len, but "
"vars len must be equal to is_sparse_gradient len, but "
"[%lu] != [%lu]"
,
"[%lu] != [%lu]"
,
vars
.
size
(),
is_sparse_gradient
.
size
()));
vars
.
size
(),
is_sparse_gradient
.
size
()));
auto
check_perm
=
[](
const
std
::
vector
<
int64_t
>
&
x
)
->
bool
{
size_t
len
=
x
.
size
();
std
::
vector
<
size_t
>
cnt
(
len
,
0
);
for
(
size_t
i
=
0
;
i
<
len
;
++
i
)
{
if
(
x
[
i
]
>=
static_cast
<
int64_t
>
(
len
)
||
x
[
i
]
<
0
||
cnt
[
x
[
i
]])
{
return
false
;
}
cnt
[
x
[
i
]]
++
;
}
return
true
;
};
PADDLE_ENFORCE_EQ
(
true
,
check_perm
(
tensor_indices
),
platform
::
errors
::
PreconditionNotMet
(
"tensor_indices must be a permutation from 0 to %lu"
,
tensor_indices
.
size
()));
// the return vector
// the return vector
std
::
vector
<
std
::
vector
<
size_t
>>
res
;
std
::
vector
<
std
::
vector
<
size_t
>>
res
;
...
@@ -294,9 +411,15 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
...
@@ -294,9 +411,15 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vars
.
size
();
++
i
)
{
const
auto
&
var
=
vars
[
i
];
const
auto
&
var
=
vars
[
i
];
if
(
is_sparse_gradient
[
i
])
{
size_t
tensor_real_index
=
i
;
if
(
!
tensor_indices
.
empty
())
{
tensor_real_index
=
tensor_indices
[
i
];
}
if
(
is_sparse_gradient
[
tensor_real_index
])
{
// we keep sparse var a single group
// we keep sparse var a single group
res
.
push_back
({
i
});
res
.
push_back
({
tensor_real_index
});
continue
;
continue
;
}
}
...
@@ -313,7 +436,7 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
...
@@ -313,7 +436,7 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
<<
" is not tensor or selected_rows, so skip it"
;
<<
" is not tensor or selected_rows, so skip it"
;
continue
;
continue
;
}
}
group_info
.
first
.
push_back
(
i
);
group_info
.
first
.
push_back
(
tensor_real_index
);
group_info
.
second
+=
framework
::
SizeOfType
(
var_dtype
)
*
var_size
;
group_info
.
second
+=
framework
::
SizeOfType
(
var_dtype
)
*
var_size
;
if
(
group_limit_index
.
find
(
var_dtype_str
)
==
group_limit_index
.
end
())
{
if
(
group_limit_index
.
find
(
var_dtype_str
)
==
group_limit_index
.
end
())
{
...
@@ -344,10 +467,12 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
...
@@ -344,10 +467,12 @@ std::vector<std::vector<size_t>> AssignGroupBySize(
platform
::
errors
::
PreconditionNotMet
(
platform
::
errors
::
PreconditionNotMet
(
"AssignGroupBySize construct empty group, please check."
));
"AssignGroupBySize construct empty group, please check."
));
}
}
std
::
sort
(
res
.
begin
(),
res
.
end
(),
if
(
tensor_indices
.
empty
())
{
[](
const
std
::
vector
<
size_t
>
&
x
,
const
std
::
vector
<
size_t
>
&
y
)
{
std
::
sort
(
res
.
begin
(),
res
.
end
(),
return
x
.
front
()
<
y
.
front
();
[](
const
std
::
vector
<
size_t
>
&
x
,
const
std
::
vector
<
size_t
>
&
y
)
{
});
return
x
.
front
()
<
y
.
front
();
});
}
return
res
;
return
res
;
}
}
#endif
#endif
...
...
paddle/fluid/imperative/reducer.h
浏览文件 @
2ef9e0e2
...
@@ -86,6 +86,8 @@ class Group {
...
@@ -86,6 +86,8 @@ class Group {
std
::
vector
<
framework
::
Tensor
>
dense_tensors_
;
std
::
vector
<
framework
::
Tensor
>
dense_tensors_
;
std
::
vector
<
size_t
>
length_
;
std
::
vector
<
size_t
>
length_
;
int64_t
all_length_
{
0
};
// Global indices of participating variables in the group
// Global indices of participating variables in the group
std
::
vector
<
size_t
>
variable_indices_
;
std
::
vector
<
size_t
>
variable_indices_
;
...
@@ -97,53 +99,15 @@ class Group {
...
@@ -97,53 +99,15 @@ class Group {
framework
::
proto
::
VarType
::
Type
dtype_
;
framework
::
proto
::
VarType
::
Type
dtype_
;
// context is used to select the stream for concat
// context is used to select the stream for concat
void
ConcatTensors
(
const
platform
::
CUDADeviceContext
&
context
)
{
void
ConcatTensors
(
const
platform
::
CUDADeviceContext
&
context
);
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP16
:
ConcatTensorsForAllReduce
<
platform
::
float16
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
case
framework
::
proto
::
VarType
::
FP32
:
ConcatTensorsForAllReduce
<
float
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
case
framework
::
proto
::
VarType
::
FP64
:
ConcatTensorsForAllReduce
<
double
>
(
context
,
dense_tensors_
,
&
dense_contents_
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it concats tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
dtype_
)));
}
}
// context is used to select the stream for split
// context is used to select the stream for split
void
SplitTensors
(
const
platform
::
CUDADeviceContext
&
context
)
{
void
SplitTensors
(
const
platform
::
CUDADeviceContext
&
context
);
switch
(
dtype_
)
{
case
framework
::
proto
::
VarType
::
FP16
:
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
,
const
Group
&
);
SplitTensorsForAllReduce
<
platform
::
float16
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
case
framework
::
proto
::
VarType
::
FP32
:
SplitTensorsForAllReduce
<
float
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
case
framework
::
proto
::
VarType
::
FP64
:
SplitTensorsForAllReduce
<
double
>
(
context
,
&
dense_contents_
,
&
dense_tensors_
);
break
;
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Data type (%s) is not supported when it splits tensors for "
"allreduce."
,
framework
::
DataTypeToString
(
dtype_
)));
}
}
};
};
struct
Variable
Index
{
struct
Variable
Locator
{
// record the index in groups_
// record the index in groups_
size_t
group_index
;
size_t
group_index
;
size_t
inside_group_index
;
size_t
inside_group_index
;
...
@@ -155,22 +119,21 @@ class Reducer {
...
@@ -155,22 +119,21 @@ class Reducer {
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
vars
,
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
vars
,
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
,
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
);
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
,
const
std
::
vector
<
size_t
>&
group_size_limits
);
virtual
~
Reducer
()
{}
virtual
~
Reducer
()
{}
void
InitializeGroups
(
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
);
void
InitializeGroups
(
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
);
int64_t
InitializeDenseGroups
(
const
std
::
vector
<
size_t
>&
variable_indices_
,
void
InitializeDenseGroups
(
const
std
::
vector
<
size_t
>&
variable_indices_
,
Group
*
p_group
);
Group
*
p_group
);
void
PrepareForBackward
();
void
PrepareForBackward
();
void
AddDistHook
(
VariableWrapper
*
var_warpper
,
void
AddDistHook
(
VariableWrapper
*
var_warpper
,
size_t
var_index
);
const
VariableIndex
&
var_index
);
void
MarkVariableReady
(
const
VariableIndex
&
var_index
,
void
MarkVariableReady
(
size_t
var_index
,
VariableWrapper
*
var_warpper
);
VariableWrapper
*
var_warpper
);
void
MarkGroupReady
(
size_t
group_index
);
void
MarkGroupReady
(
size_t
group_index
);
...
@@ -178,15 +141,21 @@ class Reducer {
...
@@ -178,15 +141,21 @@ class Reducer {
void
ReleaseReducer
();
void
ReleaseReducer
();
std
::
vector
<
std
::
vector
<
size_t
>>
RebuildGruops
();
void
CreateGroupEvents
(
int
group_num
);
// Reducer Singleton
// Reducer Singleton
static
std
::
shared_ptr
<
Reducer
>
SetInstance
(
static
std
::
shared_ptr
<
Reducer
>
SetInstance
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
vars
,
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
vars
,
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
,
const
std
::
vector
<
std
::
vector
<
size_t
>>&
group_indices
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
)
{
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
,
const
std
::
vector
<
size_t
>&
group_size_limits
)
{
if
(
NULL
==
s_instance_
)
{
if
(
NULL
==
s_instance_
)
{
s_instance_
.
reset
(
new
paddle
::
imperative
::
Reducer
(
s_instance_
.
reset
(
new
paddle
::
imperative
::
Reducer
(
vars
,
group_indices
,
is_sparse_gradient
,
parallel_ctx
));
vars
,
group_indices
,
is_sparse_gradient
,
parallel_ctx
,
group_size_limits
));
}
}
return
s_instance_
;
return
s_instance_
;
}
}
...
@@ -208,17 +177,26 @@ class Reducer {
...
@@ -208,17 +177,26 @@ class Reducer {
std
::
once_flag
once_flag_
;
std
::
once_flag
once_flag_
;
std
::
vector
<
bool
>
is_sparse_gradient_
;
std
::
vector
<
bool
>
is_sparse_gradient_
;
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx_
;
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx_
;
std
::
vector
<
VariableLocator
>
variable_locators_
;
// Following variables are to help sync stream
std
::
vector
<
std
::
shared_ptr
<
platform
::
CudaEventObject
>>
events_
;
std
::
vector
<
std
::
shared_ptr
<
platform
::
CudaEventObject
>>
events_
;
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_enent_
;
std
::
shared_ptr
<
platform
::
CudaEventObject
>
comm_enent_
;
cudaStream_t
compute_stream_
;
cudaStream_t
compute_stream_
;
cudaStream_t
comm_stream_
;
cudaStream_t
comm_stream_
;
// Following variables are to help rebuild group
bool
has_rebuilt_group_
{
false
};
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
rebuild_vars_
;
std
::
vector
<
int64_t
>
rebuild_var_indices_
;
const
std
::
vector
<
size_t
>
group_size_limits_
;
};
};
std
::
vector
<
std
::
vector
<
size_t
>>
AssignGroupBySize
(
std
::
vector
<
std
::
vector
<
size_t
>>
AssignGroupBySize
(
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
tensors
,
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>&
tensors
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
const
std
::
vector
<
bool
>&
is_sparse_gradient
,
const
std
::
vector
<
size_t
>&
group_size_limits
);
const
std
::
vector
<
size_t
>&
group_size_limits
,
const
std
::
vector
<
int64_t
>&
tensor_indices
=
{});
#endif
#endif
}
// namespace imperative
}
// namespace imperative
...
...
paddle/fluid/imperative/tests/CMakeLists.txt
浏览文件 @
2ef9e0e2
...
@@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry
...
@@ -12,3 +12,7 @@ cc_test(test_layer SRCS test_layer.cc DEPS layer proto_desc operator op_registry
cc_test
(
test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place
)
cc_test
(
test_prepare_op SRCS test_prepare_op.cc DEPS prepared_operator op_info split_op layer concat_and_split activation_op place
)
cc_test
(
test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy
)
cc_test
(
test_tracer SRCS test_tracer.cc DEPS tracer layer proto_desc operator op_registry variable_helper mul_op reduce_sum_op elementwise_add_op memcpy
)
cc_test
(
test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy
)
cc_test
(
test_hooks SRCS test_hooks.cc DEPS tracer basic_engine layer proto_desc operator op_registry variable_helper mul_op elementwise_add_op memcpy
)
if
(
WITH_NCCL
)
cc_test
(
test_group SRCS test_group.cc DEPS reducer concat_and_split memcpy
)
endif
()
paddle/fluid/imperative/tests/test_group.cc
0 → 100644
浏览文件 @
2ef9e0e2
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <ostream>
#include <sstream>
#include <string>
#include "glog/logging.h"
#include "gtest/gtest.h"
#if defined(PADDLE_WITH_NCCL)
#include "paddle/fluid/imperative/reducer.h"
#endif
namespace
paddle
{
namespace
imperative
{
#if defined(PADDLE_WITH_NCCL)
TEST
(
TestGroup
,
TestPrintGroupMessage
)
{
Group
group
;
std
::
stringstream
stream1
,
stream2
;
stream1
<<
group
;
ASSERT_STREQ
(
stream1
.
str
().
c_str
(),
"numul: 0 ;is_sparse: 0 ;var number: 0
\n
[]
\n
"
);
std
::
vector
<
size_t
>
vars
;
size_t
vars_num
=
102
;
for
(
size_t
i
=
0
;
i
<
vars_num
;
++
i
)
{
vars
.
push_back
(
i
);
}
group
.
variable_indices_
=
vars
;
group
.
all_length_
=
102
;
group
.
is_sparse_
=
false
;
std
::
string
head
=
"numul: 102 ;is_sparse: 0 ;var number: 102
\n
"
;
head
=
head
+
"["
;
auto
begin
=
vars
.
begin
();
auto
end
=
vars
.
end
();
for
(
int
i
=
0
;
begin
!=
end
&&
i
<
100
;
++
i
,
++
begin
)
{
if
(
i
>
0
)
head
+=
' '
;
head
+=
std
::
to_string
(
*
begin
);
}
if
(
begin
!=
end
)
{
head
+=
" ..."
;
}
head
+=
"]
\n
"
;
stream2
<<
group
;
ASSERT_STREQ
(
stream2
.
str
().
c_str
(),
head
.
c_str
());
}
#endif
}
// namespace imperative
}
// namespace paddle
paddle/fluid/pybind/imperative.cc
浏览文件 @
2ef9e0e2
...
@@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) {
...
@@ -1289,9 +1289,11 @@ void BindImperative(py::module *m_ptr) {
[](
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
[](
const
std
::
vector
<
std
::
shared_ptr
<
imperative
::
VarBase
>>
&
vars
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
group_indices
,
const
std
::
vector
<
std
::
vector
<
size_t
>>
&
group_indices
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
const
std
::
vector
<
bool
>
&
is_sparse_gradient
,
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
)
{
std
::
shared_ptr
<
imperative
::
ParallelContext
>
parallel_ctx
,
const
std
::
vector
<
size_t
>
&
group_size_limits
)
{
return
imperative
::
Reducer
::
SetInstance
(
return
imperative
::
Reducer
::
SetInstance
(
vars
,
group_indices
,
is_sparse_gradient
,
parallel_ctx
);
vars
,
group_indices
,
is_sparse_gradient
,
parallel_ctx
,
group_size_limits
);
}))
}))
.
def
(
"prepare_for_backward"
,
&
imperative
::
Reducer
::
PrepareForBackward
,
.
def
(
"prepare_for_backward"
,
&
imperative
::
Reducer
::
PrepareForBackward
,
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
...
@@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) {
...
@@ -1299,6 +1301,7 @@ void BindImperative(py::module *m_ptr) {
m
.
def
(
"assign_group_by_size"
,
&
imperative
::
AssignGroupBySize
,
py
::
arg
(
"vars"
),
m
.
def
(
"assign_group_by_size"
,
&
imperative
::
AssignGroupBySize
,
py
::
arg
(
"vars"
),
py
::
arg
(
"is_sparse_gradient"
),
py
::
arg
(
"is_sparse_gradient"
),
py
::
arg
(
"group_size_limits"
)
=
std
::
vector
<
size_t
>
{
25
*
1024
*
1024
},
py
::
arg
(
"group_size_limits"
)
=
std
::
vector
<
size_t
>
{
25
*
1024
*
1024
},
py
::
arg
(
"tensor_indices"
)
=
std
::
vector
<
int64_t
>
{},
py
::
call_guard
<
py
::
gil_scoped_release
>
());
py
::
call_guard
<
py
::
gil_scoped_release
>
());
#endif
#endif
}
}
...
...
python/paddle/distributed/fleet/base/distributed_strategy.py
浏览文件 @
2ef9e0e2
...
@@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core
...
@@ -18,7 +18,6 @@ from paddle.fluid.framework import Variable, set_flags, core
from
paddle.fluid.wrapped_decorator
import
wrap_decorator
from
paddle.fluid.wrapped_decorator
import
wrap_decorator
import
google.protobuf.text_format
import
google.protobuf.text_format
import
google.protobuf
import
google.protobuf
from
paddle.fluid.framework
import
dygraph_only
__all__
=
[
"DistributedStrategy"
]
__all__
=
[
"DistributedStrategy"
]
...
...
python/paddle/fluid/dygraph/parallel.py
浏览文件 @
2ef9e0e2
...
@@ -441,10 +441,11 @@ class DataParallel(layers.Layer):
...
@@ -441,10 +441,11 @@ class DataParallel(layers.Layer):
"ParallelContext must be initialized before. You should use init_parallel_env() before"
\
"ParallelContext must be initialized before. You should use init_parallel_env() before"
\
"constructing the DataParallel."
"constructing the DataParallel."
self
.
_reducer
=
core
.
Reducer
(
trainable_parameters
,
self
.
_reducer
=
core
.
Reducer
(
list
(
reversed
(
self
.
group_indices
)),
trainable_parameters
,
is_sparse_gradient
,
list
(
reversed
(
self
.
group_indices
)),
is_sparse_gradient
,
parallel_helper
.
__parallel_ctx__clz__
)
parallel_helper
.
__parallel_ctx__clz__
,
[
self
.
last_comm_buffer_size
,
self
.
comm_buffer_size
])
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
self
.
_strategy
.
nranks
>
1
:
if
self
.
_strategy
.
nranks
>
1
:
...
...
python/paddle/fluid/tests/unittests/test_imperative_group.py
浏览文件 @
2ef9e0e2
...
@@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase):
...
@@ -155,6 +155,30 @@ class TestDataParallelGroup(unittest.TestCase):
var_list
,
[
True
,
False
,
False
,
False
,
False
,
True
],
[
200
,
400
])
var_list
,
[
True
,
False
,
False
,
False
,
False
,
True
],
[
200
,
400
])
self
.
assertEqual
([[
0
],
[
1
],
[
2
],
[
3
],
[
4
],
[
5
]],
res
)
self
.
assertEqual
([[
0
],
[
1
],
[
2
],
[
3
],
[
4
],
[
5
]],
res
)
def
test_construct_group8
(
self
):
# one dtype & one limit capability & have tensor_indices
var_list
=
[]
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
25
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
100
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
50
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
25
]))
res
=
core
.
assign_group_by_size
(
var_list
,
[
False
,
False
,
False
,
False
],
[
400
],
[
3
,
0
,
1
,
2
])
self
.
assertEqual
([[
3
,
0
],
[
1
],
[
2
]],
res
)
def
test_construct_group9
(
self
):
# one dtype & one limit capability & have tensor_indices
var_list
=
[]
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
25
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
25
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
25
]))
var_list
.
append
(
self
.
create_varbase
(
core
.
VarDesc
.
VarType
.
FP32
,
[
2
,
1000
]))
res
=
core
.
assign_group_by_size
(
var_list
,
[
False
,
False
,
False
,
True
],
[
300
],
[
1
,
0
,
2
,
3
])
self
.
assertEqual
([[
1
,
0
],
[
3
],
[
2
]],
res
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
unittest
.
main
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录