Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
9e80551d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
9e80551d
编写于
2月 18, 2020
作者:
1
123malin
提交者:
GitHub
2月 18, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
support dumping params/grads in transpiler mode (#22490) (#22649)
上级
5515597c
变更
18
隐藏空白更改
内联
并排
Showing
18 changed file
with
434 addition
and
132 deletion
+434
-132
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-0
paddle/fluid/framework/device_worker.cc
paddle/fluid/framework/device_worker.cc
+68
-0
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+16
-9
paddle/fluid/framework/device_worker_test.cc
paddle/fluid/framework/device_worker_test.cc
+55
-2
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+6
-76
paddle/fluid/framework/downpour_worker_opt.cc
paddle/fluid/framework/downpour_worker_opt.cc
+1
-1
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+86
-0
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+86
-2
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+15
-11
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+1
-1
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
.../fleet/parameter_server/distribute_transpiler/__init__.py
+19
-2
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
...eter_server/distribute_transpiler/distributed_strategy.py
+17
-0
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
+1
-1
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
+11
-0
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
+6
-1
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
...paddle/fluid/tests/unittests/test_distributed_strategy.py
+25
-0
python/paddle/fluid/tests/unittests/test_downpoursgd.py
python/paddle/fluid/tests/unittests/test_downpoursgd.py
+4
-15
python/paddle/fluid/trainer_factory.py
python/paddle/fluid/trainer_factory.py
+15
-11
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
9e80551d
...
...
@@ -66,9 +66,11 @@ else()
cc_test
(
mixed_vector_test SRCS mixed_vector_test.cc DEPS place memory device_context tensor
)
endif
()
cc_library
(
lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto version
)
cc_library
(
device_worker SRCS device_worker.cc DEPS trainer_desc_proto lod_tensor
)
cc_test
(
lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory
)
nv_test
(
lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor
)
cc_test
(
device_worker_test SRCS device_worker_test.cc DEPS device_worker
)
cc_library
(
garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog
)
...
...
paddle/fluid/framework/device_worker.cc
浏览文件 @
9e80551d
...
...
@@ -23,5 +23,73 @@ void DeviceWorker::SetDataFeed(DataFeed* data_feed) {
device_reader_
=
data_feed
;
}
template
<
typename
T
>
std
::
string
PrintLodTensorType
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
auto
count
=
tensor
->
numel
();
if
(
start
<
0
||
end
>
count
)
{
VLOG
(
3
)
<<
"access violation"
;
return
"access violation"
;
}
std
::
ostringstream
os
;
for
(
int64_t
i
=
start
;
i
<
end
;
i
++
)
{
os
<<
":"
<<
tensor
->
data
<
T
>
()[
i
];
}
return
os
.
str
();
}
std
::
string
PrintLodTensorIntType
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
auto
count
=
tensor
->
numel
();
if
(
start
<
0
||
end
>
count
)
{
VLOG
(
3
)
<<
"access violation"
;
return
"access violation"
;
}
std
::
ostringstream
os
;
for
(
int64_t
i
=
start
;
i
<
end
;
i
++
)
{
os
<<
":"
<<
static_cast
<
uint64_t
>
(
tensor
->
data
<
int64_t
>
()[
i
]);
}
return
os
.
str
();
}
std
::
string
PrintLodTensor
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
std
::
string
out_val
;
if
(
tensor
->
type
()
==
proto
::
VarType
::
FP32
)
{
out_val
=
PrintLodTensorType
<
float
>
(
tensor
,
start
,
end
);
}
else
if
(
tensor
->
type
()
==
proto
::
VarType
::
INT64
)
{
out_val
=
PrintLodTensorIntType
(
tensor
,
start
,
end
);
}
else
if
(
tensor
->
type
()
==
proto
::
VarType
::
FP64
)
{
out_val
=
PrintLodTensorType
<
double
>
(
tensor
,
start
,
end
);
}
else
{
out_val
=
"unsupported type"
;
}
return
out_val
;
}
std
::
pair
<
int64_t
,
int64_t
>
GetTensorBound
(
LoDTensor
*
tensor
,
int
index
)
{
auto
&
dims
=
tensor
->
dims
();
if
(
tensor
->
lod
().
size
()
!=
0
)
{
auto
&
lod
=
tensor
->
lod
()[
0
];
return
{
lod
[
index
]
*
dims
[
1
],
lod
[
index
+
1
]
*
dims
[
1
]};
}
else
{
return
{
index
*
dims
[
1
],
(
index
+
1
)
*
dims
[
1
]};
}
}
bool
CheckValidOutput
(
LoDTensor
*
tensor
,
size_t
batch_size
)
{
auto
&
dims
=
tensor
->
dims
();
if
(
dims
.
size
()
!=
2
)
return
false
;
if
(
tensor
->
lod
().
size
()
!=
0
)
{
auto
&
lod
=
tensor
->
lod
()[
0
];
if
(
lod
.
size
()
!=
batch_size
+
1
)
{
return
false
;
}
}
else
{
if
(
dims
[
0
]
!=
static_cast
<
int
>
(
batch_size
))
{
return
false
;
}
}
return
true
;
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/device_worker.h
浏览文件 @
9e80551d
...
...
@@ -45,6 +45,10 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
std
::
string
PrintLodTensor
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
);
std
::
pair
<
int64_t
,
int64_t
>
GetTensorBound
(
LoDTensor
*
tensor
,
int
index
);
bool
CheckValidOutput
(
LoDTensor
*
tensor
,
size_t
batch_size
);
class
FleetWrapper
;
#define SEC_LOG \
...
...
@@ -168,6 +172,8 @@ class HogwildWorker : public CPUWorkerBase {
virtual
void
Initialize
(
const
TrainerDesc
&
desc
);
virtual
void
TrainFiles
();
virtual
void
TrainFilesWithProfiler
();
virtual
void
SetNeedDump
(
bool
need_dump_field
);
virtual
void
SetChannelWriter
(
ChannelObject
<
std
::
string
>*
queue
);
virtual
void
PrintFetchVars
();
virtual
void
CreateDeviceResource
(
const
ProgramDesc
&
main_prog
);
virtual
void
BindingDataFeedMemory
();
...
...
@@ -177,6 +183,8 @@ class HogwildWorker : public CPUWorkerBase {
protected:
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
virtual
void
DumpParam
(
const
int
batch_id
);
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
bool
thread_barrier_
;
...
...
@@ -184,6 +192,12 @@ class HogwildWorker : public CPUWorkerBase {
HogwildWorkerParameter
param_
;
std
::
vector
<
std
::
string
>
skip_ops_
;
std
::
map
<
std
::
string
,
int
>
stat_var_name_map_
;
// dump params or grads for debug
bool
need_dump_param_
;
bool
need_dump_field_
;
std
::
vector
<
std
::
string
>
dump_param_
;
std
::
vector
<
std
::
string
>
dump_fields_
;
ChannelWriter
<
std
::
string
>
writer_
;
};
class
DownpourWorker
:
public
HogwildWorker
{
...
...
@@ -203,13 +217,11 @@ class DownpourWorker : public HogwildWorker {
void
PushGradients
();
void
CollectLabelInfo
(
size_t
table_id
);
void
AdjustInsWeight
();
void
DumpParam
();
void
CopySparseTable
();
void
CopyDenseTable
();
void
CopyDenseVars
();
std
::
string
PrintLodTensor
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
);
std
::
pair
<
int64_t
,
int64_t
>
GetTensorBound
(
LoDTensor
*
tensor
,
int
index
);
bool
CheckValidOutput
(
LoDTensor
*
tensor
,
size_t
batch_size
);
virtual
void
DumpParam
(
const
int
batch_id
);
DownpourWorkerParameter
param_
;
// copy table
CopyTableConfig
copy_table_config_
;
...
...
@@ -236,16 +248,11 @@ class DownpourWorker : public HogwildWorker {
std
::
vector
<::
std
::
future
<
int32_t
>>
push_sparse_status_
;
bool
dump_slot_
;
bool
need_to_push_dense_
;
bool
need_dump_field_
;
bool
need_dump_param_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
dense_grad_names_
;
float
scale_datanorm_
;
std
::
vector
<::
std
::
future
<
int32_t
>>
push_dense_status_
;
std
::
vector
<
std
::
string
>
dump_fields_
;
ChannelWriter
<
std
::
string
>
writer_
;
// skipped ops
std
::
vector
<
std
::
string
>
skip_ops_
;
std
::
vector
<
std
::
string
>
dump_param_
;
// just save the value in param_ for easy access
std
::
map
<
uint64_t
,
std
::
string
>
label_var_name_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
dense_value_names_
;
...
...
paddle/fluid/framework/device_worker_test.cc
浏览文件 @
9e80551d
...
...
@@ -12,13 +12,66 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/device_worker.h"
#include <gtest/gtest.h>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/trainer.h"
namespace
paddle
{
namespace
framework
{
TEST
()
{
// create hogwild device worker
TEST
(
LodTensor
,
PrintLodTensor
)
{
LoDTensor
tensor1
;
tensor1
.
Resize
({
2
});
tensor1
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tensor1
.
data
<
float
>
()[
0
]
=
0.2
;
tensor1
.
data
<
float
>
()[
1
]
=
0.5
;
std
::
string
res
=
PrintLodTensor
(
&
tensor1
,
-
1
,
2
);
ASSERT_EQ
(
res
,
"access violation"
);
res
=
PrintLodTensor
(
&
tensor1
,
0
,
2
);
ASSERT_EQ
(
res
,
":0.2:0.5"
);
LoDTensor
tensor2
;
tensor2
.
Resize
({
2
});
tensor2
.
mutable_data
<
int64_t
>
(
platform
::
CPUPlace
());
tensor2
.
data
<
int64_t
>
()[
0
]
=
1
;
tensor2
.
data
<
int64_t
>
()[
1
]
=
2
;
res
=
PrintLodTensor
(
&
tensor2
,
-
1
,
2
);
ASSERT_EQ
(
res
,
"access violation"
);
res
=
PrintLodTensor
(
&
tensor2
,
0
,
2
);
ASSERT_EQ
(
res
,
":1:2"
);
LoDTensor
tensor3
;
tensor3
.
Resize
({
2
});
tensor3
.
mutable_data
<
double
>
(
platform
::
CPUPlace
());
tensor3
.
data
<
double
>
()[
0
]
=
0.1
;
tensor3
.
data
<
double
>
()[
1
]
=
0.2
;
res
=
PrintLodTensor
(
&
tensor3
,
0
,
2
);
ASSERT_EQ
(
res
,
":0.1:0.2"
);
}
TEST
(
LodTensor
,
GetTensorBound
)
{
LoD
lod
{{
0
,
2
}};
LoDTensor
tensor
;
tensor
.
set_lod
(
lod
);
tensor
.
Resize
({
2
,
1
});
tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tensor
.
data
<
float
>
()[
0
]
=
0
;
tensor
.
data
<
float
>
()[
1
]
=
1
;
std
::
pair
<
int64_t
,
int64_t
>
res
=
GetTensorBound
(
&
tensor
,
0
);
ASSERT_EQ
(
res
.
first
,
0
);
ASSERT_EQ
(
res
.
second
,
2
);
}
TEST
(
LodTensor
,
CheckValidOutput
)
{
LoD
lod
{{
0
,
1
,
2
}};
LoDTensor
tensor
;
tensor
.
set_lod
(
lod
);
tensor
.
Resize
({
2
,
1
});
tensor
.
mutable_data
<
float
>
(
platform
::
CPUPlace
());
tensor
.
data
<
float
>
()[
0
]
=
0
;
tensor
.
data
<
float
>
()[
1
]
=
1
;
ASSERT_TRUE
(
CheckValidOutput
(
&
tensor
,
2
));
}
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
9e80551d
...
...
@@ -129,89 +129,19 @@ void DownpourWorker::SetNeedDump(bool need_dump_field) {
need_dump_field_
=
need_dump_field
;
}
template
<
typename
T
>
std
::
string
PrintLodTensorType
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
auto
count
=
tensor
->
numel
();
if
(
start
<
0
||
end
>
count
)
{
VLOG
(
3
)
<<
"access violation"
;
return
"access violation"
;
}
std
::
ostringstream
os
;
for
(
int64_t
i
=
start
;
i
<
end
;
i
++
)
{
os
<<
":"
<<
tensor
->
data
<
T
>
()[
i
];
}
return
os
.
str
();
}
std
::
string
PrintLodTensorIntType
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
auto
count
=
tensor
->
numel
();
if
(
start
<
0
||
end
>
count
)
{
VLOG
(
3
)
<<
"access violation"
;
return
"access violation"
;
}
void
DownpourWorker
::
DumpParam
(
const
int
batch_id
)
{
std
::
ostringstream
os
;
for
(
int64_t
i
=
start
;
i
<
end
;
i
++
)
{
os
<<
":"
<<
static_cast
<
uint64_t
>
(
tensor
->
data
<
int64_t
>
()[
i
]);
}
return
os
.
str
();
}
std
::
string
DownpourWorker
::
PrintLodTensor
(
LoDTensor
*
tensor
,
int64_t
start
,
int64_t
end
)
{
std
::
string
out_val
;
if
(
tensor
->
type
()
==
proto
::
VarType
::
FP32
)
{
out_val
=
PrintLodTensorType
<
float
>
(
tensor
,
start
,
end
);
}
else
if
(
tensor
->
type
()
==
proto
::
VarType
::
INT64
)
{
out_val
=
PrintLodTensorIntType
(
tensor
,
start
,
end
);
}
else
if
(
tensor
->
type
()
==
proto
::
VarType
::
FP64
)
{
out_val
=
PrintLodTensorType
<
double
>
(
tensor
,
start
,
end
);
}
else
{
out_val
=
"unsupported type"
;
}
return
out_val
;
}
std
::
pair
<
int64_t
,
int64_t
>
DownpourWorker
::
GetTensorBound
(
LoDTensor
*
tensor
,
int
index
)
{
auto
&
dims
=
tensor
->
dims
();
if
(
tensor
->
lod
().
size
()
!=
0
)
{
auto
&
lod
=
tensor
->
lod
()[
0
];
return
{
lod
[
index
]
*
dims
[
1
],
lod
[
index
+
1
]
*
dims
[
1
]};
}
else
{
return
{
index
*
dims
[
1
],
(
index
+
1
)
*
dims
[
1
]};
}
}
bool
DownpourWorker
::
CheckValidOutput
(
LoDTensor
*
tensor
,
size_t
batch_size
)
{
auto
&
dims
=
tensor
->
dims
();
if
(
dims
.
size
()
!=
2
)
return
false
;
if
(
tensor
->
lod
().
size
()
!=
0
)
{
auto
&
lod
=
tensor
->
lod
()[
0
];
if
(
lod
.
size
()
!=
batch_size
+
1
)
{
return
false
;
}
}
else
{
if
(
dims
[
0
]
!=
static_cast
<
int
>
(
batch_size
))
{
return
false
;
}
}
return
true
;
}
void
DownpourWorker
::
DumpParam
()
{
std
::
string
os
;
for
(
auto
&
param
:
dump_param_
)
{
os
.
clear
();
os
=
param
;
os
.
str
(
""
);
Variable
*
var
=
thread_scope_
->
FindVar
(
param
);
if
(
var
==
nullptr
)
{
continue
;
}
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
int64_t
len
=
tensor
->
numel
();
os
+=
PrintLodTensor
(
tensor
,
0
,
len
);
writer_
<<
os
;
os
<<
"("
<<
batch_id
<<
","
<<
param
<<
")"
<<
PrintLodTensor
(
tensor
,
0
,
len
);
writer_
<<
os
.
str
();
}
}
...
...
@@ -1022,7 +952,7 @@ void DownpourWorker::TrainFiles() {
writer_
<<
ars
[
i
];
}
if
(
need_dump_param_
&&
thread_id_
==
0
)
{
DumpParam
();
DumpParam
(
batch_cnt
);
}
}
...
...
paddle/fluid/framework/downpour_worker_opt.cc
浏览文件 @
9e80551d
...
...
@@ -564,7 +564,7 @@ void DownpourWorkerOpt::TrainFiles() {
writer_
<<
ars
[
i
];
}
if
(
need_dump_param_
&&
thread_id_
==
0
)
{
DumpParam
();
DumpParam
(
batch_cnt
);
}
}
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
9e80551d
...
...
@@ -31,6 +31,20 @@ void HogwildWorker::Initialize(const TrainerDesc &desc) {
}
use_cvm_
=
desc
.
use_cvm
();
thread_barrier_
=
desc
.
thread_barrier
();
dump_fields_
.
resize
(
desc
.
dump_fields_size
());
for
(
int
i
=
0
;
i
<
desc
.
dump_fields_size
();
++
i
)
{
dump_fields_
[
i
]
=
desc
.
dump_fields
(
i
);
}
need_dump_param_
=
false
;
dump_param_
.
resize
(
desc
.
dump_param_size
());
for
(
int
i
=
0
;
i
<
desc
.
dump_param_size
();
++
i
)
{
dump_param_
[
i
]
=
desc
.
dump_param
(
i
);
}
if
(
desc
.
dump_param_size
()
!=
0
)
{
need_dump_param_
=
true
;
}
}
void
HogwildWorker
::
CreateThreadOperators
(
const
ProgramDesc
&
program
)
{
...
...
@@ -143,6 +157,49 @@ void HogwildWorker::TrainFilesWithProfiler() {
op_total_time
[
i
]
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
}
if
(
need_dump_field_
)
{
size_t
batch_size
=
device_reader_
->
GetCurBatchSize
();
std
::
vector
<
std
::
string
>
ars
(
batch_size
);
for
(
auto
&
ar
:
ars
)
{
ar
.
clear
();
}
auto
&
ins_id_vec
=
device_reader_
->
GetInsIdVec
();
auto
&
ins_content_vec
=
device_reader_
->
GetInsContentVec
();
for
(
size_t
i
=
0
;
i
<
ins_id_vec
.
size
();
i
++
)
{
ars
[
i
]
+=
ins_id_vec
[
i
];
ars
[
i
]
=
ars
[
i
]
+
"
\t
"
+
ins_content_vec
[
i
];
}
for
(
auto
&
field
:
dump_fields_
)
{
Variable
*
var
=
thread_scope_
->
FindVar
(
field
);
if
(
var
==
nullptr
)
{
continue
;
}
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
if
(
!
CheckValidOutput
(
tensor
,
batch_size
))
{
continue
;
}
for
(
size_t
i
=
0
;
i
<
batch_size
;
++
i
)
{
auto
output_dim
=
tensor
->
dims
()[
1
];
std
::
string
output_dimstr
=
boost
::
lexical_cast
<
std
::
string
>
(
output_dim
);
ars
[
i
]
=
ars
[
i
]
+
"
\t
"
+
field
+
":"
+
output_dimstr
;
auto
bound
=
GetTensorBound
(
tensor
,
i
);
ars
[
i
]
+=
PrintLodTensor
(
tensor
,
bound
.
first
,
bound
.
second
);
}
}
// #pragma omp parallel for
for
(
size_t
i
=
0
;
i
<
ars
.
size
();
i
++
)
{
if
(
ars
[
i
].
length
()
==
0
)
{
continue
;
}
writer_
<<
ars
[
i
];
}
if
(
need_dump_param_
&&
thread_id_
==
0
)
{
DumpParam
(
batch_cnt
);
}
}
total_inst
+=
cur_batch
;
++
batch_cnt
;
PrintFetchVars
();
...
...
@@ -160,6 +217,11 @@ void HogwildWorker::TrainFilesWithProfiler() {
thread_scope_
->
DropKids
();
timeline
.
Start
();
}
if
(
need_dump_field_
)
{
writer_
.
Flush
();
}
#ifdef PADDLE_WITH_DISTRIBUTE
if
(
thread_barrier_
)
{
operators
::
distributed
::
Communicator
::
GetInstance
()
...
...
@@ -168,6 +230,10 @@ void HogwildWorker::TrainFilesWithProfiler() {
#endif
}
void
HogwildWorker
::
SetChannelWriter
(
ChannelObject
<
std
::
string
>
*
queue
)
{
writer_
.
Reset
(
queue
);
}
void
HogwildWorker
::
TrainFiles
()
{
platform
::
SetNumThreads
(
1
);
...
...
@@ -214,5 +280,25 @@ void HogwildWorker::PrintFetchVars() {
}
}
void
HogwildWorker
::
SetNeedDump
(
bool
need_dump_field
)
{
need_dump_field_
=
need_dump_field
;
}
void
HogwildWorker
::
DumpParam
(
const
int
batch_id
)
{
std
::
ostringstream
os
;
for
(
auto
&
param
:
dump_param_
)
{
os
.
str
(
""
);
Variable
*
var
=
thread_scope_
->
FindVar
(
param
);
if
(
var
==
nullptr
)
{
continue
;
}
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
int64_t
len
=
tensor
->
numel
();
os
<<
"("
<<
batch_id
<<
","
<<
param
<<
")"
<<
PrintLodTensor
(
tensor
,
0
,
len
);
writer_
<<
os
.
str
();
}
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
9e80551d
...
...
@@ -14,6 +14,7 @@ limitations under the License. */
#include <string>
#include <vector>
#include "io/fs.h"
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
...
...
@@ -25,12 +26,29 @@ namespace framework {
void
MultiTrainer
::
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
dataset
)
{
thread_num_
=
trainer_desc
.
thread_num
();
SetDataset
(
dataset
);
dump_fields_path_
=
trainer_desc
.
dump_fields_path
();
dump_converter_
=
trainer_desc
.
dump_converter
();
need_dump_field_
=
false
;
if
(
trainer_desc
.
dump_fields_size
()
!=
0
&&
dump_fields_path_
!=
""
)
{
need_dump_field_
=
true
;
}
if
(
need_dump_field_
)
{
auto
&
file_list
=
dataset
->
GetFileList
();
if
(
file_list
.
size
()
==
0
)
{
need_dump_field_
=
false
;
}
}
mpi_rank_
=
trainer_desc
.
mpi_rank
();
mpi_size_
=
trainer_desc
.
mpi_size
();
dump_file_num_
=
trainer_desc
.
dump_file_num
();
for
(
int
i
=
0
;
i
<
trainer_desc
.
downpour_param
().
stat_var_names_size
();
i
++
)
{
need_merge_var_names_
.
push_back
(
trainer_desc
.
downpour_param
().
stat_var_names
(
i
));
}
SetDataset
(
dataset
);
// get filelist from trainer_desc here
const
std
::
vector
<
paddle
::
framework
::
DataFeed
*>
readers
=
dataset
->
GetReaders
();
...
...
@@ -53,12 +71,66 @@ void MultiTrainer::Initialize(const TrainerDesc& trainer_desc,
workers_
[
i
]
->
Initialize
(
trainer_desc
);
workers_
[
i
]
->
SetDeviceIndex
(
i
);
workers_
[
i
]
->
SetDataFeed
(
readers
[
i
]);
workers_
[
i
]
->
SetNeedDump
(
need_dump_field_
);
}
// set debug here
SetDebug
(
trainer_desc
.
debug
());
}
void
MultiTrainer
::
DumpWork
(
int
tid
)
{
#ifdef _LINUX
int
err_no
=
0
;
std
::
string
path
=
string
::
format_string
(
"%s/part-%03d-%05d"
,
dump_fields_path_
.
c_str
(),
mpi_rank_
,
tid
);
std
::
shared_ptr
<
FILE
>
fp
=
fs_open_write
(
path
,
&
err_no
,
dump_converter_
);
while
(
1
)
{
std
::
string
out_str
;
if
(
!
queue_
->
Get
(
out_str
))
{
break
;
}
size_t
write_count
=
fwrite_unlocked
(
out_str
.
data
(),
1
,
out_str
.
length
(),
fp
.
get
());
if
(
write_count
!=
out_str
.
length
())
{
VLOG
(
3
)
<<
"dump text failed"
;
continue
;
}
write_count
=
fwrite_unlocked
(
"
\n
"
,
1
,
1
,
fp
.
get
());
if
(
write_count
!=
1
)
{
VLOG
(
3
)
<<
"dump text failed"
;
continue
;
}
}
#endif
}
void
MultiTrainer
::
InitDumpEnv
()
{
queue_
=
paddle
::
framework
::
MakeChannel
<
std
::
string
>
();
for
(
int
i
=
0
;
i
<
thread_num_
;
++
i
)
{
workers_
[
i
]
->
SetChannelWriter
(
queue_
.
get
());
}
dump_thread_num_
=
1
;
if
(
dump_file_num_
>
mpi_size_
)
{
dump_thread_num_
=
dump_file_num_
/
mpi_size_
;
if
(
dump_file_num_
%
mpi_size_
>
mpi_rank_
)
{
dump_thread_num_
+=
1
;
}
}
for
(
int
i
=
0
;
i
<
dump_thread_num_
;
i
++
)
{
dump_thread_
.
push_back
(
std
::
thread
(
std
::
bind
(
&
MultiTrainer
::
DumpWork
,
this
,
i
)));
}
}
void
MultiTrainer
::
FinalizeDumpEnv
()
{
queue_
->
Close
();
for
(
auto
&
th
:
dump_thread_
)
{
th
.
join
();
}
queue_
.
reset
();
}
// call only after all resources are set in current trainer
void
MultiTrainer
::
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
{
...
...
@@ -71,6 +143,13 @@ void MultiTrainer::InitTrainerEnv(const ProgramDesc& main_program,
}
}
void
MultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
if
(
need_dump_field_
)
{
InitDumpEnv
();
}
VLOG
(
3
)
<<
"init other env done."
;
}
Scope
*
MultiTrainer
::
GetWorkerScope
(
int
thread_id
)
{
return
workers_
[
thread_id
]
->
GetThreadScope
();
}
...
...
@@ -91,7 +170,12 @@ void MultiTrainer::Run() {
}
}
void
MultiTrainer
::
Finalize
()
{
root_scope_
->
DropKids
();
}
void
MultiTrainer
::
Finalize
()
{
if
(
need_dump_field_
)
{
FinalizeDumpEnv
();
}
root_scope_
->
DropKids
();
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/trainer.h
浏览文件 @
9e80551d
...
...
@@ -68,10 +68,13 @@ class MultiTrainer : public TrainerBase {
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
,
Dataset
*
data_set
);
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
);
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{}
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
;
virtual
void
Run
();
virtual
void
Finalize
();
virtual
void
FinalizeDumpEnv
();
virtual
void
InitDumpEnv
();
virtual
Scope
*
GetWorkerScope
(
int
thread_id
);
virtual
void
DumpWork
(
int
tid
);
protected:
int
thread_num_
;
...
...
@@ -79,6 +82,17 @@ class MultiTrainer : public TrainerBase {
std
::
vector
<
DataFeed
*>
readers_
;
std
::
vector
<
std
::
shared_ptr
<
DeviceWorker
>>
workers_
;
std
::
vector
<
std
::
string
>
need_merge_var_names_
;
bool
need_dump_field_
;
std
::
string
dump_fields_path_
;
std
::
string
dump_converter_
;
int
mpi_rank_
;
int
mpi_size_
;
int
dump_file_num_
;
std
::
vector
<
std
::
thread
>
dump_thread_
;
int
dump_thread_num_
;
std
::
shared_ptr
<
paddle
::
framework
::
ChannelObject
<
std
::
string
>>
queue_
;
};
class
DistMultiTrainer
:
public
MultiTrainer
{
...
...
@@ -98,16 +112,6 @@ class DistMultiTrainer : public MultiTrainer {
protected:
std
::
shared_ptr
<
paddle
::
framework
::
PullDenseWorker
>
pull_dense_worker_
;
std
::
vector
<
std
::
thread
>
dump_thread_
;
int
dump_thread_num_
;
std
::
shared_ptr
<
paddle
::
framework
::
ChannelObject
<
std
::
string
>>
queue_
;
bool
need_dump_field_
;
std
::
string
dump_fields_path_
;
std
::
string
dump_converter_
;
int
mpi_rank_
;
int
mpi_size_
;
int
dump_file_num_
;
};
#if defined(PADDLE_WITH_NCCL)
...
...
python/paddle/fluid/executor.py
浏览文件 @
9e80551d
...
...
@@ -919,7 +919,7 @@ class Executor(object):
def
_dump_debug_info
(
self
,
program
=
None
,
trainer
=
None
):
with
open
(
str
(
id
(
program
))
+
"_train_desc.prototxt"
,
"w"
)
as
fout
:
fout
.
write
(
str
(
trainer
))
if
program
.
_fleet_opt
:
if
program
.
_fleet_opt
and
"fleet_desc"
in
program
.
_fleet_opt
:
with
open
(
"fleet_desc.prototxt"
,
"w"
)
as
fout
:
fout
.
write
(
str
(
program
.
_fleet_opt
[
"fleet_desc"
]))
...
...
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/__init__.py
浏览文件 @
9e80551d
...
...
@@ -333,6 +333,12 @@ class DistributedTranspiler(Fleet):
self
.
_transpiler
.
get_pserver_programs
(
self
.
server_endpoints
()[
self
.
server_index
()])
def
_set_opt_info
(
self
,
opt_info
):
"""
this function saves the result from DistributedOptimizer.minimize()
"""
self
.
_opt_info
=
opt_info
fleet
=
DistributedTranspiler
()
...
...
@@ -358,9 +364,11 @@ class TranspilerOptimizer(DistributedOptimizer):
def
__init__
(
self
,
optimizer
,
strategy
=
None
):
super
(
TranspilerOptimizer
,
self
).
__init__
(
optimizer
,
strategy
)
self
.
opt_info
=
dict
()
if
strategy
:
if
isinstance
(
strategy
,
DistributeTranspilerConfig
)
or
isinstance
(
strategy
,
DistributedStrategy
):
if
isinstance
(
strategy
,
DistributeTranspilerConfig
):
self
.
_strategy
=
strategy
elif
isinstance
(
strategy
,
DistributedStrategy
):
self
.
_strategy
=
strategy
else
:
raise
TypeError
(
...
...
@@ -369,6 +377,14 @@ class TranspilerOptimizer(DistributedOptimizer):
else
:
self
.
_strategy
=
StrategyFactory
.
create_sync_strategy
()
if
isinstance
(
self
.
_strategy
,
DistributedStrategy
):
self
.
opt_info
=
self
.
_strategy
.
get_debug_opt
()
self
.
opt_info
[
"mpi_rank"
]
=
fleet
.
worker_index
()
self
.
opt_info
[
"mpi_size"
]
=
fleet
.
worker_num
()
self
.
opt_info
[
"trainer"
]
=
"MultiTrainer"
self
.
opt_info
[
"device_worker"
]
=
"Hogwild"
fleet
.
_set_opt_info
(
self
.
opt_info
)
def
backward
(
self
,
loss
,
startup_program
=
None
,
...
...
@@ -456,4 +472,5 @@ class TranspilerOptimizer(DistributedOptimizer):
optimize_ops
,
params_grads
=
self
.
_optimizer
.
minimize
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
)
fleet
.
_transpile
(
config
=
self
.
_strategy
)
loss
.
block
.
program
.
_fleet_opt
=
self
.
opt_info
return
optimize_ops
,
params_grads
python/paddle/fluid/incubate/fleet/parameter_server/distribute_transpiler/distributed_strategy.py
浏览文件 @
9e80551d
...
...
@@ -69,6 +69,23 @@ class DistributedStrategy(object):
self
.
_execute_strategy
.
num_threads
=
num_threads
if
num_threads
>
1
:
self
.
_build_strategy
.
reduce_strategy
=
fluid
.
BuildStrategy
.
ReduceStrategy
.
Reduce
self
.
debug_opt
=
None
def
set_debug_opt
(
self
,
opt_info
):
self
.
debug_opt
=
opt_info
def
get_debug_opt
(
self
):
opt_info
=
dict
()
if
self
.
debug_opt
is
not
None
and
isinstance
(
self
.
debug_opt
,
dict
):
opt_info
[
"dump_slot"
]
=
bool
(
self
.
debug_opt
.
get
(
"dump_slot"
,
0
))
opt_info
[
"dump_converter"
]
=
str
(
self
.
debug_opt
.
get
(
"dump_converter"
,
""
))
opt_info
[
"dump_fields"
]
=
self
.
debug_opt
.
get
(
"dump_fields"
,
[])
opt_info
[
"dump_file_num"
]
=
self
.
debug_opt
.
get
(
"dump_file_num"
,
16
)
opt_info
[
"dump_fields_path"
]
=
self
.
debug_opt
.
get
(
"dump_fields_path"
,
""
)
opt_info
[
"dump_param"
]
=
self
.
debug_opt
.
get
(
"dump_param"
,
[])
return
opt_info
def
get_program_config
(
self
):
return
self
.
_program_config
...
...
python/paddle/fluid/tests/unittests/dist_fleet_ctr.py
浏览文件 @
9e80551d
...
...
@@ -229,7 +229,7 @@ class TestDistCTR2x2(FleetDistRunnerBase):
fetch_list
=
[
self
.
avg_cost
],
fetch_info
=
[
"cost"
],
print_period
=
2
,
debug
=
False
)
debug
=
int
(
os
.
getenv
(
"Debug"
,
"0"
))
)
pass_time
=
time
.
time
()
-
pass_start
res_dict
=
dict
()
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_base.py
浏览文件 @
9e80551d
...
...
@@ -79,6 +79,17 @@ class FleetDistRunnerBase(object):
elif
args
.
mode
==
"geo"
:
self
.
strategy
=
StrategyFactory
.
create_geo_strategy
(
args
.
geo_sgd_need_push_nums
)
self
.
dump_param
=
os
.
getenv
(
"dump_param"
,
""
).
split
(
","
)
self
.
dump_fields
=
os
.
getenv
(
"dump_fields"
,
""
).
split
(
","
)
self
.
dump_fields_path
=
os
.
getenv
(
"dump_fields_path"
,
""
)
debug
=
int
(
os
.
getenv
(
"Debug"
,
"0"
))
if
debug
:
self
.
strategy
.
set_debug_opt
({
"dump_param"
:
self
.
dump_param
,
"dump_fields"
:
self
.
dump_fields
,
"dump_fields_path"
:
self
.
dump_fields_path
})
return
self
.
strategy
def
build_optimizer
(
self
,
avg_cost
,
strategy
):
...
...
python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py
浏览文件 @
9e80551d
...
...
@@ -16,6 +16,7 @@ from __future__ import print_function
import
os
import
unittest
import
tempfile
from
test_dist_fleet_base
import
TestFleetBase
...
...
@@ -99,7 +100,11 @@ class TestDistMnistAsyncDataset2x2(TestFleetBase):
"LD_LIBRARY_PATH"
:
os
.
getenv
(
"LD_LIBRARY_PATH"
,
""
),
"FLAGS_rpc_deadline"
:
"5000"
,
# 5sec to fail fast
"http_proxy"
:
""
,
"SAVE_MODEL"
:
"1"
"SAVE_MODEL"
:
"1"
,
"dump_param"
:
"concat_0.tmp_0"
,
"dump_fields"
:
"dnn-fc-3.tmp_0,dnn-fc-3.tmp_0@GRAD"
,
"dump_fields_path"
:
tempfile
.
mkdtemp
(),
"Debug"
:
"1"
}
required_envs
.
update
(
need_envs
)
...
...
python/paddle/fluid/tests/unittests/test_distributed_strategy.py
浏览文件 @
9e80551d
...
...
@@ -198,5 +198,30 @@ class TestHalfAsyncStrategy(unittest.TestCase):
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
half_async_config
)
class
TestDebugInfo
(
unittest
.
TestCase
):
def
test_debug_info
(
self
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
1
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
role
=
role_maker
.
UserDefinedRoleMaker
(
current_id
=
0
,
role
=
role_maker
.
Role
.
WORKER
,
worker_num
=
2
,
server_endpoints
=
[
"127.0.0.1:6001"
,
"127.0.0.1:6002"
])
fleet
.
init
(
role
)
optimizer
=
fluid
.
optimizer
.
SGD
(
0.0001
)
strategy
=
StrategyFactory
.
create_sync_strategy
()
strategy
.
set_debug_opt
({
"dump_param"
:
[
"fc_0.tmp_0"
],
"dump_fields"
:
[
"fc_0.tmp_0"
,
"fc_0.tmp_0@GRAD"
],
"dump_fields_path"
:
"dump_text/"
})
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
,
strategy
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_downpoursgd.py
浏览文件 @
9e80551d
...
...
@@ -29,6 +29,7 @@ from paddle.fluid.device_worker import DownpourSGD, DownpourSGDOPT
from
paddle.fluid.incubate.fleet.parameter_server.pslib.node
import
DownpourWorker
from
google.protobuf
import
text_format
import
paddle.fluid.incubate.fleet.parameter_server.pslib.ps_pb2
as
pslib
from
paddle.fluid.trainer_factory
import
TrainerFactory
class
TestListenAndServOp
(
unittest
.
TestCase
):
...
...
@@ -87,12 +88,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info
[
"program_id_to_worker"
]
=
{
program_id
:
worker
}
main_program
.
_fleet_opt
=
opt_info
trainer
=
DistMultiTrainer
(
)
trainer
=
TrainerFactory
().
_create_trainer
(
main_program
.
_fleet_opt
)
trainer
.
_set_program
(
main_program
)
device_worker
=
DownpourSGD
()
device_worker
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_set_device_worker
(
device_worker
)
trainer
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_gen_trainer_desc
()
cmd
=
"rm fleet_desc.prototxt*"
os
.
system
(
cmd
)
...
...
@@ -147,12 +144,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info
[
"program_id_to_worker"
]
=
{
program_id
:
worker
}
main_program
.
_fleet_opt
=
opt_info
trainer
=
DistMultiTrainer
(
)
trainer
=
TrainerFactory
().
_create_trainer
(
main_program
.
_fleet_opt
)
trainer
.
_set_program
(
main_program
)
device_worker
=
DownpourSGD
()
device_worker
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_set_device_worker
(
device_worker
)
trainer
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_gen_trainer_desc
()
cmd
=
"rm fleet_desc.prototxt*"
os
.
system
(
cmd
)
...
...
@@ -207,12 +200,8 @@ class TestListenAndServOp(unittest.TestCase):
opt_info
[
"program_id_to_worker"
]
=
{
program_id
:
worker
}
main_program
.
_fleet_opt
=
opt_info
trainer
=
DistMultiTrainer
(
)
trainer
=
TrainerFactory
().
_create_trainer
(
main_program
.
_fleet_opt
)
trainer
.
_set_program
(
main_program
)
device_worker
=
DownpourSGDOPT
()
device_worker
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_set_device_worker
(
device_worker
)
trainer
.
_set_fleet_desc
(
fleet_desc
)
trainer
.
_gen_trainer_desc
()
cmd
=
"rm fleet_desc.prototxt*"
os
.
system
(
cmd
)
...
...
python/paddle/fluid/trainer_factory.py
浏览文件 @
9e80551d
...
...
@@ -53,15 +53,9 @@ class TrainerFactory(object):
device_worker_class
=
opt_info
[
"device_worker"
]
trainer
=
globals
()[
trainer_class
]()
device_worker
=
globals
()[
device_worker_class
]()
if
"fleet_desc"
in
opt_info
:
device_worker
.
_set_fleet_desc
(
opt_info
[
"fleet_desc"
])
trainer
.
_set_fleet_desc
(
opt_info
[
"fleet_desc"
])
if
opt_info
.
get
(
"use_cvm"
)
is
not
None
:
trainer
.
_set_use_cvm
(
opt_info
[
"use_cvm"
])
if
opt_info
.
get
(
"no_cvm"
)
is
not
None
:
trainer
.
_set_no_cvm
(
opt_info
[
"no_cvm"
])
if
opt_info
.
get
(
"scale_datanorm"
)
is
not
None
:
trainer
.
_set_scale_datanorm
(
opt_info
[
"scale_datanorm"
])
# for debug tools
if
opt_info
is
not
None
:
if
opt_info
.
get
(
"dump_slot"
)
is
not
None
:
trainer
.
_set_dump_slot
(
opt_info
[
"dump_slot"
])
if
opt_info
.
get
(
"mpi_rank"
)
is
not
None
:
...
...
@@ -76,6 +70,18 @@ class TrainerFactory(object):
trainer
.
_set_dump_file_num
(
opt_info
[
"dump_file_num"
])
if
opt_info
.
get
(
"dump_converter"
)
is
not
None
:
trainer
.
_set_dump_converter
(
opt_info
[
"dump_converter"
])
if
opt_info
.
get
(
"dump_param"
)
is
not
None
:
trainer
.
_set_dump_param
(
opt_info
[
"dump_param"
])
if
"fleet_desc"
in
opt_info
:
device_worker
.
_set_fleet_desc
(
opt_info
[
"fleet_desc"
])
trainer
.
_set_fleet_desc
(
opt_info
[
"fleet_desc"
])
if
opt_info
.
get
(
"use_cvm"
)
is
not
None
:
trainer
.
_set_use_cvm
(
opt_info
[
"use_cvm"
])
if
opt_info
.
get
(
"no_cvm"
)
is
not
None
:
trainer
.
_set_no_cvm
(
opt_info
[
"no_cvm"
])
if
opt_info
.
get
(
"scale_datanorm"
)
is
not
None
:
trainer
.
_set_scale_datanorm
(
opt_info
[
"scale_datanorm"
])
if
opt_info
.
get
(
"adjust_ins_weight"
)
is
not
None
:
trainer
.
_set_adjust_ins_weight
(
opt_info
[
"adjust_ins_weight"
])
...
...
@@ -84,8 +90,6 @@ class TrainerFactory(object):
if
opt_info
.
get
(
"check_nan_var_names"
)
is
not
None
:
trainer
.
_set_check_nan_var_names
(
opt_info
[
"check_nan_var_names"
])
if
opt_info
.
get
(
"dump_param"
)
is
not
None
:
trainer
.
_set_dump_param
(
opt_info
[
"dump_param"
])
if
opt_info
.
get
(
"loss_names"
)
is
not
None
:
trainer
.
_set_loss_names
(
opt_info
[
"loss_names"
])
trainer
.
_set_device_worker
(
device_worker
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录