Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
dccd013b
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
dccd013b
编写于
4月 26, 2018
作者:
Y
Yancey1989
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
refine distribute transpiler
上级
e393c86c
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
61 addition
and
11 deletion
+61
-11
paddle/fluid/operators/lookup_sparse_table_op.cc
paddle/fluid/operators/lookup_sparse_table_op.cc
+7
-4
paddle/fluid/operators/sgd_op.cc
paddle/fluid/operators/sgd_op.cc
+20
-1
paddle/fluid/operators/uniform_random_op.cc
paddle/fluid/operators/uniform_random_op.cc
+22
-2
python/paddle/fluid/distribute_transpiler.py
python/paddle/fluid/distribute_transpiler.py
+12
-4
未找到文件。
paddle/fluid/operators/lookup_sparse_table_op.cc
浏览文件 @
dccd013b
...
...
@@ -55,12 +55,16 @@ class LookupSparseTableOp : public framework::OperatorBase {
"The type of Out var should be LodTensor."
);
PADDLE_ENFORCE
(
w_var
->
IsType
<
framework
::
SelectedRows
>
(),
"The type of W var should be SelectedRows."
);
PADDLE_ENFORCE
(
ids_var
->
IsType
<
framework
::
SelectedRows
>
(),
PADDLE_ENFORCE
(
ids_var
->
IsType
<
framework
::
LoDTensor
>
(),
"The type of Ids var should be SelectedRows."
);
auto
&
ids_t
=
ids_var
->
Get
<
framework
::
SelectedRows
>
();
auto
&
ids_t
=
ids_var
->
Get
<
framework
::
LoDTensor
>
();
auto
out_t
=
out_var
->
GetMutable
<
framework
::
LoDTensor
>
();
auto
w_t
=
w_var
->
GetMutable
<
framework
::
SelectedRows
>
();
auto
keys
=
ids_t
.
rows
();
std
::
vector
<
int64_t
>
keys
;
keys
.
resize
(
ids_t
.
numel
());
for
(
size_t
i
=
0
;
i
<
ids_t
.
numel
();
++
i
)
{
keys
[
i
]
=
ids_t
.
data
<
int64_t
>
()[
i
];
}
// TODO(Yancey1989): support CUDA Place for the sparse table
platform
::
CPUPlace
cpu
;
...
...
@@ -68,7 +72,6 @@ class LookupSparseTableOp : public framework::OperatorBase {
out_shape
[
0
]
=
keys
.
size
();
out_t
->
Resize
(
out_shape
);
out_t
->
mutable_data
(
cpu
,
w_t
->
value
().
type
());
PADDLE_ENFORCE_EQ
(
framework
::
ToDataType
(
w_t
->
value
().
type
()),
framework
::
proto
::
VarType
::
FP32
,
"The sparse table only support FP32"
);
...
...
paddle/fluid/operators/sgd_op.cc
浏览文件 @
dccd013b
...
...
@@ -48,6 +48,24 @@ class SGDOp : public framework::OperatorWithKernel {
}
};
class
SGDOpInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
input_var
=
op_desc
.
Input
(
"Param"
)[
0
];
for
(
auto
&
out_var
:
op_desc
.
Output
(
"ParamOut"
))
{
if
(
block
->
FindRecursiveOrCreateVar
(
input_var
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var
).
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
}
};
class
SGDOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
SGDOpMaker
(
OpProto
*
proto
,
OpAttrChecker
*
op_checker
)
...
...
@@ -74,5 +92,6 @@ $$param\_out = param - learning\_rate * grad$$
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
);
REGISTER_OPERATOR
(
sgd
,
ops
::
SGDOp
,
ops
::
SGDOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
ops
::
SGDOpInferVarType
);
REGISTER_OP_CPU_KERNEL
(
sgd
,
ops
::
SGDOpKernel
<
float
>
,
ops
::
SGDOpKernel
<
double
>
);
paddle/fluid/operators/uniform_random_op.cc
浏览文件 @
dccd013b
...
...
@@ -116,11 +116,31 @@ uniform distribution.
.
SetDefault
(
framework
::
proto
::
VarType
::
FP32
);
}
};
class
UniformRandomOpVarTypeInference
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
const
framework
::
OpDesc
&
op_desc
,
framework
::
BlockDesc
*
block
)
const
override
{
auto
out_var_name
=
op_desc
.
Output
(
"Out"
).
front
();
if
(
block
->
FindRecursiveOrCreateVar
(
out_var_name
).
GetType
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
SELECTED_ROWS
);
}
else
{
block
->
FindRecursiveOrCreateVar
(
out_var_name
)
.
SetType
(
framework
::
proto
::
VarType
::
LOD_TENSOR
);
}
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_WITHOUT_GRADIENT
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
);
REGISTER_OPERATOR
(
uniform_random
,
paddle
::
operators
::
UniformRandomOp
,
paddle
::
operators
::
UniformRandomOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
,
paddle
::
operators
::
UniformRandomOpVarTypeInference
);
REGISTER_OP_CPU_KERNEL
(
uniform_random
,
paddle
::
operators
::
CPUUniformRandomKernel
<
float
>
,
paddle
::
operators
::
CPUUniformRandomKernel
<
double
>
);
...
...
python/paddle/fluid/distribute_transpiler.py
浏览文件 @
dccd013b
...
...
@@ -650,7 +650,7 @@ class DistributeTranspiler:
shape
=
trainer_out
.
shape
,
dtype
=
trainer_out
.
dtype
)
prefetch_block
.
append_op
(
type
=
LOOKUP_TABLE_TYPE
,
type
=
"lookup_sparse_table"
,
inputs
=
{
'Ids'
:
pserver_ids
,
"W"
:
table_var
},
outputs
=
{
"Out"
:
pserver_out
},
...
...
@@ -674,9 +674,17 @@ class DistributeTranspiler:
# STEP: create table optimize block
# create table param and grad var in pserver program
param_var
=
_clone_var
(
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
self
.
table_name
])
#param_var = _clone_var(
# pserver_program.global_block(),
# self.origin_program.global_block().vars[self.table_name])
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
self
.
table_name
]
param_var
=
pserver_program
.
global_block
().
create_var
(
name
=
origin_param_var
.
name
,
shape
=
origin_param_var
.
shape
,
dtype
=
origin_param_var
.
dtype
,
type
=
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
persistable
=
True
)
grad_var
=
_clone_var
(
pserver_program
.
global_block
(),
self
.
origin_program
.
global_block
().
vars
[
framework
.
grad_var_name
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录