Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
58560622
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
58560622
编写于
7月 02, 2018
作者:
F
fengjiayi
提交者:
GitHub
7月 02, 2018
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #11854 from JiayiFeng/dev_data_balance
Data balance for the ParallelExecutor
上级
87dd01d6
ff4317ce
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
469 addition
and
7 deletion
+469
-7
paddle/fluid/framework/details/CMakeLists.txt
paddle/fluid/framework/details/CMakeLists.txt
+2
-1
paddle/fluid/framework/details/build_strategy.h
paddle/fluid/framework/details/build_strategy.h
+2
-0
paddle/fluid/framework/details/data_balance_op_handle.cc
paddle/fluid/framework/details/data_balance_op_handle.cc
+154
-0
paddle/fluid/framework/details/data_balance_op_handle.h
paddle/fluid/framework/details/data_balance_op_handle.h
+59
-0
paddle/fluid/framework/details/fetch_op_handle.cc
paddle/fluid/framework/details/fetch_op_handle.cc
+1
-1
paddle/fluid/framework/details/multi_devices_graph_builder.cc
...le/fluid/framework/details/multi_devices_graph_builder.cc
+34
-2
paddle/fluid/framework/details/multi_devices_graph_builder.h
paddle/fluid/framework/details/multi_devices_graph_builder.h
+3
-0
paddle/fluid/framework/details/op_handle_base.cc
paddle/fluid/framework/details/op_handle_base.cc
+2
-0
paddle/fluid/framework/lod_tensor.cc
paddle/fluid/framework/lod_tensor.cc
+2
-0
paddle/fluid/operators/read_op.cc
paddle/fluid/operators/read_op.cc
+16
-2
paddle/fluid/pybind/pybind.cc
paddle/fluid/pybind/pybind.cc
+5
-1
python/paddle/fluid/tests/unittests/.gitignore
python/paddle/fluid/tests/unittests/.gitignore
+2
-0
python/paddle/fluid/tests/unittests/test_data_balance.py
python/paddle/fluid/tests/unittests/test_data_balance.py
+187
-0
未找到文件。
paddle/fluid/framework/details/CMakeLists.txt
浏览文件 @
58560622
...
...
@@ -25,11 +25,12 @@ else()
cc_library
(
broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
endif
()
cc_library
(
data_balance_op_handle SRCS data_balance_op_handle.cc DEPS op_handle_base scope lod_tensor
)
cc_library
(
gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor
)
cc_library
(
fuse_vars_op_handle SRCS fuse_vars_op_handle.cc DEPS op_handle_base scope
)
cc_library
(
multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle
)
scale_loss_grad_op_handle rpc_op_handle all_reduce_op_handle reduce_op_handle broadcast_op_handle
data_balance_op_handle
)
cc_library
(
ssa_graph_builder_factory SRCS ssa_graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer ssa_graph_checker
)
...
...
paddle/fluid/framework/details/build_strategy.h
浏览文件 @
58560622
...
...
@@ -33,6 +33,8 @@ struct BuildStrategy {
GradientScaleStrategy
gradient_scale_
{
GradientScaleStrategy
::
kCoeffNumDevice
};
std
::
string
debug_graphviz_path_
{
""
};
bool
enable_data_balance_
{
true
};
};
}
// namespace details
...
...
paddle/fluid/framework/details/data_balance_op_handle.cc
0 → 100644
浏览文件 @
58560622
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include <algorithm>
#include "paddle/fluid/framework/details/container_cast.h"
namespace
paddle
{
namespace
framework
{
namespace
details
{
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle
::
DataBalanceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
)
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{
if
(
ctxs
)
{
for
(
auto
&
p
:
places_
)
{
this
->
dev_ctxes_
[
p
]
=
ctxs
->
DevCtx
(
p
);
}
}
}
#else
DataBalanceOpHandle
::
DataBalanceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
)
:
local_scopes_
(
local_scopes
),
places_
(
places
)
{}
#endif
std
::
string
DataBalanceOpHandle
::
Name
()
const
{
return
"data balance"
;
}
std
::
vector
<
std
::
array
<
int
,
3
>>
DataBalanceOpHandle
::
GetBalancePlan
(
const
std
::
vector
<
int
>
&
device_sizes
)
{
int
device_num
=
device_sizes
.
size
();
int
total_size
=
0
;
int
empty_num
=
0
;
std
::
vector
<
std
::
array
<
int
,
2
>>
size_device_vec
;
size_device_vec
.
reserve
(
device_num
);
for
(
int
i
=
0
;
i
<
device_num
;
++
i
)
{
if
(
device_sizes
[
i
]
==
0
)
{
++
empty_num
;
}
total_size
+=
device_sizes
[
i
];
size_device_vec
.
push_back
({{
device_sizes
[
i
],
i
}});
}
std
::
vector
<
std
::
array
<
int
,
3
>>
res
;
if
(
empty_num
==
0
)
{
// No need to do data balance.
return
res
;
}
if
(
total_size
<
device_num
)
{
// No enough data.
PADDLE_THROW
(
"There is no next data."
);
}
std
::
sort
(
size_device_vec
.
begin
(),
size_device_vec
.
end
(),
[](
const
std
::
array
<
int
,
2
>
&
a
,
const
std
::
array
<
int
,
2
>
&
b
)
{
return
a
[
0
]
>
b
[
0
];
});
int
expected_device_size
=
total_size
/
device_num
;
int
src_idx
=
0
;
for
(
int
dst_idx
=
device_num
-
empty_num
;
dst_idx
<
device_num
;
++
dst_idx
)
{
if
(
size_device_vec
[
src_idx
][
0
]
<=
expected_device_size
)
{
++
src_idx
;
PADDLE_ENFORCE_LT
(
src_idx
,
device_num
-
empty_num
,
"In current srategy an empty tensor should not be copy source."
);
}
size_device_vec
[
src_idx
][
0
]
-=
expected_device_size
;
size_device_vec
[
dst_idx
][
0
]
+=
expected_device_size
;
res
.
push_back
({{
size_device_vec
[
src_idx
][
1
],
size_device_vec
[
dst_idx
][
1
],
expected_device_size
}});
}
return
res
;
}
void
DataBalanceOpHandle
::
RunImpl
()
{
if
(
places_
.
size
()
==
1
)
{
return
;
}
auto
in_var_handles
=
DynamicCast
<
VarHandle
>
(
inputs_
);
auto
out_var_handles
=
DynamicCast
<
VarHandle
>
(
outputs_
);
PADDLE_ENFORCE
(
in_var_handles
.
size
()
%
places_
.
size
()
==
0
);
PADDLE_ENFORCE_EQ
(
in_var_handles
.
size
(),
out_var_handles
.
size
(),
"The NoDummyInputSize and NoDummyOutputSize should be equal."
);
int
data_num
=
in_var_handles
.
size
()
/
places_
.
size
();
WaitInputVarGenerated
();
std
::
vector
<
std
::
vector
<
LoDTensor
*>>
lod_tensors
(
data_num
);
std
::
vector
<
int
>
device_sizes
;
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
in_var_handles
.
size
());
++
i
)
{
PADDLE_ENFORCE_EQ
(
in_var_handles
[
i
]
->
name_
,
out_var_handles
[
i
]
->
name_
,
"The name of input and output should be equal."
);
int
place_idx
=
i
/
data_num
;
int
data_idx
=
i
%
data_num
;
auto
*
local_scope
=
local_scopes_
[
place_idx
]
->
FindVar
(
kLocalExecScopeName
)
->
Get
<
Scope
*>
();
auto
*
tensor_var
=
local_scope
->
FindVar
(
in_var_handles
[
i
]
->
name_
);
PADDLE_ENFORCE
(
tensor_var
->
IsType
<
LoDTensor
>
());
auto
*
tensor
=
tensor_var
->
GetMutable
<
LoDTensor
>
();
lod_tensors
[
data_idx
].
push_back
(
tensor
);
int
ins_size
=
tensor
->
lod
().
empty
()
?
tensor
->
dims
()[
0
]
:
tensor
->
NumElements
();
if
(
data_idx
==
0
)
{
device_sizes
.
emplace_back
(
ins_size
);
}
else
{
PADDLE_ENFORCE_EQ
(
ins_size
,
device_sizes
.
at
(
place_idx
),
"All data on the same device shall have the same batch size."
);
}
}
const
auto
&
balance_plan
=
GetBalancePlan
(
device_sizes
);
for
(
const
auto
&
trans
:
balance_plan
)
{
for
(
int
data_idx
=
0
;
data_idx
<
data_num
;
++
data_idx
)
{
LoDTensor
*
src_tensor
=
lod_tensors
[
data_idx
][
trans
[
0
]];
LoDTensor
*
dst_tensor
=
lod_tensors
[
data_idx
][
trans
[
1
]];
int
trans_ins_size
=
trans
[
2
];
LoD
src_lod
=
src_tensor
->
lod
();
int
src_ins_size
=
src_lod
.
empty
()
?
src_tensor
->
dims
()[
0
]
:
src_tensor
->
NumElements
();
int
cut_point
=
src_ins_size
-
trans_ins_size
;
if
(
!
src_lod
.
empty
())
{
for
(
auto
&
level
:
src_lod
)
{
cut_point
=
level
[
cut_point
];
}
}
TensorCopySync
(
src_tensor
->
Slice
(
cut_point
,
src_tensor
->
dims
()[
0
]),
dst_tensor
->
place
(),
dst_tensor
);
src_tensor
->
ShareDataWith
(
src_tensor
->
Slice
(
0
,
cut_point
));
if
(
!
src_lod
.
empty
())
{
dst_tensor
->
set_lod
(
SliceInLevel
(
src_lod
,
0
,
src_ins_size
-
trans_ins_size
,
src_ins_size
));
src_tensor
->
set_lod
(
SliceInLevel
(
src_lod
,
0
,
0
,
src_ins_size
-
trans_ins_size
));
}
}
}
}
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/data_balance_op_handle.h
0 → 100644
浏览文件 @
58560622
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/op_handle_base.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif
namespace
paddle
{
namespace
framework
{
namespace
details
{
struct
DataBalanceOpHandle
:
public
OpHandleBase
{
public:
#ifdef PADDLE_WITH_CUDA
DataBalanceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
,
const
platform
::
NCCLContextMap
*
ctxs
);
#else
DataBalanceOpHandle
(
const
std
::
vector
<
Scope
*>
&
local_scopes
,
const
std
::
vector
<
platform
::
Place
>
&
places
);
#endif
std
::
string
Name
()
const
override
;
bool
IsMultiDeviceTransfer
()
override
{
return
false
;
};
protected:
void
RunImpl
()
override
;
private:
// std::vector<(src_dev_id, dst_dev_id, trans_size)>
std
::
vector
<
std
::
array
<
int
,
3
>>
GetBalancePlan
(
const
std
::
vector
<
int
>
&
batch_size_per_device
);
const
std
::
vector
<
Scope
*>
local_scopes_
;
const
std
::
vector
<
platform
::
Place
>
places_
;
};
}
// namespace details
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/details/fetch_op_handle.cc
浏览文件 @
58560622
...
...
@@ -67,8 +67,8 @@ void FetchOpHandle::RunImpl() {
#endif
}
else
{
tensors_
[
i
].
ShareDataWith
(
t
);
tensors_
[
i
].
set_lod
(
t
.
lod
());
}
tensors_
[
i
].
set_lod
(
t
.
lod
());
}
this
->
WaitAndMergeCPUTensors
();
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.cc
浏览文件 @
58560622
...
...
@@ -20,6 +20,7 @@
#include "paddle/fluid/framework/details/all_reduce_op_handle.h"
#include "paddle/fluid/framework/details/broadcast_op_handle.h"
#include "paddle/fluid/framework/details/computation_op_handle.h"
#include "paddle/fluid/framework/details/data_balance_op_handle.h"
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/reduce_op_handle.h"
#include "paddle/fluid/framework/details/rpc_op_handle.h"
...
...
@@ -215,7 +216,14 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
}
else
{
// This op runs on all devices, and its output may have parameter's
// gradients.
if
(
op
->
Type
()
==
"read"
&&
strategy_
.
enable_data_balance_
)
{
op
->
SetAttr
(
"throw_eof_exp"
,
false
);
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
const
auto
&
data_var_names
=
op
->
Output
(
"Out"
);
InsertDataBalanceOp
(
&
result
,
data_var_names
);
}
else
{
CreateComputationalOps
(
&
result
,
*
op
,
places_
.
size
());
}
if
(
!
is_forwarding
&&
places_
.
size
()
>
1
)
{
// Currently, we assume that once gradient is generated, it can be
...
...
@@ -360,6 +368,29 @@ void MultiDevSSAGraphBuilder::InsertAllReduceOp(SSAGraph *result,
}
}
void
MultiDevSSAGraphBuilder
::
InsertDataBalanceOp
(
SSAGraph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
{
#ifdef PADDLE_WITH_CUDA
result
->
ops_
.
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
,
nccl_ctxs_
));
#else
result
->
ops_
.
emplace_back
(
new
DataBalanceOpHandle
(
local_scopes_
,
places_
));
#endif
auto
*
op_handle
=
result
->
ops_
.
back
().
get
();
for
(
size_t
i
=
0
;
i
<
places_
.
size
();
++
i
)
{
auto
&
p
=
places_
[
i
];
SetCommunicationContext
(
op_handle
,
p
);
for
(
const
std
::
string
&
d_name
:
datas
)
{
auto
&
vars
=
result
->
vars_
[
i
][
d_name
];
PADDLE_ENFORCE
(
!
vars
.
empty
());
op_handle
->
AddInput
(
vars
.
back
().
get
());
auto
var
=
new
VarHandle
(
vars
.
size
(),
i
,
d_name
,
p
);
vars
.
emplace_back
(
var
);
op_handle
->
AddOutput
(
var
);
}
}
}
bool
MultiDevSSAGraphBuilder
::
IsParameterGradientOnce
(
const
std
::
string
&
og
,
std
::
unordered_set
<
std
::
string
>
*
og_has_been_broadcast
)
const
{
...
...
@@ -512,7 +543,8 @@ void MultiDevSSAGraphBuilder::CreateRPCOp(SSAGraph *result,
op_dev_id
=
GetVarDeviceID
(
op
.
InputArgumentNames
()[
0
]);
// the variable name which contains .block means it was splited by
// split_byref op
// so that we can balance the variable blocks to all the pserver instances.
// so that we can balance the variable blocks to all the pserver
// instances.
if
(
strategy_
.
reduce_
==
BuildStrategy
::
ReduceStrategy
::
kAllReduce
&&
op
.
InputArgumentNames
()[
0
].
find
(
".block"
)
==
std
::
string
::
npos
)
{
op_dev_id
=
GetAppropriateDeviceID
(
op
.
InputArgumentNames
());
...
...
paddle/fluid/framework/details/multi_devices_graph_builder.h
浏览文件 @
58560622
...
...
@@ -101,6 +101,9 @@ class MultiDevSSAGraphBuilder : public SSAGraphBuilder {
void
InsertAllReduceOp
(
SSAGraph
*
result
,
const
std
::
string
&
og
)
const
;
void
InsertDataBalanceOp
(
SSAGraph
*
result
,
const
std
::
vector
<
std
::
string
>
&
datas
)
const
;
void
CreateBroadcastOp
(
SSAGraph
*
result
,
const
std
::
string
&
p_name
,
size_t
src_dev_id
)
const
;
...
...
paddle/fluid/framework/details/op_handle_base.cc
浏览文件 @
58560622
...
...
@@ -58,8 +58,10 @@ void OpHandleBase::Run(bool use_cuda) {
void
OpHandleBase
::
RecordWaitEventOnCtx
(
platform
::
DeviceContext
*
waited_ctx
)
{
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_NOT_NULL
(
waited_ctx
);
if
(
platform
::
is_cpu_place
(
waited_ctx
->
GetPlace
())
||
events_
.
empty
())
{
for
(
auto
&
dev_ctx
:
dev_ctxes_
)
{
PADDLE_ENFORCE_NOT_NULL
(
dev_ctx
.
second
);
dev_ctx
.
second
->
Wait
();
}
}
else
{
...
...
paddle/fluid/framework/lod_tensor.cc
浏览文件 @
58560622
...
...
@@ -90,6 +90,7 @@ std::string LoDToString(const LoD &lod) {
LoD
SliceInLevel
(
const
LoD
&
in
,
size_t
level
,
size_t
elem_begin
,
size_t
elem_end
)
{
PADDLE_ENFORCE_LT
(
level
,
in
.
size
());
PADDLE_ENFORCE_LT
(
elem_begin
,
elem_end
);
PADDLE_ENFORCE_LT
(
elem_end
,
in
[
level
].
size
());
LoD
res
;
...
...
@@ -393,6 +394,7 @@ void LoDTensor::MergeLoDTensor(
new_dim
[
0
]
+=
t
->
dims
()[
0
];
auto
&
lod
=
t
->
lod
();
PADDLE_ENFORCE_EQ
(
new_lod
.
size
(),
lod
.
size
());
for
(
size_t
j
=
0
;
j
<
lod
.
size
();
++
j
)
{
auto
&
sub_lod
=
new_lod
[
j
];
auto
&
offset
=
sub_lod
.
back
();
...
...
paddle/fluid/operators/read_op.cc
浏览文件 @
58560622
...
...
@@ -66,9 +66,19 @@ class ReadOp : public framework::OperatorBase {
std
::
vector
<
std
::
string
>
out_arg_names
=
Outputs
(
"Out"
);
std
::
vector
<
framework
::
LoDTensor
>
ins
;
reader
->
ReadNext
(
&
ins
);
PADDLE_ENFORCE
(
!
ins
.
empty
(),
"There is no next data."
);
if
(
ins
.
empty
())
{
if
(
Attr
<
bool
>
(
"throw_eof_exp"
))
{
PADDLE_THROW
(
"There is no next data."
);
}
else
{
ins
.
resize
(
out_arg_names
.
size
());
for
(
auto
&
tensor
:
ins
)
{
// data type is not important for subsequent DataBalanceOpHandle
tensor
.
mutable_data
<
float
>
(
framework
::
make_ddim
({
0
}),
dev_place
);
}
}
}
PADDLE_ENFORCE_EQ
(
ins
.
size
(),
out_arg_names
.
size
());
for
(
size_t
i
=
0
;
i
<
in
s
.
size
();
++
i
)
{
for
(
size_t
i
=
0
;
i
<
out_arg_name
s
.
size
();
++
i
)
{
auto
*
out
=
scope
.
FindVar
(
out_arg_names
[
i
])
->
GetMutable
<
framework
::
LoDTensor
>
();
out
->
ShareDataWith
(
ins
[
i
]);
...
...
@@ -82,6 +92,10 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
void
Make
()
override
{
AddInput
(
"Reader"
,
"(ReaderHolder) The executed reader."
);
AddOutput
(
"Out"
,
"(LoDTensor) The output data."
).
AsDuplicable
();
AddAttr
<
bool
>
(
"throw_eof_exp"
,
"If set true, an exception will be thrown when the Reader "
"yields empty (which means there is no next data)."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
Read Operator
...
...
paddle/fluid/pybind/pybind.cc
浏览文件 @
58560622
...
...
@@ -643,7 +643,11 @@ All parameter, weight, gradient are variables in Paddle.
[](
const
BuildStrategy
&
self
)
{
return
self
.
debug_graphviz_path_
;
},
[](
BuildStrategy
&
self
,
const
std
::
string
&
path
)
{
self
.
debug_graphviz_path_
=
path
;
});
})
.
def_property
(
"enable_data_balance"
,
[](
const
BuildStrategy
&
self
)
{
return
self
.
enable_data_balance_
;
},
[](
BuildStrategy
&
self
,
bool
b
)
{
self
.
enable_data_balance_
=
b
;
});
pe
.
def
(
py
::
init
<
const
std
::
vector
<
platform
::
Place
>
&
,
const
std
::
unordered_set
<
std
::
string
>
&
,
...
...
python/paddle/fluid/tests/unittests/.gitignore
浏览文件 @
58560622
...
...
@@ -4,3 +4,5 @@ mnist_1.recordio
mnist_2.recordio
flowers.recordio
wmt16.recordio
data_balance_test.recordio
data_balance_with_lod_test.recordio
python/paddle/fluid/tests/unittests/test_data_balance.py
0 → 100644
浏览文件 @
58560622
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
paddle.fluid
as
fluid
import
paddle.v2
as
paddle
import
numpy
as
np
class
TestDataBalance
(
unittest
.
TestCase
):
def
prepare_data
(
self
):
def
fake_data_generator
():
for
n
in
xrange
(
self
.
total_ins_num
):
yield
np
.
ones
((
3
,
4
))
*
n
,
n
# Prepare data
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
reader
=
paddle
.
batch
(
fake_data_generator
,
batch_size
=
self
.
batch_size
)
feeder
=
fluid
.
DataFeeder
(
feed_list
=
[
fluid
.
layers
.
data
(
name
=
'image'
,
shape
=
[
3
,
4
],
dtype
=
'float32'
),
fluid
.
layers
.
data
(
name
=
'label'
,
shape
=
[
1
],
dtype
=
'int64'
),
],
place
=
fluid
.
CPUPlace
())
self
.
num_batches
=
fluid
.
recordio_writer
.
convert_reader_to_recordio_file
(
self
.
data_file_name
,
reader
,
feeder
)
def
prepare_lod_data
(
self
):
def
fake_data_generator
():
for
n
in
xrange
(
1
,
self
.
total_ins_num
+
1
):
d1
=
(
np
.
ones
((
n
,
3
))
*
n
).
astype
(
'float32'
)
d2
=
(
np
.
array
(
n
).
reshape
((
1
,
1
))).
astype
(
'int32'
)
yield
d1
,
d2
# Prepare lod data
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
with
fluid
.
recordio_writer
.
create_recordio_writer
(
filename
=
self
.
lod_data_file_name
)
as
writer
:
eof
=
False
generator
=
fake_data_generator
()
while
(
not
eof
):
data_batch
=
[
np
.
array
([]).
reshape
((
0
,
3
)),
np
.
array
([]).
reshape
(
(
0
,
1
))
]
lod
=
[
0
]
for
_
in
xrange
(
self
.
batch_size
):
try
:
ins
=
generator
.
next
()
except
StopIteration
:
eof
=
True
break
for
i
,
d
in
enumerate
(
ins
):
data_batch
[
i
]
=
np
.
concatenate
(
(
data_batch
[
i
],
d
),
axis
=
0
)
lod
.
append
(
lod
[
-
1
]
+
ins
[
0
].
shape
[
0
])
if
data_batch
[
0
].
shape
[
0
]
>
0
:
for
i
,
d
in
enumerate
(
data_batch
):
t
=
fluid
.
LoDTensor
()
t
.
set
(
data_batch
[
i
],
fluid
.
CPUPlace
())
if
i
==
0
:
t
.
set_lod
([
lod
])
writer
.
append_tensor
(
t
)
writer
.
complete_append_tensor
()
def
setUp
(
self
):
self
.
use_cuda
=
fluid
.
core
.
is_compiled_with_cuda
()
self
.
data_file_name
=
'./data_balance_test.recordio'
self
.
lod_data_file_name
=
'./data_balance_with_lod_test.recordio'
self
.
total_ins_num
=
50
self
.
batch_size
=
10
self
.
prepare_data
()
self
.
prepare_lod_data
()
def
main
(
self
):
main_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
data_reader
=
fluid
.
layers
.
io
.
open_files
(
filenames
=
[
self
.
data_file_name
],
shapes
=
[[
-
1
,
3
,
4
],
[
-
1
,
1
]],
lod_levels
=
[
0
,
0
],
dtypes
=
[
'float32'
,
'int64'
])
if
self
.
use_cuda
:
data_reader
=
fluid
.
layers
.
double_buffer
(
data_reader
)
image
,
label
=
fluid
.
layers
.
read_file
(
data_reader
)
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
parallel_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
self
.
use_cuda
,
main_program
=
main_prog
)
if
(
parallel_exe
.
device_count
>
self
.
batch_size
):
print
(
"WARNING: Unittest TestDataBalance skipped.
\
For the result is not correct when device count
\
is larger than batch size."
)
exit
(
0
)
fetch_list
=
[
image
.
name
,
label
.
name
]
data_appeared
=
[
False
]
*
self
.
total_ins_num
while
(
True
):
try
:
image_val
,
label_val
=
parallel_exe
.
run
(
fetch_list
,
return_numpy
=
True
)
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
ins_num
=
image_val
.
shape
[
0
]
broadcasted_label
=
np
.
ones
(
(
ins_num
,
3
,
4
))
*
label_val
.
reshape
((
ins_num
,
1
,
1
))
self
.
assertEqual
(
image_val
.
all
(),
broadcasted_label
.
all
())
for
l
in
label_val
:
self
.
assertFalse
(
data_appeared
[
l
[
0
]])
data_appeared
[
l
[
0
]]
=
True
for
i
in
data_appeared
:
self
.
assertTrue
(
i
)
def
main_lod
(
self
):
main_prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
main_prog
,
startup_prog
):
data_reader
=
fluid
.
layers
.
io
.
open_files
(
filenames
=
[
self
.
lod_data_file_name
],
shapes
=
[[
-
1
,
3
],
[
-
1
,
1
]],
lod_levels
=
[
1
,
0
],
dtypes
=
[
'float32'
,
'int32'
],
thread_num
=
1
)
ins
,
label
=
fluid
.
layers
.
read_file
(
data_reader
)
place
=
fluid
.
CUDAPlace
(
0
)
if
self
.
use_cuda
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
startup_prog
)
parallel_exe
=
fluid
.
ParallelExecutor
(
use_cuda
=
self
.
use_cuda
,
main_program
=
main_prog
)
if
(
parallel_exe
.
device_count
>
self
.
batch_size
):
print
(
"WARNING: Unittest TestDataBalance skipped.
\
For the result is not correct when device count
\
is larger than batch size."
)
exit
(
0
)
fetch_list
=
[
ins
.
name
,
label
.
name
]
data_appeared
=
[
False
]
*
self
.
total_ins_num
while
(
True
):
try
:
ins_tensor
,
label_tensor
=
parallel_exe
.
run
(
fetch_list
,
return_numpy
=
False
)
except
fluid
.
core
.
EnforceNotMet
as
ex
:
self
.
assertIn
(
"There is no next data."
,
ex
.
message
)
break
ins_val
=
np
.
array
(
ins_tensor
)
label_val
=
np
.
array
(
label_tensor
)
ins_lod
=
ins_tensor
.
lod
()[
0
]
self
.
assertEqual
(
ins_val
.
shape
[
1
],
3
)
self
.
assertEqual
(
label_val
.
shape
[
1
],
1
)
self
.
assertEqual
(
len
(
ins_lod
)
-
1
,
label_val
.
shape
[
0
])
for
i
in
range
(
0
,
len
(
ins_lod
)
-
1
):
ins_elem
=
ins_val
[
ins_lod
[
i
]:
ins_lod
[
i
+
1
]][:]
label_elem
=
label_val
[
i
][
0
]
self
.
assertEqual
(
ins_elem
.
all
(),
label_elem
.
all
())
self
.
assertFalse
(
data_appeared
[
int
(
label_elem
-
1
)])
data_appeared
[
int
(
label_elem
-
1
)]
=
True
for
i
in
data_appeared
:
self
.
assertTrue
(
i
)
def
test_all
(
self
):
self
.
main
()
self
.
main_lod
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录