Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
c1650120
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c1650120
编写于
2月 02, 2019
作者:
D
dongdaxiang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine device_worker and trainer code
test=develop
上级
8a335b50
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
318 addition
and
132 deletion
+318
-132
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+2
-2
paddle/fluid/framework/async_executor.cc
paddle/fluid/framework/async_executor.cc
+3
-1
paddle/fluid/framework/async_executor.h
paddle/fluid/framework/async_executor.h
+0
-8
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+4
-6
paddle/fluid/framework/device_worker_factory.cc
paddle/fluid/framework/device_worker_factory.cc
+0
-21
paddle/fluid/framework/device_worker_factory.h
paddle/fluid/framework/device_worker_factory.h
+50
-0
paddle/fluid/framework/dist_multi_trainer.cc
paddle/fluid/framework/dist_multi_trainer.cc
+6
-1
paddle/fluid/framework/downpour_worker.cc
paddle/fluid/framework/downpour_worker.cc
+29
-13
paddle/fluid/framework/fleet/CMakeLists.txt
paddle/fluid/framework/fleet/CMakeLists.txt
+1
-1
paddle/fluid/framework/fleet/fleet_wrapper.cc
paddle/fluid/framework/fleet/fleet_wrapper.cc
+42
-5
paddle/fluid/framework/fleet/fleet_wrapper.h
paddle/fluid/framework/fleet/fleet_wrapper.h
+4
-2
paddle/fluid/framework/hogwild_worker.cc
paddle/fluid/framework/hogwild_worker.cc
+8
-6
paddle/fluid/framework/multi_trainer.cc
paddle/fluid/framework/multi_trainer.cc
+2
-0
paddle/fluid/framework/pull_dense_worker.cc
paddle/fluid/framework/pull_dense_worker.cc
+11
-0
paddle/fluid/framework/trainer.cc
paddle/fluid/framework/trainer.cc
+0
-2
paddle/fluid/framework/trainer.h
paddle/fluid/framework/trainer.h
+1
-1
paddle/fluid/framework/trainer_desc.proto
paddle/fluid/framework/trainer_desc.proto
+0
-1
paddle/fluid/framework/trainer_factory.cc
paddle/fluid/framework/trainer_factory.cc
+0
-19
paddle/fluid/framework/trainer_factory.h
paddle/fluid/framework/trainer_factory.h
+47
-0
python/paddle/fluid/async_executor.py
python/paddle/fluid/async_executor.py
+12
-8
python/paddle/fluid/device_worker.py
python/paddle/fluid/device_worker.py
+75
-0
python/paddle/fluid/trainer_desc.py
python/paddle/fluid/trainer_desc.py
+21
-35
未找到文件。
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
c1650120
...
...
@@ -203,7 +203,7 @@ if(WITH_PSLIB)
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table
fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper pslib_brpc pslib timer
)
else
()
...
...
@@ -212,7 +212,7 @@ else()
trainer_factory.cc trainer.cc device_worker.cc hogwild_worker.cc
downpour_worker.cc pull_dense_worker.cc device_worker_factory.cc
DEPS op_registry device_context scope framework_proto
trainer_desc_proto glog lod_rank_table
trainer_desc_proto glog lod_rank_table
fleet_wrapper
feed_fetch_method graph_to_program_pass async_executor_proto
variable_helper timer
)
endif
(
WITH_PSLIB
)
...
...
paddle/fluid/framework/async_executor.cc
浏览文件 @
c1650120
...
...
@@ -26,7 +26,9 @@ limitations under the License. */
#include "paddle/fluid/framework/lod_tensor_array.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_desc.pb.h"
#include "paddle/fluid/framework/trainer_factory.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/pybind/pybind.h"
...
...
@@ -161,7 +163,7 @@ void AsyncExecutor::RunFromFile(const ProgramDesc& main_program,
trainer
=
TrainerFactory
::
CreateTrainer
(
trainer_desc
.
class_name
());
// initialize trainer
trainer
->
Initialize
(
trainer_desc
);
// trainer->SetRoo
tScope(root_scope_);
trainer
->
Se
tScope
(
root_scope_
);
trainer
->
SetDebug
(
debug
);
// prepare training environment and helper environment
trainer
->
InitTrainerEnv
(
main_program
,
place_
);
...
...
paddle/fluid/framework/async_executor.h
浏览文件 @
c1650120
...
...
@@ -75,14 +75,6 @@ class AsyncExecutor {
void
InitModel
();
void
SaveModel
(
const
std
::
string
&
path
);
private:
void
CreateThreads
(
ExecutorThreadWorker
*
worker
,
const
ProgramDesc
&
main_program
,
const
std
::
shared_ptr
<
DataFeed
>&
reader
,
const
std
::
vector
<
std
::
string
>&
fetch_var_names
,
Scope
*
root_scope
,
const
int
thread_index
,
const
bool
debug
);
public:
std
::
shared_ptr
<
paddle
::
framework
::
FleetWrapper
>
fleet_ptr_
;
Scope
*
root_scope_
;
...
...
paddle/fluid/framework/device_worker.h
浏览文件 @
c1650120
...
...
@@ -39,12 +39,11 @@ namespace framework {
class
PullDenseWorker
{
public:
PullDenseWorker
()
{}
virtual
~
PullDenseWorker
()
{}
virtual
void
Initialize
(
const
TrainerDesc
&
param
);
int
Start
();
void
Stop
();
void
SetScope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
Set
Root
Scope
(
Scope
*
scope
)
{
root_scope_
=
scope
;
}
void
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
);
void
ResetThreadVersion
(
uint64_t
table_id
);
void
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
);
...
...
@@ -57,6 +56,7 @@ class PullDenseWorker {
}
private:
PullDenseWorker
()
:
root_scope_
(
NULL
)
{}
void
Run
();
bool
CheckUpdateParam
(
uint64_t
table_id
);
...
...
@@ -137,20 +137,18 @@ class HogwildWorker : public CPUWorkerBase {
protected:
void
CreateThreadOperators
(
const
ProgramDesc
&
program
);
void
CreateThreadScope
(
const
ProgramDesc
&
program
);
std
::
shared_ptr
<
DataFeed
>
thread_reader_
;
std
::
vector
<
std
::
string
>
op_names_
;
std
::
vector
<
OperatorBase
*>
ops_
;
Scope
*
thread_scope_
;
std
::
vector
<
std
::
string
>
fetch_var_names_
;
std
::
vector
<
std
::
vector
<
float
>>
fetch_values_
;
platform
::
Place
place_
;
};
class
DownpourWorker
:
public
HogwildWorker
{
public:
DownpourWorker
()
{}
virtual
~
DownpourWorker
()
{}
virtual
void
Initilize
(
const
TrainerDesc
&
desc
);
virtual
void
Initi
a
lize
(
const
TrainerDesc
&
desc
);
virtual
void
TrainFiles
();
protected:
...
...
@@ -163,7 +161,7 @@ class DownpourWorker : public HogwildWorker {
private:
DownpourWorkerParameter
param_
;
// just save the value in param_ for easy access
std
::
string
label_var_name_
;
std
::
map
<
uint64_t
,
std
::
string
>
label_var_name_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_key_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_value_names_
;
std
::
map
<
uint64_t
,
std
::
vector
<
std
::
string
>>
sparse_grad_names_
;
...
...
paddle/fluid/framework/device_worker_factory.cc
浏览文件 @
c1650120
...
...
@@ -19,25 +19,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
std
::
string
DeviceWorkerFactory
::
DeviceWorkerTypeList
()
{
std
::
string
device_worker_types
;
...
...
@@ -59,7 +40,5 @@ std::shared_ptr<DeviceWorker> DeviceWorkerFactory::CreateDeviceWorker(
return
g_device_worker_map
[
device_worker_class
]();
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/device_worker_factory.h
0 → 100644
浏览文件 @
c1650120
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/device_worker.h"
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
DeviceWorker
>
(
*
Createdevice_workerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
Createdevice_workerFunction
>
device_workerMap
;
device_workerMap
g_device_worker_map
;
#define REGISTER_DEVICE_WORKER_CLASS(device_worker_class) \
namespace { \
std::shared_ptr<DeviceWorker> Creator_##device_worker_class() { \
return std::shared_ptr<DeviceWorker>(new device_worker_class); \
} \
class __Registerer_##device_worker_class { \
public: \
__Registerer_##device_worker_class() { \
g_device_worker_map[#device_worker_class] = \
&Creator_##device_worker_class; \
} \
}; \
__Registerer_##device_worker_class g_registerer_##device_worker_class; \
} // namespace
class
DeviceWorkerFactory
{
public:
static
std
::
string
DeviceWorkerTypeList
();
static
std
::
shared_ptr
<
DeviceWorker
>
CreateDeviceWorker
(
std
::
string
device_worker_class
);
};
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/dist_multi_trainer.cc
浏览文件 @
c1650120
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -34,6 +35,7 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
workers_
[
i
]
->
SetDeviceIndex
(
i
);
readers_
[
i
]
->
Init
(
trainer_desc
.
data_desc
());
workers_
[
i
]
->
SetDataFeed
(
readers_
[
i
]);
workers_
[
i
]
->
Initialize
(
trainer_desc
);
}
std
::
vector
<
std
::
string
>
filelist_vec
;
...
...
@@ -41,13 +43,15 @@ void DistMultiTrainer::Initialize(const TrainerDesc& trainer_desc) {
filelist_vec
.
push_back
(
trainer_desc
.
filelist
(
i
));
}
readers_
[
0
]
->
SetFileList
(
filelist_vec
);
fleet_ptr_
=
FleetWrapper
::
GetInstance
();
pull_dense_worker_
=
PullDenseWorker
::
GetInstance
();
pull_dense_worker_
->
Initialize
(
trainer_desc
);
}
void
DistMultiTrainer
::
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
{
pull_dense_worker_
->
SetScope
(
root_scope_
);
pull_dense_worker_
->
Set
Root
Scope
(
root_scope_
);
pull_dense_worker_
->
Start
();
}
...
...
@@ -58,5 +62,6 @@ void DistMultiTrainer::Finalize() {
pull_dense_worker_
->
Stop
();
}
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/downpour_worker.cc
浏览文件 @
c1650120
/* Copyright (c) 201
8
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 201
9
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -13,14 +13,14 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace
paddle
{
namespace
framework
{
void
DownpourWorker
::
Initilize
(
const
TrainerDesc
&
desc
)
{
void
DownpourWorker
::
Initi
a
lize
(
const
TrainerDesc
&
desc
)
{
param_
=
desc
.
downpour_param
();
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
...
...
@@ -37,6 +37,7 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for
(
size_t
j
=
0
;
j
<
table
.
sparse_grad_name_size
();
++
j
)
{
sparse_grad_names_
[
table_id
][
j
]
=
table
.
sparse_grad_name
(
j
);
}
label_var_name_
[
table_id
]
=
table
.
label_var_name
();
}
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
...
...
@@ -56,15 +57,18 @@ void DownpourWorker::Initilize(const TrainerDesc& desc) {
for
(
size_t
i
=
0
;
i
<
param_
.
skip_ops_size
();
++
i
)
{
skip_ops_
[
i
]
=
param_
.
skip_ops
(
i
);
}
label_var_name_
=
param_
.
label_var_name
();
skip_ops_
.
resize
(
param_
.
skip_ops_size
());
}
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_id
)
{
void
DownpourWorker
::
CollectLabelInfo
(
size_t
table_idx
)
{
auto
table
=
param_
.
sparse_table
(
table_idx
);
uint64_t
table_id
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
table_idx
).
table_id
());
auto
&
feature
=
features_
[
table_id
];
auto
&
feature_label
=
feature_labels_
[
table_id
];
feature_label
.
resize
(
feature
.
size
());
Variable
*
var
=
thread_scope_
->
FindVar
(
label_var_name_
);
Variable
*
var
=
thread_scope_
->
FindVar
(
label_var_name_
[
table_id
]
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
int64_t
*
label_ptr
=
tensor
->
data
<
int64_t
>
();
...
...
@@ -75,13 +79,14 @@ void DownpourWorker::CollectLabelInfo(size_t table_id) {
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
int
fea_idx
=
0
;
// tensor->lod()[0].size() == batch_size + 1
for
(
auto
ins_idx
=
0u
;
ins_idx
<
tensor
->
lod
()[
0
].
size
()
-
1
;
++
ins
_idx
)
{
for
(;
fea_idx
<
tensor
->
lod
()[
0
][
ins
_idx
];
++
fea_idx
)
{
for
(
auto
lod_idx
=
1u
;
lod_idx
<
tensor
->
lod
()[
0
].
size
();
++
lod
_idx
)
{
for
(;
fea_idx
<
tensor
->
lod
()[
0
][
lod
_idx
];
++
fea_idx
)
{
// should be skipped feasign defined in protobuf
if
(
ids
[
fea_idx
]
==
0u
)
{
continue
;
}
feature_label
[
global_index
++
]
=
static_cast
<
float
>
(
label_ptr
[
ins_idx
]);
feature_label
[
global_index
++
]
=
static_cast
<
float
>
(
label_ptr
[
lod_idx
-
1
]);
}
}
}
...
...
@@ -128,10 +133,10 @@ void DownpourWorker::FillSparseValue(size_t table_idx) {
void
DownpourWorker
::
TrainFiles
()
{
platform
::
SetNumThreads
(
1
);
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
int
batch_cnt
=
0
;
int
cur_batch
;
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
// pull sparse here
for
(
size_t
i
=
0
;
i
<
param_
.
sparse_table_size
();
++
i
)
{
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
sparse_table
(
i
).
table_id
());
...
...
@@ -144,7 +149,16 @@ void DownpourWorker::TrainFiles() {
// do computation here
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
bool
need_skip
=
false
;
for
(
auto
t
=
0u
;
t
<
skip_ops_
.
size
();
++
t
)
{
if
(
op
->
Type
().
find
(
skip_ops_
[
t
])
!=
std
::
string
::
npos
)
{
need_skip
=
true
;
break
;
}
}
if
(
!
need_skip
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
}
}
// push gradients here
...
...
@@ -198,10 +212,12 @@ void DownpourWorker::TrainFiles() {
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
pull_dense_worker_
->
IncreaseThreadVersion
(
thread_id_
,
tid
);
}
thread_scope_
->
DropKids
();
++
batch_cnt
;
}
}
REGISTER_DEVICE_WORKER_CLASS
(
DownpourWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/fleet/CMakeLists.txt
浏览文件 @
c1650120
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc
)
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc
DEPS pslib_brpc pslib
)
paddle/fluid/framework/fleet/fleet_wrapper.cc
浏览文件 @
c1650120
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
...
...
@@ -19,10 +33,16 @@ namespace framework {
const
uint32_t
MAX_FEASIGN_NUM
=
1024
*
100
*
100
;
std
::
shared_ptr
<
FleetWrapper
>
FleetWrapper
::
s_instance_
=
NULL
;
bool
FleetWrapper
::
is_initialized_
=
false
;
#ifdef PADDLE_WITH_PSLIB
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
FleetWrapper
::
pslib_ptr_
=
NULL
;
#endif
void
FleetWrapper
::
InitServer
(
const
std
::
string
&
dist_desc
,
int
index
)
{
#ifdef PADDLE_WITH_PSLIB
if
(
!
is_initialized_
)
{
LOG
(
WARNING
)
<<
"Going to init server"
;
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
pslib_ptr_
->
init_server
(
dist_desc
,
index
);
...
...
@@ -38,6 +58,7 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
int
node_num
,
int
index
)
{
#ifdef PADDLE_WITH_PSLIB
if
(
!
is_initialized_
)
{
LOG
(
WARNING
)
<<
"Going to init server"
;
pslib_ptr_
=
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
(
new
paddle
::
distributed
::
PSlib
());
pslib_ptr_
->
init_worker
(
dist_desc
,
...
...
@@ -52,12 +73,14 @@ void FleetWrapper::InitWorker(const std::string& dist_desc,
void
FleetWrapper
::
StopServer
()
{
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to stop server"
;
pslib_ptr_
->
stop_server
();
#endif
}
uint64_t
FleetWrapper
::
RunServer
()
{
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to run server"
;
return
pslib_ptr_
->
run_server
();
#else
return
0
;
...
...
@@ -67,6 +90,7 @@ uint64_t FleetWrapper::RunServer() {
void
FleetWrapper
::
GatherServers
(
const
std
::
vector
<
uint64_t
>&
host_sign_list
,
int
node_num
)
{
#ifdef PADDLE_WITH_PSLIB
LOG
(
WARNING
)
<<
"Going to gather server ips"
;
pslib_ptr_
->
gather_servers
(
const_cast
<
uint64_t
*>
(
host_sign_list
.
data
()),
node_num
);
#endif
...
...
@@ -122,13 +146,13 @@ void FleetWrapper::PullDenseVarsAsync(
std
::
vector
<::
std
::
future
<
int32_t
>>*
pull_dense_status
)
{
#ifdef PADDLE_WITH_PSLIB
std
::
vector
<
paddle
::
ps
::
Region
>
regions
;
regions
.
res
erv
e
(
var_names
.
size
());
for
(
auto
&
t
:
var_names
)
{
Variable
*
var
=
scope
.
FindVar
(
t
);
regions
.
res
iz
e
(
var_names
.
size
());
for
(
auto
i
=
0u
;
i
<
var_names
.
size
();
++
i
)
{
Variable
*
var
=
scope
.
FindVar
(
var_names
[
i
]
);
LoDTensor
*
tensor
=
var
->
GetMutable
<
LoDTensor
>
();
float
*
w
=
tensor
->
data
<
float
>
();
paddle
::
ps
::
Region
reg
(
w
,
tensor
->
numel
());
regions
.
emplace_back
(
std
::
move
(
reg
)
);
regions
[
i
]
=
std
::
move
(
reg
);
}
auto
status
=
pslib_ptr_
->
_worker_ptr
->
pull_dense
(
regions
.
data
(),
regions
.
size
(),
tid
);
...
...
@@ -186,7 +210,10 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
int
offset
=
2
;
uint64_t
fea_idx
=
0u
;
for
(
size_t
i
=
0
;
i
<
sparse_key_names
.
size
();
++
i
)
{
Variable
*
g_var
=
scope
.
FindVar
(
sparse_key_names
[
i
]);
LOG
(
WARNING
)
<<
"sparse key names["
<<
i
<<
"]: "
<<
sparse_key_names
[
i
];
LOG
(
WARNING
)
<<
"sparse grad names["
<<
i
<<
"]: "
<<
sparse_grad_names
[
i
];
Variable
*
g_var
=
scope
.
FindVar
(
sparse_grad_names
[
i
]);
CHECK
(
g_var
!=
nullptr
)
<<
"var["
<<
sparse_grad_names
[
i
]
<<
"] not found"
;
LoDTensor
*
g_tensor
=
g_var
->
GetMutable
<
LoDTensor
>
();
if
(
g_tensor
==
NULL
)
{
LOG
(
ERROR
)
<<
"var["
<<
sparse_key_names
[
i
]
<<
"] not found"
;
...
...
@@ -201,16 +228,26 @@ void FleetWrapper::PushSparseVarsWithLabelAsync(
exit
(
-
1
);
}
int
len
=
tensor
->
numel
();
LOG
(
WARNING
)
<<
" tensor len: "
<<
len
;
int64_t
*
ids
=
tensor
->
data
<
int64_t
>
();
push_values
->
resize
(
fea_keys
.
size
()
+
1
);
for
(
auto
&
t
:
*
push_values
)
{
t
.
resize
(
emb_dim
+
offset
);
}
for
(
auto
id_idx
=
0u
;
id_idx
<
len
;
++
id_idx
)
{
if
(
ids
[
id_idx
]
==
0
)
{
g
+=
emb_dim
;
continue
;
}
LOG
(
WARNING
)
<<
"going to memcpy"
;
memcpy
((
*
push_values
)[
fea_idx
].
data
()
+
offset
,
g
,
sizeof
(
float
)
*
emb_dim
);
LOG
(
WARNING
)
<<
"show"
;
(
*
push_values
)[
fea_idx
][
0
]
=
1.0
f
;
LOG
(
WARNING
)
<<
"click"
;
(
*
push_values
)[
fea_idx
][
1
]
=
static_cast
<
float
>
(
fea_labels
[
fea_idx
]);
LOG
(
WARNING
)
<<
"offset"
;
g
+=
emb_dim
;
fea_idx
++
;
}
...
...
paddle/fluid/framework/fleet/fleet_wrapper.h
浏览文件 @
c1650120
...
...
@@ -47,7 +47,6 @@ namespace framework {
class
FleetWrapper
{
public:
FleetWrapper
()
{}
virtual
~
FleetWrapper
()
{}
// Pull sparse variables from server in Sync mode
...
...
@@ -122,8 +121,11 @@ class FleetWrapper {
static
std
::
shared_ptr
<
paddle
::
distributed
::
PSlib
>
pslib_ptr_
;
#endif
private:
FleetWrapper
()
{}
protected:
bool
is_initialized_
;
static
bool
is_initialized_
;
DISABLE_COPY_AND_ASSIGN
(
FleetWrapper
);
};
...
...
paddle/fluid/framework/hogwild_worker.cc
浏览文件 @
c1650120
...
...
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/platform/cpu_helper.h"
namespace
paddle
{
...
...
@@ -50,9 +51,9 @@ void HogwildWorker::CreateThreadScope(const ProgramDesc& program) {
void
HogwildWorker
::
BindingDataFeedMemory
()
{
const
std
::
vector
<
std
::
string
>&
input_feed
=
thread
_reader_
->
GetUseSlotAlias
();
device
_reader_
->
GetUseSlotAlias
();
for
(
auto
name
:
input_feed
)
{
thread
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
device
_reader_
->
AddFeedVar
(
thread_scope_
->
Var
(
name
),
name
);
}
}
...
...
@@ -63,7 +64,7 @@ void HogwildWorker::CreateDeviceResource(const ProgramDesc& main_prog) {
void
HogwildWorker
::
TrainFilesWithProfiler
()
{
platform
::
SetNumThreads
(
1
);
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
std
::
vector
<
double
>
op_total_time
;
std
::
vector
<
std
::
string
>
op_name
;
for
(
auto
&
op
:
ops_
)
{
...
...
@@ -79,7 +80,7 @@ void HogwildWorker::TrainFilesWithProfiler() {
int
cur_batch
;
int
batch_cnt
=
0
;
timeline
.
Start
();
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
timeline
.
Pause
();
read_time
+=
timeline
.
ElapsedSec
();
total_time
+=
timeline
.
ElapsedSec
();
...
...
@@ -115,10 +116,10 @@ void HogwildWorker::TrainFiles() {
platform
::
SetNumThreads
(
1
);
// how to accumulate fetched values here
thread
_reader_
->
Start
();
device
_reader_
->
Start
();
int
cur_batch
;
int
batch_cnt
=
0
;
while
((
cur_batch
=
thread
_reader_
->
Next
())
>
0
)
{
while
((
cur_batch
=
device
_reader_
->
Next
())
>
0
)
{
for
(
auto
&
op
:
ops_
)
{
op
->
Run
(
*
thread_scope_
,
place_
);
}
...
...
@@ -128,5 +129,6 @@ void HogwildWorker::TrainFiles() {
}
}
REGISTER_DEVICE_WORKER_CLASS
(
HogwildWorker
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/multi_trainer.cc
浏览文件 @
c1650120
...
...
@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/data_feed_factory.h"
#include "paddle/fluid/framework/device_worker_factory.h"
#include "paddle/fluid/framework/trainer.h"
#include "paddle/fluid/framework/trainer_factory.h"
namespace
paddle
{
namespace
framework
{
...
...
@@ -65,5 +66,6 @@ void MultiTrainer::Finalize() {
}
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/pull_dense_worker.cc
浏览文件 @
c1650120
...
...
@@ -20,24 +20,31 @@ namespace framework {
std
::
shared_ptr
<
PullDenseWorker
>
PullDenseWorker
::
s_instance_
=
NULL
;
void
PullDenseWorker
::
Initialize
(
const
TrainerDesc
&
param
)
{
LOG
(
WARNING
)
<<
"going to initialize pull dense worker"
;
running_
=
false
;
param_
=
param
.
pull_dense_param
();
threshold_
=
param_
.
threshold
();
thread_num_
=
param_
.
device_num
();
sleep_time_ms_
=
param_
.
sleep_time_ms
();
LOG
(
WARNING
)
<<
"dense table size: "
<<
param_
.
dense_table_size
();
LOG
(
WARNING
)
<<
"thread num: "
<<
thread_num_
;
for
(
size_t
i
=
0
;
i
<
param_
.
dense_table_size
();
++
i
)
{
// setup dense variables for each table
int
var_num
=
param_
.
dense_table
(
i
).
dense_value_name_size
();
LOG
(
WARNING
)
<<
"var num: "
<<
var_num
;
uint64_t
tid
=
static_cast
<
uint64_t
>
(
param_
.
dense_table
(
i
).
table_id
());
dense_value_names_
[
tid
].
resize
(
var_num
);
for
(
int
j
=
0
;
j
<
var_num
;
++
j
)
{
dense_value_names_
[
tid
][
j
]
=
param_
.
dense_table
(
i
).
dense_value_name
(
j
);
LOG
(
WARNING
)
<<
"dense value names "
<<
j
<<
" "
<<
dense_value_names_
[
tid
][
j
];
}
// setup training version for each table
training_versions_
[
tid
].
resize
(
thread_num_
,
0
);
last_versions_
[
tid
]
=
0
;
current_version_
[
tid
]
=
0
;
}
LOG
(
WARNING
)
<<
"initialize pull dense worker done."
;
}
void
PullDenseWorker
::
Wait
(
std
::
vector
<::
std
::
future
<
int32_t
>>*
status_vec
)
{
...
...
@@ -56,6 +63,7 @@ void PullDenseWorker::Wait(std::vector<::std::future<int32_t>>* status_vec) {
<<
" Times"
;
exit
(
-
1
);
}
status_vec
->
resize
(
0
);
}
void
PullDenseWorker
::
Stop
()
{
...
...
@@ -90,7 +98,10 @@ void PullDenseWorker::Run() {
}
void
PullDenseWorker
::
IncreaseThreadVersion
(
int
thread_id
,
uint64_t
table_id
)
{
LOG
(
WARNING
)
<<
"increase thread version input: "
<<
thread_id
<<
" table id "
<<
table_id
;
std
::
lock_guard
<
std
::
mutex
>
lock
(
mutex_for_version_
);
LOG
(
WARNING
)
<<
"going to increase"
;
training_versions_
[
table_id
][
thread_id
]
++
;
}
...
...
paddle/fluid/framework/trainer.cc
浏览文件 @
c1650120
...
...
@@ -19,7 +19,5 @@ namespace framework {
void
TrainerBase
::
SetScope
(
Scope
*
root_scope
)
{
root_scope_
=
root_scope
;
}
void
TrainerBase
::
Initialize
(
const
TrainerDesc
&
trainer_desc
)
{
return
;
}
}
// end namespace framework
}
// end namespace paddle
paddle/fluid/framework/trainer.h
浏览文件 @
c1650120
...
...
@@ -39,8 +39,8 @@ class TrainerBase {
virtual
~
TrainerBase
()
{}
// model memory are hosted in root_scope
void
SetScope
(
Scope
*
root_scope
);
void
Initialize
(
const
TrainerDesc
&
trainer_desc
);
void
SetDebug
(
const
bool
debug
)
{
debug_
=
debug
;
}
virtual
void
Initialize
(
const
TrainerDesc
&
trainer_desc
)
=
0
;
virtual
void
InitTrainerEnv
(
const
ProgramDesc
&
main_program
,
const
platform
::
Place
&
place
)
=
0
;
virtual
void
InitOtherEnv
(
const
ProgramDesc
&
main_program
)
=
0
;
...
...
paddle/fluid/framework/trainer_desc.proto
浏览文件 @
c1650120
...
...
@@ -43,7 +43,6 @@ message DownpourWorkerParameter {
repeated
TableParameter
sparse_table
=
1
;
repeated
TableParameter
dense_table
=
2
;
repeated
string
skip_ops
=
3
;
optional
string
label_var_name
=
4
;
}
message
PullDenseWorkerParameter
{
...
...
paddle/fluid/framework/trainer_factory.cc
浏览文件 @
c1650120
...
...
@@ -21,23 +21,6 @@ limitations under the License. */
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
std
::
string
TrainerFactory
::
TrainerTypeList
()
{
std
::
string
trainer_types
;
...
...
@@ -58,7 +41,5 @@ std::shared_ptr<TrainerBase> TrainerFactory::CreateTrainer(
return
g_trainer_map
[
trainer_class
]();
}
REGISTER_TRAINER_CLASS
(
MultiTrainer
);
REGISTER_TRAINER_CLASS
(
DistMultiTrainer
);
}
// namespace framework
}
// namespace paddle
paddle/fluid/framework/trainer_factory.h
0 → 100644
浏览文件 @
c1650120
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/framework/trainer.h"
namespace
paddle
{
namespace
framework
{
typedef
std
::
shared_ptr
<
TrainerBase
>
(
*
CreatetrainerFunction
)();
typedef
std
::
unordered_map
<
std
::
string
,
CreatetrainerFunction
>
trainerMap
;
trainerMap
g_trainer_map
;
#define REGISTER_TRAINER_CLASS(trainer_class) \
namespace { \
std::shared_ptr<TrainerBase> Creator_##trainer_class() { \
return std::shared_ptr<TrainerBase>(new trainer_class); \
} \
class __Registerer_##trainer_class { \
public: \
__Registerer_##trainer_class() { \
g_trainer_map[#trainer_class] = &Creator_##trainer_class; \
} \
}; \
__Registerer_##trainer_class g_registerer_##trainer_class; \
} // namespace
class
TrainerFactory
{
public:
static
std
::
string
TrainerTypeList
();
static
std
::
shared_ptr
<
TrainerBase
>
CreateTrainer
(
std
::
string
trainer_class
);
};
}
// namespace framework
}
// namespace paddle
python/paddle/fluid/async_executor.py
浏览文件 @
c1650120
...
...
@@ -110,15 +110,17 @@ class AsyncExecutor(object):
is_local
=
self
.
instance
==
None
trainer
=
None
if
is_local
:
trainer
=
MultiTrainer
(
data_feed
=
data_feed
,
worker
=
"Hogwild"
)
trainer
=
MultiTrainer
()
else
:
trainer
=
DistMultiTrainer
(
data_feed
,
worker
=
"Downpour"
,
fleet_desc
=
self
.
dist_desc
)
# define a trainer and a device_worker here
trainer
=
DistMultiTrainer
()
trainer
.
gen_trainer_desc
(
dataset
=
data_feed
,
fleet_desc
=
self
.
dist_desc
,
worker
=
"downpour"
)
trainer
.
set_thread
(
thread_num
)
trainer
.
set_filelist
(
filelist
)
trainer
.
set_data_feed
(
data_feed
)
with
open
(
"trainer_desc.proto"
,
"w"
)
as
fout
:
fout
.
write
(
trainer
.
_desc
())
# define a trainer and a device_worker here
self
.
executor
.
run_from_files
(
program_desc
,
trainer
.
_desc
(),
debug
)
'''
...
...
@@ -284,8 +286,9 @@ class AsyncExecutor(object):
raise
ValueError
(
'instance is None, please run config_distributed_nodes init instance'
)
self
.
init_desc
=
init_desc
self
.
executor
.
init_server
(
dist_desc
,
self
.
instance
.
_rankid
)
self
.
dist_desc_str
=
text_format
.
MessageToString
(
dist_desc
)
self
.
dist_desc
=
dist_desc
self
.
executor
.
init_server
(
self
.
dist_desc_str
,
self
.
instance
.
_rankid
)
ip
=
self
.
executor
.
start_server
()
self
.
instance
.
set_ip
(
ip
)
self
.
instance
.
barrier_all
()
#wait all server start
...
...
@@ -306,6 +309,7 @@ class AsyncExecutor(object):
'instance is None, please run config_distributed_nodes init instance'
)
self
.
dist_desc_str
=
text_format
.
MessageToString
(
dist_desc
)
self
.
dist_desc
=
dist_desc
place
=
core
.
CPUPlace
()
executor
=
Executor
(
place
)
...
...
@@ -313,7 +317,7 @@ class AsyncExecutor(object):
self
.
instance
.
barrier_all
()
#wait all server start
ips
=
self
.
instance
.
gather_ips
()
self
.
executor
.
init_worker
(
dist_desc
,
ips
,
self
.
executor
.
init_worker
(
self
.
dist_desc_str
,
ips
,
self
.
instance
.
get_node_cnt
(),
self
.
instance
.
_rankid
)
self
.
instance
.
barrier_all
()
#wait all worker start
...
...
python/paddle/fluid/device_worker.py
0 → 100644
浏览文件 @
c1650120
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class
DeviceWorker
(
object
):
def
__init__
(
self
):
pass
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
pass
class
Hogwild
(
DeviceWorker
):
def
__init__
(
self
):
super
(
Hogwild
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
trainer_desc
.
device_worker_name
=
"HogwildWorker"
class
Downpour
(
DeviceWorker
):
def
__init__
(
self
):
super
(
Downpour
,
self
).
__init__
()
def
gen_worker_desc
(
self
,
trainer_desc
,
fleet_desc
):
trainer_desc
.
device_worker_name
=
"DownpourWorker"
pull_thread
=
trainer_desc
.
pull_dense_param
pull_thread
.
device_num
=
trainer_desc
.
thread_num
dense_table
=
pull_thread
.
dense_table
.
add
()
dense_table
.
dense_value_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_variable_name
)
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
[
0
].
table_id
downpour
=
trainer_desc
.
downpour_param
sparse_table
=
downpour
.
sparse_table
.
add
()
sparse_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
table_id
sparse_table
.
sparse_key_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_key
)
sparse_table
.
sparse_value_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_value
)
sparse_table
.
sparse_grad_name
.
extend
(
fleet_desc
.
trainer_param
.
sparse_table
[
0
].
slot_gradient
)
sparse_table
.
emb_dim
=
fleet_desc
.
server_param
.
downpour_server_param
.
downpour_table_param
[
0
].
accessor
.
fea_dim
-
2
sparse_table
.
fea_dim
=
sparse_table
.
emb_dim
+
2
sparse_table
.
label_var_name
=
"click"
dense_table
=
downpour
.
dense_table
.
add
()
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
[
0
].
table_id
dense_table
.
dense_value_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_variable_name
)
dense_table
.
dense_grad_name
.
extend
(
fleet_desc
.
trainer_param
.
dense_table
[
0
].
dense_gradient_variable_name
)
downpour
.
skip_ops
.
extend
(
fleet_desc
.
trainer_param
.
skip_op
)
class
DeviceWorkerFactory
(
object
):
def
create_device_worker
(
self
,
worker_type
):
classname
=
worker_type
.
capitalize
()
print
(
"------------"
)
print
(
classname
)
return
globals
()[
classname
]()
python/paddle/fluid/trainer_desc.py
浏览文件 @
c1650120
...
...
@@ -13,7 +13,8 @@
# limitations under the License.
from
paddle.fluid.proto
import
trainer_desc_pb2
import
ps_pb2
as
pslib
from
distributed
import
ps_pb2
as
ps_pb2
from
device_worker
import
DeviceWorkerFactory
from
google.protobuf
import
text_format
__all__
=
[
'TrainerDesc'
,
'MultiTrainer'
,
'DistMultiTrainer'
]
...
...
@@ -28,16 +29,22 @@ class TrainerDesc(object):
text_format.Parse(f.read(), self.proto_desc)
'''
self
.
proto_desc
=
trainer_desc_pb2
.
TrainerDesc
()
self
.
proto_desc
.
thread_num
=
12
def
set_thread
(
self
,
thread_num
):
self
.
proto_desc
.
thread_num
=
thread_num
def
set_filelist
(
self
,
filelist
):
self
.
proto_desc
.
filelist
.
extend
(
filelist
)
self
.
proto_desc
.
thread_num
=
min
(
len
(
filelist
),
self
.
proto_desc
.
thread_num
)
def
set_data_feed
(
self
,
datafeed
):
self
.
proto_desc
.
data_desc
.
CopyFrom
(
datafeed
.
proto_desc
)
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
worker
=
None
):
pass
def
_desc
(
self
):
return
text_format
.
MessageToString
(
self
.
proto_desc
)
...
...
@@ -52,41 +59,20 @@ class MultiTrainer(TrainerDesc):
raise
ValueError
(
'ValueError: DeviceWorker %s '
'is not supported in MultiTrainer'
%
worker
)
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
worker
=
"Hogwild"
):
super
(
MultiTrainer
,
self
).
gen_trainer_desc
(
fleet_desc
,
worker
)
class
DistMultiTrainer
(
TrainerDesc
):
def
__init__
(
self
,
dataset
=
None
,
worker
=
'Downpour'
,
fleet_desc
=
None
):
def
__init__
(
self
):
super
(
DistMultiTrainer
,
self
).
__init__
()
if
worker
==
"Downpour"
:
self
.
proto_desc
.
device_worker_name
=
worker
+
"Worker"
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
data_feed
.
CopyFrom
(
dataset
)
downpour
=
self
.
proto_desc
.
downpour_param
.
add
()
# sparse table should specify:
sparse_table
=
downpour
.
sparse_table
.
add
()
sparse_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
sparse_table
.
table_id
sparse_table
.
sparse_key_name
.
CopyFrom
(
fleet_desc
.
trainer_param
()
.
sparse_table
().
slot_key
())
sparse_table
.
sparse_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_value
())
sparse_table
.
sparse_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
sparse_table
().
slot_gradient
())
sparse_table
.
emb_dim
=
fleet_desc
.
server_param
.
downpour_server_param
.
downpour_table_param
.
accessor
.
fea_dim
-
2
sparse_table
.
fea_dim
=
downpour
.
emb_dim
+
2
sparse_table
.
label_var_name
=
"click"
pass
# dense table should specify:
dense_table
=
downpour
.
dense_table
.
add
()
dense_table
.
table_id
=
\
fleet_desc
.
trainer_param
.
dense_table
.
table_id
# dense_value_name
dense_table
.
dense_value_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
dense_table
().
dense_variable_name
)
# dense_grad_name
dense_table
.
dense_grad_name
.
CopyFrom
(
fleet_desc
.
trainer_param
(
).
dense_table
().
dense_gradient_name
)
downpour
.
skipped_ops
.
extend
(
fleet_desc
.
trainer_param
.
skip_op
)
print
(
str
(
self
.
proto_desc
))
else
:
raise
ValueError
(
'ValueError: DeviceWorker %s '
'is not supported in DistMultiTrainer'
%
worker
)
def
gen_trainer_desc
(
self
,
dataset
=
None
,
fleet_desc
=
None
,
worker
=
"Downpour"
):
super
(
DistMultiTrainer
,
self
).
gen_trainer_desc
(
fleet_desc
,
worker
)
self
.
proto_desc
.
class_name
=
"DistMultiTrainer"
self
.
proto_desc
.
data_desc
.
CopyFrom
(
dataset
.
proto_desc
)
worker_builder
=
DeviceWorkerFactory
()
device_worker
=
worker_builder
.
create_device_worker
(
"Downpour"
)
device_worker
.
gen_worker_desc
(
self
.
proto_desc
,
fleet_desc
)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录