Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
85c6937b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
85c6937b
编写于
7月 21, 2022
作者:
Z
zhaocaibei123
提交者:
GitHub
7月 21, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add slot attr for push sparse op (#44422)
* add slot attr for push sparse op * add pybind * remove fleet * add unittest * fix
上级
1a7f2de3
变更
13
显示空白变更内容
内联
并排
Showing
13 changed file
with
100 addition
and
36 deletion
+100
-36
paddle/fluid/distributed/ps/wrapper/fleet.cc
paddle/fluid/distributed/ps/wrapper/fleet.cc
+11
-12
paddle/fluid/distributed/ps/wrapper/fleet.h
paddle/fluid/distributed/ps/wrapper/fleet.h
+2
-0
paddle/fluid/framework/data_feed.cc
paddle/fluid/framework/data_feed.cc
+7
-7
paddle/fluid/framework/data_set.cc
paddle/fluid/framework/data_set.cc
+1
-1
paddle/fluid/framework/io/fs.cc
paddle/fluid/framework/io/fs.cc
+33
-7
paddle/fluid/framework/io/fs.h
paddle/fluid/framework/io/fs.h
+8
-2
paddle/fluid/framework/io/test_fs.cc
paddle/fluid/framework/io/test_fs.cc
+13
-0
paddle/fluid/operators/lookup_table_op.cc
paddle/fluid/operators/lookup_table_op.cc
+1
-0
paddle/fluid/operators/lookup_table_v2_op.cc
paddle/fluid/operators/lookup_table_v2_op.cc
+1
-0
paddle/fluid/operators/pscore/distributed_push_sparse_op.cc
paddle/fluid/operators/pscore/distributed_push_sparse_op.cc
+5
-0
paddle/fluid/operators/pscore/distributed_push_sparse_op.h
paddle/fluid/operators/pscore/distributed_push_sparse_op.h
+3
-0
python/paddle/distributed/passes/ps_trainer_pass.py
python/paddle/distributed/passes/ps_trainer_pass.py
+8
-4
python/paddle/fluid/contrib/layers/nn.py
python/paddle/fluid/contrib/layers/nn.py
+7
-3
未找到文件。
paddle/fluid/distributed/ps/wrapper/fleet.cc
浏览文件 @
85c6937b
...
...
@@ -529,10 +529,12 @@ void FleetWrapper::PushSparseFromTensorAsync(
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
std
::
vector
<
int
>&
slots
,
const
LoDTensor
*
shows
,
const
LoDTensor
*
clks
,
std
::
vector
<
LoDTensor
*>*
outputs
,
bool
use_cvm_op
)
{
CHECK
(
slots
.
size
()
==
inputs
->
size
());
int
batch_size
=
-
1
;
bool
batch_size_consist
=
true
;
for
(
auto
*
input
:
*
inputs
)
{
...
...
@@ -568,8 +570,8 @@ void FleetWrapper::PushSparseFromTensorAsync(
// TODO(zhaocaibei123): check type of show/clk is int? float? uint64?
// const long int* show_tensor = shows->data<int64_t>();
// const long int* clk_tensor = clks->data<int64_t>();
const
int64_t
*
show_tensor
=
shows
->
data
<
int64_
t
>
();
const
int64_t
*
clk_tensor
=
clks
->
data
<
int64_
t
>
();
const
float
*
show_tensor
=
shows
->
data
<
floa
t
>
();
const
float
*
clk_tensor
=
clks
->
data
<
floa
t
>
();
for
(
size_t
index
=
0
;
index
<
inputs
->
size
();
++
index
)
{
framework
::
LoDTensor
*
g_tensor
=
outputs
->
at
(
index
);
...
...
@@ -603,15 +605,14 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_keys
.
emplace_back
(
real_id
);
if
(
use_cvm_op
)
{
push_values
.
emplace_back
(
fea_dim
+
1
);
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
0
]
=
static_cast
<
float
>
(
slots
[
index
]);
float
*
data
=
push_values
.
back
().
data
()
+
1
;
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
}
else
{
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined
// in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
// in ctr_accessor.h
push_values
.
back
()[
0
]
=
static_cast
<
float
>
(
slots
[
index
]);
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
...
...
@@ -631,18 +632,16 @@ void FleetWrapper::PushSparseFromTensorAsync(
push_keys
.
emplace_back
(
real_id
);
if
(
use_cvm_op
)
{
push_values
.
emplace_back
(
fea_dim
+
1
);
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
0
]
=
static_cast
<
float
>
(
slots
[
index
]);
float
*
data
=
push_values
.
back
().
data
()
+
1
;
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
}
else
{
push_values
.
emplace_back
(
fea_dim
+
3
);
// slot show clk grad... consistent with CtrCommonPushValue defined in
// ctr_accessor.h
push_values
.
back
()[
0
]
=
2
;
// TODO(zhaocaibei123): slot
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
static_cast
<
float
>
(
show_tensor
[
i
]));
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
static_cast
<
float
>
(
clk_tensor
[
i
]));
push_values
.
back
()[
0
]
=
static_cast
<
float
>
(
slots
[
index
]);
push_values
.
back
()[
1
]
=
(
i
>=
show_size
?
1
:
show_tensor
[
i
]);
push_values
.
back
()[
2
]
=
(
i
>=
clk_size
?
0
:
clk_tensor
[
i
]);
float
*
data
=
push_values
.
back
().
data
()
+
3
;
memcpy
(
data
,
g
+
output_len
,
sizeof
(
float
)
*
fea_dim
);
}
...
...
paddle/fluid/distributed/ps/wrapper/fleet.h
浏览文件 @
85c6937b
...
...
@@ -190,11 +190,13 @@ class FleetWrapper {
const
std
::
vector
<
std
::
string
>&
input_names
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
// NOLINT
std
::
vector
<
const
LoDTensor
*>*
outputs
);
// NOLINT
void
PushSparseFromTensorAsync
(
const
uint64_t
table_id
,
int
fea_dim
,
uint64_t
padding_id
,
platform
::
Place
place
,
std
::
vector
<
const
LoDTensor
*>*
inputs
,
std
::
vector
<
int
>&
slots
,
// NOLINT
const
LoDTensor
*
shows
,
const
LoDTensor
*
clicks
,
std
::
vector
<
LoDTensor
*>*
outputs
,
...
...
paddle/fluid/framework/data_feed.cc
浏览文件 @
85c6937b
...
...
@@ -309,7 +309,7 @@ void PrivateQueueDataFeed<T>::ReadThread() {
std
::
string
filename
;
while
(
PickOneFile
(
&
filename
))
{
int
err_no
=
0
;
fp_
=
fs_open_read
(
filename
,
&
err_no
,
pipe_command_
);
fp_
=
fs_open_read
(
filename
,
&
err_no
,
pipe_command_
,
true
);
__fsetlocking
(
&*
fp_
,
FSETLOCKING_BYCALLER
);
T
instance
;
while
(
ParseOneInstanceFromPipe
(
&
instance
))
{
...
...
@@ -538,7 +538,7 @@ void InMemoryDataFeed<T>::LoadIntoMemory() {
}
else
{
#endif
int
err_no
=
0
;
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
);
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
,
true
);
#ifdef PADDLE_WITH_BOX_PS
}
#endif
...
...
@@ -574,7 +574,7 @@ void InMemoryDataFeed<T>::LoadIntoMemoryFromSo() {
(defined PADDLE_WITH_PSLIB)
VLOG
(
3
)
<<
"LoadIntoMemoryFromSo() begin, thread_id="
<<
thread_id_
;
int
buf_len
=
1024
*
1024
*
10
;
char
*
buf
=
(
char
*
)
malloc
(
buf_len
+
10
);
char
*
buf
=
reinterpret_cast
<
char
*>
(
malloc
(
buf_len
+
10
)
);
auto
ps_gpu_ptr
=
PSGPUWrapper
::
GetInstance
();
paddle
::
framework
::
CustomParser
*
parser
=
...
...
@@ -681,7 +681,7 @@ void MultiSlotDataFeed::ReadThread() {
std
::
string
filename
;
while
(
PickOneFile
(
&
filename
))
{
int
err_no
=
0
;
fp_
=
fs_open_read
(
filename
,
&
err_no
,
pipe_command_
);
fp_
=
fs_open_read
(
filename
,
&
err_no
,
pipe_command_
,
true
);
CHECK
(
fp_
!=
nullptr
);
__fsetlocking
(
&*
fp_
,
FSETLOCKING_BYCALLER
);
std
::
vector
<
MultiSlotType
>
instance
;
...
...
@@ -2175,7 +2175,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByFile(void) {
lines
);
}
else
{
int
err_no
=
0
;
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
);
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
,
true
);
CHECK
(
this
->
fp_
!=
nullptr
);
__fsetlocking
(
&*
(
this
->
fp_
),
FSETLOCKING_BYCALLER
);
...
...
@@ -2265,7 +2265,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByLine(void) {
do
{
int
err_no
=
0
;
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
);
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
,
true
);
CHECK
(
this
->
fp_
!=
nullptr
);
__fsetlocking
(
&*
(
this
->
fp_
),
FSETLOCKING_BYCALLER
);
lines
=
line_reader
.
read_file
(
this
->
fp_
.
get
(),
line_func
,
lines
);
...
...
@@ -2314,7 +2314,7 @@ void SlotRecordInMemoryDataFeed::LoadIntoMemoryByCommand(void) {
do
{
int
err_no
=
0
;
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
);
this
->
fp_
=
fs_open_read
(
filename
,
&
err_no
,
this
->
pipe_command_
,
true
);
CHECK
(
this
->
fp_
!=
nullptr
);
__fsetlocking
(
&*
(
this
->
fp_
),
FSETLOCKING_BYCALLER
);
...
...
paddle/fluid/framework/data_set.cc
浏览文件 @
85c6937b
...
...
@@ -102,7 +102,7 @@ void DatasetImpl<T>::SetHdfsConfig(const std::string& fs_name,
cmd
+=
" -D fs.default.name="
+
fs_name
;
cmd
+=
" -D hadoop.job.ugi="
+
fs_ugi
;
cmd
+=
" -Ddfs.client.block.write.retries=15 -Ddfs.rpc.timeout=500000"
;
paddle
::
framework
::
hdfs_set_command
(
cmd
);
paddle
::
framework
::
dataset_
hdfs_set_command
(
cmd
);
}
template
<
typename
T
>
...
...
paddle/fluid/framework/io/fs.cc
浏览文件 @
85c6937b
...
...
@@ -230,6 +230,20 @@ const std::string& hdfs_command() { return hdfs_command_internal(); }
void
hdfs_set_command
(
const
std
::
string
&
x
)
{
hdfs_command_internal
()
=
x
;
}
// dataset and model may be on different afs cluster
static
std
::
string
&
dataset_hdfs_command_internal
()
{
static
std
::
string
x
=
"hadoop fs"
;
return
x
;
}
const
std
::
string
&
dataset_hdfs_command
()
{
return
dataset_hdfs_command_internal
();
}
void
dataset_hdfs_set_command
(
const
std
::
string
&
x
)
{
dataset_hdfs_command_internal
()
=
x
;
}
static
std
::
string
&
customized_download_cmd_internal
()
{
static
std
::
string
x
=
""
;
return
x
;
...
...
@@ -243,19 +257,30 @@ void set_download_command(const std::string& x) {
std
::
shared_ptr
<
FILE
>
hdfs_open_read
(
std
::
string
path
,
int
*
err_no
,
const
std
::
string
&
converter
)
{
const
std
::
string
&
converter
,
bool
read_data
)
{
if
(
download_cmd
()
!=
""
)
{
// use customized download command
path
=
string
::
format_string
(
"%s
\"
%s
\"
"
,
download_cmd
().
c_str
(),
path
.
c_str
());
}
else
{
if
(
fs_end_with_internal
(
path
,
".gz"
))
{
if
(
read_data
)
{
path
=
string
::
format_string
(
"%s -text
\"
%s
\"
"
,
dataset_hdfs_command
().
c_str
(),
path
.
c_str
());
}
else
{
path
=
string
::
format_string
(
"%s -text
\"
%s
\"
"
,
hdfs_command
().
c_str
(),
path
.
c_str
());
}
}
else
{
if
(
read_data
)
{
path
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
dataset_hdfs_command
().
c_str
(),
path
.
c_str
());
}
else
{
path
=
string
::
format_string
(
"%s -cat
\"
%s
\"
"
,
hdfs_command
().
c_str
(),
path
.
c_str
());
}
}
}
bool
is_pipe
=
true
;
fs_add_read_converter_internal
(
path
,
is_pipe
,
converter
);
...
...
@@ -370,13 +395,14 @@ int fs_select_internal(const std::string& path) {
std
::
shared_ptr
<
FILE
>
fs_open_read
(
const
std
::
string
&
path
,
int
*
err_no
,
const
std
::
string
&
converter
)
{
const
std
::
string
&
converter
,
bool
read_data
)
{
switch
(
fs_select_internal
(
path
))
{
case
0
:
return
localfs_open_read
(
path
,
converter
);
case
1
:
return
hdfs_open_read
(
path
,
err_no
,
converter
);
return
hdfs_open_read
(
path
,
err_no
,
converter
,
read_data
);
default:
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
...
...
paddle/fluid/framework/io/fs.h
浏览文件 @
85c6937b
...
...
@@ -64,13 +64,18 @@ extern const std::string& hdfs_command();
extern
void
hdfs_set_command
(
const
std
::
string
&
x
);
extern
const
std
::
string
&
dataset_hdfs_command
();
extern
void
dataset_hdfs_set_command
(
const
std
::
string
&
x
);
extern
const
std
::
string
&
download_cmd
();
extern
void
set_download_command
(
const
std
::
string
&
x
);
extern
std
::
shared_ptr
<
FILE
>
hdfs_open_read
(
std
::
string
path
,
int
*
err_no
,
const
std
::
string
&
converter
);
const
std
::
string
&
converter
,
bool
read_data
);
extern
std
::
shared_ptr
<
FILE
>
hdfs_open_write
(
std
::
string
path
,
int
*
err_no
,
...
...
@@ -91,7 +96,8 @@ extern void hdfs_mv(const std::string& src, const std::string& dest);
// aut-detect fs
extern
std
::
shared_ptr
<
FILE
>
fs_open_read
(
const
std
::
string
&
path
,
int
*
err_no
,
const
std
::
string
&
converter
);
const
std
::
string
&
converter
,
bool
read_data
=
false
);
extern
std
::
shared_ptr
<
FILE
>
fs_open_write
(
const
std
::
string
&
path
,
int
*
err_no
,
...
...
paddle/fluid/framework/io/test_fs.cc
浏览文件 @
85c6937b
...
...
@@ -45,5 +45,18 @@ TEST(FS, mv) {
}
catch
(...)
{
VLOG
(
3
)
<<
"test hdfs_mv, catch expected errors of unknown prefix"
;
}
try
{
paddle
::
framework
::
dataset_hdfs_set_command
(
"hadoop -D hadoop.job.ugi=anotherxxx fs -text"
);
int
err_no
=
0
;
paddle
::
framework
::
hdfs_open_read
(
"afs:/none.gz"
,
&
err_no
,
""
,
true
);
paddle
::
framework
::
hdfs_open_read
(
"afs:/none.gz"
,
&
err_no
,
""
,
false
);
paddle
::
framework
::
hdfs_open_read
(
"afs:/none"
,
&
err_no
,
""
,
true
);
paddle
::
framework
::
hdfs_open_read
(
"afs:/none"
,
&
err_no
,
""
,
false
);
}
catch
(...)
{
VLOG
(
3
)
<<
"test hdfs_open_read, catch expected errors of unknown path"
;
}
#endif
}
paddle/fluid/operators/lookup_table_op.cc
浏览文件 @
85c6937b
...
...
@@ -134,6 +134,7 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"in the order of input variables for mapping"
)
.
SetDefault
({});
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
).
SetDefault
(
0
);
AddAttr
<
int
>
(
"slot"
,
"slot of id"
).
SetDefault
(
0
).
AsExtra
();
AddAttr
<
bool
>
(
"grad_inplace"
,
"(boolean, default false) "
"If the grad op reuse the input's variable."
)
...
...
paddle/fluid/operators/lookup_table_v2_op.cc
浏览文件 @
85c6937b
...
...
@@ -105,6 +105,7 @@ class LookupTableV2OpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
int
>
(
"trainer_id"
,
"trainer id from 0 ~ worker_num."
)
.
SetDefault
(
0
)
.
AsExtra
();
AddAttr
<
int
>
(
"slot"
,
"slot of id"
).
SetDefault
(
0
).
AsExtra
();
AddAttr
<
std
::
vector
<
int64_t
>>
(
"height_sections"
,
"Height for each output SelectedRows."
)
.
SetDefault
(
std
::
vector
<
int64_t
>
({}))
...
...
paddle/fluid/operators/pscore/distributed_push_sparse_op.cc
浏览文件 @
85c6937b
...
...
@@ -113,6 +113,11 @@ class DistributedPushSparseOpMaker : public framework::OpProtoAndCheckerMaker {
AddAttr
<
bool
>
(
"use_cvm_op"
,
"(boolean, default false) Use cvm op or not."
)
.
SetDefault
(
false
);
AddAttr
<
std
::
vector
<
int
>>
(
"slots"
,
"[slot_id1, slot_id2] Slots array of Ids."
)
.
SetDefault
({})
.
AsExtra
();
AddComment
(
R"DOC(
Lookup Tablel Prefetch Operator.
This operator is used to perform lookup on parameter W,
...
...
paddle/fluid/operators/pscore/distributed_push_sparse_op.h
浏览文件 @
85c6937b
...
...
@@ -33,6 +33,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
auto
table_id
=
context
.
Attr
<
int
>
(
"table_id"
);
auto
emb_dim
=
context
.
Attr
<
int
>
(
"size"
);
auto
use_cvm_op
=
context
.
Attr
<
bool
>
(
"use_cvm_op"
);
auto
slots
=
context
.
Attr
<
std
::
vector
<
int
>>
(
"slots"
);
auto
inputs
=
context
.
MultiInput
<
framework
::
LoDTensor
>
(
"Ids"
);
auto
shows
=
context
.
Input
<
framework
::
LoDTensor
>
(
"Shows"
);
...
...
@@ -47,6 +48,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
static_cast
<
uint64_t
>
(
padding_idx
),
context
.
GetPlace
(),
&
inputs
,
slots
,
shows
,
clks
,
&
outputs
,
...
...
@@ -103,6 +105,7 @@ class DistributedPushSparseKernel : public framework::OpKernel<T> {
static_cast
<
uint64_t
>
(
padding_idx
),
context
.
GetPlace
(),
&
tmp_input_vec
,
slots
,
tmp_shows_tensor
,
tmp_clicks_tensor
,
&
tmp_output_vec
);
...
...
python/paddle/distributed/passes/ps_trainer_pass.py
浏览文件 @
85c6937b
...
...
@@ -150,7 +150,7 @@ class DistributedOpsPass(PassBase):
print
(
'ShowClickEntry not configured, will not use'
)
show
=
_program
.
global_block
().
create_var
(
name
=
"show"
,
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
persistable
=
False
,
stop_gradient
=
True
)
_program
.
global_block
().
_insert_op
(
index
=
0
,
...
...
@@ -165,7 +165,7 @@ class DistributedOpsPass(PassBase):
clk
=
_program
.
global_block
().
create_var
(
name
=
"clk"
,
dtype
=
core
.
VarDesc
.
VarType
.
INT64
,
dtype
=
core
.
VarDesc
.
VarType
.
FP32
,
persistable
=
False
,
stop_gradient
=
True
)
_program
.
global_block
().
_insert_op
(
index
=
0
,
...
...
@@ -190,6 +190,9 @@ class DistributedOpsPass(PassBase):
padding_idx
=
ops
[
0
].
attr
(
"padding_idx"
)
is_distributed
=
ops
[
0
].
attr
(
"is_distributed"
)
op_type
=
ops
[
0
].
type
slots
=
[
op
.
attr
(
"slot"
)
for
op
in
ops
]
print
(
'debug zcb slots: '
,
slots
)
outputs
=
[
_program
.
global_block
().
vars
[
op
.
input
(
"Out@GRAD"
)[
0
]]
for
op
in
ops
...
...
@@ -204,7 +207,7 @@ class DistributedOpsPass(PassBase):
'W'
:
w
,
"Outputs"
:
outputs
,
"Shows"
:
show
,
"Clicks"
:
clk
"Clicks"
:
clk
,
},
outputs
=
{
"Outputs"
:
outputs
},
attrs
=
{
...
...
@@ -213,7 +216,8 @@ class DistributedOpsPass(PassBase):
"padding_idx"
:
padding_idx
,
"table_id"
:
table_id
,
"size"
:
self
.
emb_size
[
param
],
"use_cvm_op"
:
use_cvm_op
"use_cvm_op"
:
use_cvm_op
,
"slots"
:
slots
})
def
_pull_sparse_fuse
(
self
,
_program
,
pull_sparse_ops
,
attrs
,
send_ctx
):
...
...
python/paddle/fluid/contrib/layers/nn.py
浏览文件 @
85c6937b
...
...
@@ -1073,7 +1073,8 @@ def sparse_embedding(input,
entry
=
None
,
table_class
=
"MemorySparseTable"
,
param_attr
=
None
,
dtype
=
'float32'
):
dtype
=
'float32'
,
slot
=
None
):
r
"""
:api_attr: Static Graph
...
...
@@ -1220,6 +1221,9 @@ def sparse_embedding(input,
)
entry_str
=
entry
.
_to_attr
()
if
slot
==
None
:
slot
=
0
helper
.
append_op
(
type
=
'lookup_table'
,
inputs
=
{
'Ids'
:
input
,
...
...
@@ -1233,9 +1237,9 @@ def sparse_embedding(input,
'remote_prefetch'
:
True
,
'is_test'
:
is_test
,
'entry'
:
entry_str
,
'table_class'
:
table_class
'table_class'
:
table_class
,
'slot'
:
slot
})
return
tmp
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录