Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
fbcdb29d
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fbcdb29d
编写于
11月 07, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix import issue
上级
866d6bfe
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
29 addition
and
28 deletion
+29
-28
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+29
-4
python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py
.../fluid/transpiler/details/distribute_lookuptable_utils.py
+0
-24
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
fbcdb29d
...
...
@@ -18,7 +18,7 @@ from collections import defaultdict
from
contextlib
import
contextmanager
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
import
paddle.fluid.transpiler.details.distribute_lookuptable_utils
as
distribute_lookuptable_utils
from
paddle.fluid.transpiler.details.distribute_lookuptable_utils
import
find_distributed_lookup_table
from
.
import
framework
from
.
import
layers
...
...
@@ -40,6 +40,30 @@ __all__ = [
]
def
_process_distribute_lookuptable
(
program
,
param_grads
,
learning_rate
):
table_name
=
find_distributed_lookup_table
(
program
)
table_param
=
None
table_grad
=
None
new_param_grads
=
[]
for
p
,
g
in
param_grads
:
if
p
.
name
==
table_name
:
if
table_param
is
not
None
:
raise
RuntimeError
(
"multi dist table var found, only support one now!"
)
table_param
=
p
table_grad
=
g
else
:
new_param_grads
.
append
((
p
,
g
))
sgd_op
=
None
if
table_param
is
not
None
:
with
table_param
.
block
.
program
.
_optimized_guard
(
[
table_param
,
table_grad
]),
framework
.
name_scope
(
"optimizer"
):
sgd_optimizer
=
SGD
(
learning_rate
)
sgd_op
=
sgd_optimizer
.
_append_optimize_op
(
table_param
.
block
,
(
table_param
,
table_grad
))
return
new_param_grads
,
(
table_param
,
table_grad
),
sgd_op
class
Optimizer
(
object
):
"""Optimizer Base class.
...
...
@@ -263,7 +287,7 @@ class Optimizer(object):
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
distribute_lookuptable_utils
.
process_distribute_lookuptable
(
loss
.
block
.
program
,
params_grads
,
self
.
_learning_rate
)
_
process_distribute_lookuptable
(
loss
.
block
.
program
,
params_grads
,
self
.
_learning_rate
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
...
...
@@ -273,8 +297,9 @@ class Optimizer(object):
optimize_ops
=
self
.
_create_optimization_pass
(
params_grads
,
loss
,
startup_program
)
optimize_ops
.
append
(
table_optimize_op
)
params_grads
.
append
(
table_param_and_grad
)
if
table_optimize_op
is
not
None
:
optimize_ops
.
append
(
table_optimize_op
)
params_grads
.
append
(
table_param_and_grad
)
return
optimize_ops
,
params_grads
...
...
python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py
浏览文件 @
fbcdb29d
...
...
@@ -40,27 +40,3 @@ def find_distributed_lookup_table(program):
assert
op
.
input
(
"W"
)[
0
]
!=
table_name
return
table_name
def
process_distribute_lookuptable
(
program
,
param_grads
,
learning_rate
):
table_name
=
find_distributed_lookup_table
(
program
)
table_param
=
None
table_grad
=
None
new_param_grads
=
[]
for
p
,
g
in
param_grads
:
if
p
.
name
==
table_name
:
if
table_param
is
not
None
:
raise
RuntimeError
(
"multi dist table var found, only support one now!"
)
table_param
=
p
table_grad
=
g
else
:
new_param_grads
.
append
((
p
,
g
))
sgd_op
=
None
if
table_param
is
not
None
:
with
table_param
.
block
.
program
.
_optimized_guard
(
[
table_param
,
table_grad
]),
framework
.
name_scope
(
"optimizer"
):
sgd_optimizer
=
optimizer
.
SGD
(
learning_rate
)
sgd_op
=
sgd_optimizer
.
_append_optimize_op
(
table_param
.
block
,
(
table_param
,
table_grad
))
return
new_param_grads
,
(
table_param
,
table_grad
),
sgd_op
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录