Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
3f12c2e0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3f12c2e0
编写于
9月 14, 2020
作者:
M
malin10
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
bug fix
上级
7c0196d4
变更
17
隐藏空白更改
内联
并排
Showing
17 changed file
with
107 addition
and
59 deletion
+107
-59
cmake/external/pslib.cmake
cmake/external/pslib.cmake
+1
-1
cmake/external/pslib_brpc.cmake
cmake/external/pslib_brpc.cmake
+1
-1
paddle/fluid/framework/CMakeLists.txt
paddle/fluid/framework/CMakeLists.txt
+1
-1
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+6
-2
paddle/fluid/framework/data_feed.h
paddle/fluid/framework/data_feed.h
+17
-0
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+10
-2
paddle/fluid/framework/data_set.h
paddle/fluid/framework/data_set.h
+4
-3
paddle/fluid/framework/device_worker.h
paddle/fluid/framework/device_worker.h
+1
-0
paddle/fluid/framework/fleet/CMakeLists.txt
paddle/fluid/framework/fleet/CMakeLists.txt
+1
-0
paddle/fluid/framework/fleet/heter_wrapper.cc
paddle/fluid/framework/fleet/heter_wrapper.cc
+2
-1
paddle/fluid/framework/fleet/tree_wrapper.cc
paddle/fluid/framework/fleet/tree_wrapper.cc
+18
-8
paddle/fluid/framework/fleet/tree_wrapper.h
paddle/fluid/framework/fleet/tree_wrapper.h
+20
-17
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+2
-2
python/paddle/fluid/executor.py
python/paddle/fluid/executor.py
+5
-4
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
+12
-11
python/paddle/fluid/incubate/fleet/utils/hdfs.py
python/paddle/fluid/incubate/fleet/utils/hdfs.py
+1
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+5
-5
未找到文件。
cmake/external/pslib.cmake
浏览文件 @
3f12c2e0
...
...
@@ -48,7 +48,7 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
PSLIB_SOURCE_DIR
}
DOWNLOAD_DIR
${
PSLIB_DOWNLOAD_DIR
}
DOWNLOAD_COMMAND
wget --no-check-certificate
${
PSLIB_URL
}
-c -q -O
${
PSLIB_NAME
}
.tar.gz
DOWNLOAD_COMMAND
cp /home/malin10/baidu/paddlepaddle/pslib/pslib.tar.gz
${
PSLIB_NAME
}
.tar.gz
&& tar zxvf
${
PSLIB_NAME
}
.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND
""
...
...
cmake/external/pslib_brpc.cmake
浏览文件 @
3f12c2e0
...
...
@@ -47,7 +47,7 @@ ExternalProject_Add(
${
EXTERNAL_PROJECT_LOG_ARGS
}
PREFIX
${
PSLIB_BRPC_SOURCE_DIR
}
DOWNLOAD_DIR
${
PSLIB_BRPC_DOWNLOAD_DIR
}
DOWNLOAD_COMMAND
wget --no-check-certificate
${
PSLIB_BRPC_URL
}
-c -q -O
${
PSLIB_BRPC_NAME
}
.tar.gz
DOWNLOAD_COMMAND
cp /home/malin10/Paddle/pslib_brpc.tar.gz
${
PSLIB_BRPC_NAME
}
.tar.gz
&& tar zxvf
${
PSLIB_BRPC_NAME
}
.tar.gz
DOWNLOAD_NO_PROGRESS 1
UPDATE_COMMAND
""
...
...
paddle/fluid/framework/CMakeLists.txt
浏览文件 @
3f12c2e0
...
...
@@ -217,7 +217,7 @@ elseif(WITH_PSLIB)
data_feed.cc device_worker.cc hogwild_worker.cc hetercpu_worker.cc downpour_worker.cc downpour_worker_opt.cc
pull_dense_worker.cc section_worker.cc device_worker_factory.cc data_set.cc DEPS op_registry
device_context scope framework_proto data_feed_proto heter_service_proto trainer_desc_proto glog
lod_rank_table fs shell fleet_wrapper heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
lod_rank_table fs shell fleet_wrapper
tree_wrapper
heter_wrapper box_wrapper lodtensor_printer feed_fetch_method
graph_to_program_pass variable_helper timer monitor pslib_brpc
)
# TODO: Fix these unittest failed on Windows
if
(
NOT WIN32
)
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
3f12c2e0
...
...
@@ -859,7 +859,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
}
else
{
const
char
*
str
=
reader
.
get
();
std
::
string
line
=
std
::
string
(
str
);
// VLOG(3
) << line;
VLOG
(
1
)
<<
line
;
char
*
endptr
=
const_cast
<
char
*>
(
str
);
int
pos
=
0
;
if
(
parse_ins_id_
)
{
...
...
@@ -907,9 +907,11 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
instance
->
rank
=
rank
;
pos
+=
len
+
1
;
}
std
::
stringstream
ss
;
for
(
size_t
i
=
0
;
i
<
use_slots_index_
.
size
();
++
i
)
{
int
idx
=
use_slots_index_
[
i
];
int
num
=
strtol
(
&
str
[
pos
],
&
endptr
,
10
);
ss
<<
"("
<<
idx
<<
", "
<<
num
<<
"); "
;
PADDLE_ENFORCE_NE
(
num
,
0
,
platform
::
errors
::
InvalidArgument
(
...
...
@@ -936,7 +938,8 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
uint64_t
feasign
=
(
uint64_t
)
strtoull
(
endptr
,
&
endptr
,
10
);
// if uint64 feasign is equal to zero, ignore it
// except when slot is dense
if
(
feasign
==
0
&&
!
use_slots_is_dense_
[
i
])
{
if
(
feasign
==
0
&&
!
use_slots_is_dense_
[
i
]
&&
all_slots_
[
i
]
!=
"12345"
)
{
continue
;
}
FeatureKey
f
;
...
...
@@ -954,6 +957,7 @@ bool MultiSlotInMemoryDataFeed::ParseOneInstanceFromPipe(Record* instance) {
}
}
}
VLOG
(
1
)
<<
ss
.
str
();
instance
->
float_feasigns_
.
shrink_to_fit
();
instance
->
uint64_feasigns_
.
shrink_to_fit
();
fea_num_
+=
instance
->
uint64_feasigns_
.
size
();
...
...
paddle/fluid/framework/data_feed.h
浏览文件 @
3f12c2e0
...
...
@@ -31,6 +31,7 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/framework/archive.h"
#include "paddle/fluid/framework/blocking_queue.h"
#include "paddle/fluid/framework/channel.h"
...
...
@@ -94,6 +95,22 @@ struct Record {
uint64_t
search_id
;
uint32_t
rank
;
uint32_t
cmatch
;
void
Print
()
{
std
::
stringstream
ss
;
ss
<<
"int64_feasigns: ["
;
for
(
uint64_t
i
=
0
;
i
<
uint64_feasigns_
.
size
();
i
++
)
{
ss
<<
"("
<<
uint64_feasigns_
[
i
].
slot
()
<<
", "
<<
uint64_feasigns_
[
i
].
sign
().
uint64_feasign_
<<
"); "
;
}
ss
<<
"]
\t\t
float64_feasigns:["
;
for
(
uint64_t
i
=
0
;
i
<
float_feasigns_
.
size
();
i
++
)
{
ss
<<
"("
<<
float_feasigns_
[
i
].
slot
()
<<
", "
<<
float_feasigns_
[
i
].
sign
().
float_feasign_
<<
"); "
;
}
ss
<<
"]
\n
"
;
VLOG
(
1
)
<<
ss
.
str
();
}
};
struct
PvInstanceObject
{
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
3f12c2e0
...
...
@@ -365,7 +365,8 @@ void DatasetImpl<T>::TDMDump(std::string name, const uint64_t table_id,
// do sample
template
<
typename
T
>
void
DatasetImpl
<
T
>::
TDMSample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
)
{
const
uint64_t
type_slot
,
const
uint64_t
start_h
)
{
VLOG
(
0
)
<<
"DatasetImpl<T>::Sample() begin"
;
platform
::
Timer
timeline
;
timeline
.
Start
();
...
...
@@ -379,6 +380,7 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
if
(
!
multi_output_channel_
[
i
]
||
multi_output_channel_
[
i
]
->
Size
()
==
0
)
{
continue
;
}
multi_output_channel_
[
i
]
->
Close
();
multi_output_channel_
[
i
]
->
ReadAll
(
data
[
i
]);
}
}
else
{
...
...
@@ -388,17 +390,23 @@ void DatasetImpl<T>::TDMSample(const uint16_t sample_slot,
input_channel_
->
ReadAll
(
data
[
data
.
size
()
-
1
]);
}
VLOG
(
1
)
<<
"finish read src data, data.size = "
<<
data
.
size
()
<<
"; details: "
;
auto
tree_ptr
=
TreeWrapper
::
GetInstance
();
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
for
(
auto
i
=
0
;
i
<
data
.
size
();
i
++
)
{
VLOG
(
1
)
<<
"data["
<<
i
<<
"]: size = "
<<
data
[
i
].
size
();
std
::
vector
<
T
>
tmp_results
;
tree_ptr
->
sample
(
sample_slot
,
type_slot
,
data
[
i
],
&
tmp_results
);
tree_ptr
->
sample
(
sample_slot
,
type_slot
,
&
data
[
i
],
&
tmp_results
,
start_h
);
VLOG
(
1
)
<<
"sample_results("
<<
sample_slot
<<
", "
<<
type_slot
<<
") = "
<<
tmp_results
.
size
();
sample_results
.
push_back
(
tmp_results
);
}
auto
output_channel_num
=
multi_output_channel_
.
size
();
for
(
auto
i
=
0
;
i
<
sample_results
.
size
();
i
++
)
{
auto
output_idx
=
fleet_ptr
->
LocalRandomEngine
()()
%
output_channel_num
;
multi_output_channel_
[
output_idx
]
->
Open
();
multi_output_channel_
[
output_idx
]
->
Write
(
std
::
move
(
sample_results
[
i
]));
}
...
...
paddle/fluid/framework/data_set.h
浏览文件 @
3f12c2e0
...
...
@@ -47,8 +47,8 @@ class Dataset {
virtual
~
Dataset
()
{}
virtual
void
InitTDMTree
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
config
)
=
0
;
virtual
void
TDMSample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
)
=
0
;
virtual
void
TDMSample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
,
const
uint64_t
start_h
)
=
0
;
virtual
void
TDMDump
(
std
::
string
name
,
const
uint64_t
table_id
,
int
fea_value_dim
,
const
std
::
string
tree_path
)
=
0
;
// set file list
...
...
@@ -168,7 +168,8 @@ class DatasetImpl : public Dataset {
virtual
void
InitTDMTree
(
const
std
::
vector
<
std
::
pair
<
std
::
string
,
std
::
string
>>
config
);
virtual
void
TDMSample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
);
virtual
void
TDMSample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
,
const
uint64_t
start_h
);
virtual
void
TDMDump
(
std
::
string
name
,
const
uint64_t
table_id
,
int
fea_value_dim
,
const
std
::
string
tree_path
);
...
...
paddle/fluid/framework/device_worker.h
浏览文件 @
3f12c2e0
...
...
@@ -171,6 +171,7 @@ class DeviceWorker {
device_reader_
->
SetPlace
(
place
);
}
virtual
Scope
*
GetThreadScope
()
{
return
thread_scope_
;
}
virtual
void
GetXpuOpIndex
()
{}
protected:
virtual
void
DumpParam
(
const
Scope
&
scope
,
const
int
batch_id
);
...
...
paddle/fluid/framework/fleet/CMakeLists.txt
浏览文件 @
3f12c2e0
if
(
WITH_PSLIB
)
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope pslib_brpc pslib
)
cc_library
(
tree_wrapper SRCS tree_wrapper.cc
)
else
()
cc_library
(
fleet_wrapper SRCS fleet_wrapper.cc DEPS framework_proto variable_helper scope
)
endif
(
WITH_PSLIB
)
...
...
paddle/fluid/framework/fleet/heter_wrapper.cc
浏览文件 @
3f12c2e0
...
...
@@ -192,7 +192,8 @@ framework::proto::VarType::Type HeterWrapper::ToVarType(
case
VariableMessage
::
BOOL
:
return
framework
::
proto
::
VarType
::
BOOL
;
// NOLINT
default:
VLOG
(
0
)
<<
"Not support type "
<<
type
;
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"ToVarType:Unsupported type %d"
,
type
));
}
}
...
...
paddle/fluid/framework/fleet/tree_wrapper.cc
浏览文件 @
3f12c2e0
...
...
@@ -12,20 +12,25 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include <boost/algorithm/string.hpp>
#include <boost/lexical_cast.hpp>
#include "paddle/fluid/framework/data_feed.h"
#include "paddle/fluid/framework/fleet/fleet_wrapper.h"
#include "paddle/fluid/framework/fleet/tree_wrapper.h"
#include "paddle/fluid/framework/io/fs.h"
namespace
paddle
{
namespace
framework
{
int
Tree
::
load
(
std
::
string
path
,
std
::
string
tree_pipe_command_
)
{
std
::
shared_ptr
<
TreeWrapper
>
TreeWrapper
::
s_instance_
(
nullptr
);
int
Tree
::
load
(
std
::
string
path
)
{
uint64_t
linenum
=
0
;
size_t
idx
=
0
;
std
::
vector
<
std
::
string
>
lines
;
...
...
@@ -33,10 +38,10 @@ int Tree::load(std::string path, std::string tree_pipe_command_) {
std
::
vector
<
std
::
string
>
items
;
int
err_no
;
std
::
shared_ptr
<
FILE
>
fp_
=
fs_open_read
(
path
,
&
err_no
,
tree_pipe_command_
);
std
::
shared_ptr
<
FILE
>
fp_
=
fs_open_read
(
path
,
&
err_no
,
""
);
string
::
LineFileReader
reader
;
while
(
reader
.
getline
(
&*
(
fp_
.
get
())))
{
line
=
std
::
string
(
reader
.
get
());
auto
line
=
std
::
string
(
reader
.
get
());
strs
.
clear
();
boost
::
split
(
strs
,
line
,
boost
::
is_any_of
(
"
\t
"
));
if
(
0
==
linenum
)
{
...
...
@@ -132,16 +137,21 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
std
::
shared_ptr
<
FILE
>
fp
=
paddle
::
framework
::
fs_open
(
tree_path
,
"w"
,
&
ret
,
""
);
std
::
vector
<
uint64_t
>
fea_keys
,
std
::
vector
<
float
*>
pull_result_ptr
;
std
::
vector
<
uint64_t
>
fea_keys
;
std
::
vector
<
float
*>
pull_result_ptr
;
fea_keys
.
reserve
(
_total_node_num
);
pull_result_ptr
.
reserve
(
_total_node_num
);
for
(
size_t
i
=
0
;
i
!=
_total_node_num
;
++
i
)
{
_nodes
[
i
].
embedding
.
resize
(
fea_value_dim
);
fea_key
.
push_back
(
_nodes
[
i
].
id
);
fea_key
s
.
push_back
(
_nodes
[
i
].
id
);
pull_result_ptr
.
push_back
(
_nodes
[
i
].
embedding
.
data
());
}
auto
fleet_ptr
=
FleetWrapper
::
GetInstance
();
fleet_ptr
->
pslib_ptr_
->
_worker_ptr
->
pull_sparse
(
pull_result_ptr
.
data
(),
table_id
,
fea_keys
.
data
(),
fea_keys
.
size
());
std
::
string
first_line
=
boost
::
lexical_cast
<
std
::
string
>
(
_total_node_num
)
+
"
\t
"
+
boost
::
lexical_cast
<
std
::
string
>
(
_tree_height
);
...
...
@@ -183,7 +193,7 @@ int Tree::dump_tree(const uint64_t table_id, int fea_value_dim,
bool
Tree
::
trace_back
(
uint64_t
id
,
std
::
vector
<
std
::
pair
<
uint64_t
,
uint32_t
>>*
ids
)
{
ids
.
clear
();
ids
->
clear
();
std
::
unordered_map
<
uint64_t
,
Node
*>::
iterator
find_it
=
_leaf_node_map
.
find
(
id
);
if
(
find_it
==
_leaf_node_map
.
end
())
{
...
...
paddle/fluid/framework/fleet/tree_wrapper.h
浏览文件 @
3f12c2e0
...
...
@@ -103,15 +103,14 @@ class TreeWrapper {
}
void
sample
(
const
uint16_t
sample_slot
,
const
uint64_t
type_slot
,
const
std
::
vector
<
Record
>&
src_datas
,
std
::
vector
<
Record
>*
sample_results
)
{
std
::
vector
<
Record
>*
src_datas
,
std
::
vector
<
Record
>*
sample_results
,
const
uint64_t
start_h
)
{
sample_results
->
clear
();
auto
debug_idx
=
0
;
for
(
auto
&
data
:
src_datas
)
{
if
(
debug_idx
==
0
)
{
VLOG
(
0
)
<<
"src record"
;
data
.
Print
();
}
for
(
auto
&
data
:
*
src_datas
)
{
VLOG
(
1
)
<<
"src record"
;
data
.
Print
();
uint64_t
start_idx
=
sample_results
->
size
();
VLOG
(
1
)
<<
"before sample, sample_results.size = "
<<
start_idx
;
uint64_t
sample_feasign_idx
=
-
1
,
type_feasign_idx
=
-
1
;
for
(
uint64_t
i
=
0
;
i
<
data
.
uint64_feasigns_
.
size
();
i
++
)
{
if
(
data
.
uint64_feasigns_
[
i
].
slot
()
==
sample_slot
)
{
...
...
@@ -121,6 +120,8 @@ class TreeWrapper {
type_feasign_idx
=
i
;
}
}
VLOG
(
1
)
<<
"sample_feasign_idx: "
<<
sample_feasign_idx
<<
"; type_feasign_idx: "
<<
type_feasign_idx
;
if
(
sample_feasign_idx
>
0
)
{
std
::
vector
<
std
::
pair
<
uint64_t
,
uint32_t
>>
trace_ids
;
for
(
std
::
unordered_map
<
std
::
string
,
TreePtr
>::
iterator
ite
=
...
...
@@ -139,18 +140,20 @@ class TreeWrapper {
Record
instance
(
data
);
instance
.
uint64_feasigns_
[
sample_feasign_idx
].
sign
().
uint64_feasign_
=
trace_ids
[
i
].
first
;
if
(
type_feasign_idx
>
0
)
instance
.
uint64_feasigns_
[
type_feasign_idx
]
.
sign
()
.
uint64_feasign_
+=
trace_ids
[
i
].
second
*
100
;
if
(
debug_idx
==
0
)
{
VLOG
(
0
)
<<
"sample results:"
<<
i
;
instance
.
Print
();
}
if
(
type_feasign_idx
>
0
&&
trace_ids
[
i
].
second
>
start_h
)
instance
.
uint64_feasigns_
[
type_feasign_idx
]
.
sign
().
uint64_feasign_
=
(
instance
.
uint64_feasigns_
[
type_feasign_idx
]
.
sign
()
.
uint64_feasign_
+
1
)
*
100
+
trace_ids
[
i
].
second
;
sample_results
->
push_back
(
instance
);
}
}
debug_idx
+=
1
;
for
(
auto
i
=
start_idx
;
i
<
sample_results
->
size
();
i
++
)
{
sample_results
->
at
(
i
).
Print
();
}
}
return
;
}
...
...
python/paddle/fluid/dataset.py
浏览文件 @
3f12c2e0
...
...
@@ -611,8 +611,8 @@ class InMemoryDataset(DatasetBase):
def
init_tdm_tree
(
self
,
configs
):
self
.
dataset
.
init_tdm_tree
(
configs
)
def
tdm_sample
(
self
,
sample_slot
,
type_slot
):
self
.
dataset
.
tdm_sample
(
sample_slot
,
type_slot
)
def
tdm_sample
(
self
,
sample_slot
,
type_slot
,
start_h
):
self
.
dataset
.
tdm_sample
(
sample_slot
,
type_slot
,
start_h
)
def
tdm_dump
(
self
,
name
,
table_id
,
fea_value_dim
,
tree_path
):
self
.
dataset
.
tdm_dump
(
name
,
table_id
,
fea_value_dim
,
tree_path
)
...
...
python/paddle/fluid/executor.py
浏览文件 @
3f12c2e0
...
...
@@ -1353,10 +1353,11 @@ class Executor(object):
print_period
=
100
):
is_heter
=
0
if
not
program
.
_fleet_opt
is
None
:
if
program
.
_fleet_opt
.
get
(
"worker_class"
,
""
)
==
"HeterCpuWorker"
:
is_heter
=
1
if
program
.
_fleet_opt
(
"trainer"
,
""
)
==
"HeterXpuTrainer"
:
is_heter
=
1
is_heter
=
0
#if program._fleet_opt.get("worker_class", "") == "HeterCpuWorker":
# is_heter = 1
#if program._fleet_opt("trainer", "") == "HeterXpuTrainer":
# is_heter = 1
if
scope
is
None
:
scope
=
global_scope
()
if
fetch_list
is
None
:
...
...
python/paddle/fluid/incubate/fleet/utils/fleet_util.py
浏览文件 @
3f12c2e0
...
...
@@ -24,8 +24,7 @@ import sys
import
time
import
paddle.fluid
as
fluid
from
paddle.fluid.log_helper
import
get_logger
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
as
fleet_pslib
from
paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler
import
fleet
as
fleet_transpiler
from
paddle.fluid.incubate.fleet.parameter_server.pslib
import
fleet
from
.
import
hdfs
from
.hdfs
import
*
from
.
import
utils
...
...
@@ -35,7 +34,7 @@ __all__ = ["FleetUtil"]
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
fleet
=
fleet_pslib
#
fleet = fleet_pslib
class
FleetUtil
(
object
):
...
...
@@ -52,14 +51,16 @@ class FleetUtil(object):
"""
def
__init__
(
self
,
mode
=
"pslib"
):
global
fleet
if
mode
==
"pslib"
:
fleet
=
fleet_pslib
elif
mode
==
"transpiler"
:
fleet
=
fleet_transpiler
else
:
raise
ValueError
(
"Please choose one mode from [
\"
pslib
\"
,
\"
transpiler
\"
]"
)
pass
# global fleet
# if mode == "pslib":
# fleet = fleet_pslib
# elif mode == "transpiler":
# fleet = fleet_transpiler
# else:
# raise ValueError(
# "Please choose one mode from [\"pslib\", \"transpiler\"]")
def
rank0_print
(
self
,
s
):
"""
...
...
python/paddle/fluid/incubate/fleet/utils/hdfs.py
浏览文件 @
3f12c2e0
...
...
@@ -79,7 +79,7 @@ class HDFSClient(FS):
time_out
=
5
*
60
*
1000
,
#ms
sleep_inter
=
1000
):
#ms
# Raise exception if JAVA_HOME not exists.
java_home
=
os
.
environ
[
"JAVA_HOME"
]
#
java_home = os.environ["JAVA_HOME"]
self
.
pre_commands
=
[]
hadoop_bin
=
'%s/bin/hadoop'
%
hadoop_home
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
3f12c2e0
...
...
@@ -489,11 +489,11 @@ def embedding(input,
check_dtype(dtype, 'dtype', ['float16', 'float32', 'float64'],
'fluid.layers.embedding')
if is_distributed:
is_distributed = False
warnings.warn(
"is_distributed is go out of use, `fluid.contrib.layers.sparse_embedding` is your needed"
)
#
if is_distributed:
#
is_distributed = False
#
warnings.warn(
#
"is_distributed is go out of use, `fluid.contrib.layers.sparse_embedding` is your needed"
#
)
remote_prefetch = True if is_sparse else False
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录