Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
0f6412c0
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f6412c0
编写于
4月 02, 2022
作者:
L
Leo Chen
提交者:
GitHub
4月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
do not use scope in op kernel (#41316)
上级
1b58ce14
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
17 addition
and
31 deletion
+17
-31
paddle/fluid/operators/pscore/distributed_lookup_table_op.h
paddle/fluid/operators/pscore/distributed_lookup_table_op.h
+17
-31
未找到文件。
paddle/fluid/operators/pscore/distributed_lookup_table_op.h
浏览文件 @
0f6412c0
...
...
@@ -26,17 +26,13 @@ template <typename DeviceContext, typename T>
class
DistributedLookupTableKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
&
scope
=
context
.
scope
();
auto
padding_idx
=
context
.
Attr
<
int64_t
>
(
"padding_idx"
);
auto
table_id
=
context
.
Attr
<
int
>
(
"table_id"
);
bool
is_test
=
context
.
Attr
<
bool
>
(
"is_test"
);
auto
embedding_name
=
context
.
InputNames
(
"W"
).
front
(
);
auto
*
var
=
context
.
InputVar
(
"W"
);
int64_t
emb_dim
=
0
;
auto
*
var
=
scope
.
FindVar
(
embedding_name
);
if
(
var
->
IsType
<
framework
::
LoDTensor
>
())
{
emb_dim
=
var
->
Get
<
framework
::
LoDTensor
>
().
dims
()[
1
];
}
else
if
(
var
->
IsType
<
phi
::
SelectedRows
>
())
{
...
...
@@ -61,35 +57,31 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
}
else
{
auto
inputs_variable
=
context
.
MultiInputVar
(
"Ids"
);
auto
outputs_variable
=
context
.
MultiOutputVar
(
"Outputs"
);
auto
inputs_name
=
context
.
InputNames
(
"Ids"
);
auto
outputs_name
=
context
.
OutputNames
(
"Outputs"
);
auto
cpu_place
=
platform
::
CPUPlace
();
framework
::
Scope
*
tmp_scope
=
scope
.
NewTmpScope
().
release
();
std
::
vector
<
const
framework
::
LoDTensor
*>
tmp_input_vec
;
auto
input_var_size
=
inputs_variable
.
size
();
std
::
vector
<
framework
::
LoDTensor
*>
tmp_output_vec
;
auto
output_var_size
=
outputs_variable
.
size
();
std
::
vector
<
std
::
shared_ptr
<
framework
::
LoDTensor
>>
tmp_tensors
;
// create temp input
for
(
size_t
idx
=
0
;
idx
<
input_var_size
;
++
idx
)
{
framework
::
Variable
*
tmp_input_var
=
tmp_scope
->
Var
(
inputs_name
[
idx
]);
framework
::
LoDTensor
*
tmp_input_tensor
=
tmp_input_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tmp_tensors
.
emplace_back
(
std
::
make_shared
<
framework
::
LoDTensor
>
());
auto
*
p
=
tmp_tensors
.
back
().
get
();
framework
::
TensorCopy
(
inputs_variable
[
idx
]
->
Get
<
framework
::
LoDTensor
>
(),
cpu_place
,
context
.
device_context
(),
tmp_input_tensor
);
tmp_input_vec
.
push_back
(
tmp_input_tensor
);
cpu_place
,
context
.
device_context
(),
p
);
tmp_input_vec
.
push_back
(
p
);
}
// create temp output
for
(
size_t
idx
=
0
;
idx
<
output_var_size
;
++
idx
)
{
framework
::
Variable
*
tmp_output_var
=
tmp_scope
->
Var
(
outputs_name
[
idx
]);
framework
::
LoDTensor
*
tmp_output_tensor
=
tmp_output_var
->
GetMutable
<
framework
::
LoDTensor
>
();
tmp_output_tensor
->
Resize
(
outputs
[
idx
]
->
dims
());
tmp_output_vec
.
push_back
(
tmp_output_tensor
);
tmp_tensors
.
emplace_back
(
std
::
make_shared
<
framework
::
LoDTensor
>
());
auto
*
p
=
tmp_tensors
.
back
().
get
();
p
->
Resize
(
outputs
[
idx
]
->
dims
());
tmp_output_vec
.
push_back
(
p
);
}
// use fleet->PullSparse
...
...
@@ -100,27 +92,21 @@ class DistributedLookupTableKernel : public framework::OpKernel<T> {
// cp temp to origin
for
(
size_t
idx
=
0
;
idx
<
output_var_size
;
++
idx
)
{
framework
::
Variable
*
tmp_output_var
=
tmp_scope
->
Var
(
outputs_name
[
idx
]);
framework
::
LoDTensor
*
tmp_output_tensor
=
tmp_output_var
->
GetMutable
<
framework
::
LoDTensor
>
();
framework
::
TensorCopy
(
*
tmp_output_
tensor
,
context
.
GetPlace
(),
context
.
device_context
(),
*
tmp_output_
vec
[
idx
]
,
context
.
GetPlace
(),
context
.
device_context
(),
outputs_variable
[
idx
]
->
GetMutable
<
framework
::
LoDTensor
>
());
}
delete
tmp_scope
;
}
auto
id_names
=
context
.
InputNames
(
"Ids"
);
auto
out_names
=
context
.
OutputNames
(
"Outputs"
);
auto
lookup_table_version
=
context
.
Attr
<
std
::
string
>
(
"lookup_table_version"
);
auto
id_vars
=
context
.
MultiInputVar
(
"Ids"
);
auto
out_vars
=
context
.
MultiOutputVar
(
"Outputs"
);
if
(
lookup_table_version
==
"lookup_table_v2"
)
{
for
(
size_t
i
=
0
;
i
<
id_names
.
size
();
++
i
)
{
auto
*
id_var
=
scope
.
FindVar
(
id_names
[
i
]);
auto
*
out_var
=
scope
.
FindVar
(
out_names
[
i
]);
auto
*
id_tensor
=
id_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
out_tensor
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
for
(
size_t
i
=
0
;
i
<
id_vars
.
size
();
++
i
)
{
auto
*
id_tensor
=
id_vars
[
i
]
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
*
out_tensor
=
out_vars
[
i
]
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
id_dims
=
id_tensor
->
dims
();
out_tensor
->
Resize
(
phi
::
make_ddim
({
static_cast
<
int64_t
>
(
id_dims
[
0
]),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录