Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
866d6bfe
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
866d6bfe
编写于
11月 07, 2018
作者:
Q
Qiao Longfei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
dist table support other optimize and regular config
上级
6449faec
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
85 addition
and
36 deletion
+85
-36
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+13
-6
python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py
.../fluid/transpiler/details/distribute_lookuptable_utils.py
+66
-0
python/paddle/fluid/transpiler/distribute_transpiler.py
python/paddle/fluid/transpiler/distribute_transpiler.py
+6
-30
未找到文件。
python/paddle/fluid/optimizer.py
浏览文件 @
866d6bfe
...
...
@@ -13,21 +13,23 @@
# limitations under the License.
from
__future__
import
print_function
import
re
import
sys
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
paddle.fluid.framework
import
Program
,
Variable
,
name_scope
,
default_main_program
import
paddle.fluid.transpiler.details.distribute_lookuptable_utils
as
distribute_lookuptable_utils
from
.
import
framework
from
.
import
layers
from
.
import
unique_name
from
.backward
import
append_backward
from
.clip
import
append_gradient_clip_ops
,
error_clip_callback
from
.framework
import
program_guard
from
.
import
unique_name
from
.initializer
import
Constant
from
.layer_helper
import
LayerHelper
from
.regularizer
import
append_regularization_ops
from
.clip
import
append_gradient_clip_ops
,
error_clip_callback
from
contextlib
import
contextmanager
from
.layers
import
ops
from
.regularizer
import
append_regularization_ops
__all__
=
[
'SGD'
,
'Momentum'
,
'Adagrad'
,
'Adam'
,
'Adamax'
,
'DecayedAdagrad'
,
'Ftrl'
,
...
...
@@ -260,6 +262,9 @@ class Optimizer(object):
params_grads
=
sorted
(
params_grads
,
key
=
lambda
x
:
x
[
0
].
name
)
params_grads
,
table_param_and_grad
,
table_optimize_op
=
\
distribute_lookuptable_utils
.
process_distribute_lookuptable
(
loss
.
block
.
program
,
params_grads
,
self
.
_learning_rate
)
params_grads
=
append_gradient_clip_ops
(
params_grads
)
# Add regularization if any
...
...
@@ -268,6 +273,8 @@ class Optimizer(object):
optimize_ops
=
self
.
_create_optimization_pass
(
params_grads
,
loss
,
startup_program
)
optimize_ops
.
append
(
table_optimize_op
)
params_grads
.
append
(
table_param_and_grad
)
return
optimize_ops
,
params_grads
...
...
python/paddle/fluid/transpiler/details/distribute_lookuptable_utils.py
0 → 100644
浏览文件 @
866d6bfe
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle.fluid.optimizer
as
optimizer
import
paddle.fluid.framework
as
framework
LOOKUP_TABLE_TYPE
=
"lookup_table"
def
find_distributed_lookup_table
(
program
):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops
=
[]
# support only one distributed_lookup_table now
table_name
=
None
for
op
in
program
.
global_block
().
ops
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
attr
(
'is_distributed'
)
is
True
:
if
table_name
is
None
:
table_name
=
op
.
input
(
"W"
)[
0
]
if
table_name
!=
op
.
input
(
"W"
)[
0
]:
raise
RuntimeError
(
"all distributed lookup_table_ops"
" should have only one table"
)
distributed_lookup_table_ops
.
append
(
op
)
else
:
if
table_name
is
not
None
:
assert
op
.
input
(
"W"
)[
0
]
!=
table_name
return
table_name
def
process_distribute_lookuptable
(
program
,
param_grads
,
learning_rate
):
table_name
=
find_distributed_lookup_table
(
program
)
table_param
=
None
table_grad
=
None
new_param_grads
=
[]
for
p
,
g
in
param_grads
:
if
p
.
name
==
table_name
:
if
table_param
is
not
None
:
raise
RuntimeError
(
"multi dist table var found, only support one now!"
)
table_param
=
p
table_grad
=
g
else
:
new_param_grads
.
append
((
p
,
g
))
sgd_op
=
None
if
table_param
is
not
None
:
with
table_param
.
block
.
program
.
_optimized_guard
(
[
table_param
,
table_grad
]),
framework
.
name_scope
(
"optimizer"
):
sgd_optimizer
=
optimizer
.
SGD
(
learning_rate
)
sgd_op
=
sgd_optimizer
.
_append_optimize_op
(
table_param
.
block
,
(
table_param
,
table_grad
))
return
new_param_grads
,
(
table_param
,
table_grad
),
sgd_op
python/paddle/fluid/transpiler/distribute_transpiler.py
浏览文件 @
866d6bfe
...
...
@@ -31,18 +31,17 @@ Steps to transpile pserver:
"""
import
math
import
sys
import
numpy
as
np
import
collections
import
six
import
logging
from
.ps_dispatcher
import
RoundRobin
,
HashName
,
PSDispatcher
from
.ps_dispatcher
import
RoundRobin
,
PSDispatcher
from
..
import
core
,
framework
,
unique_name
from
..framework
import
Program
,
default_main_program
,
\
default_startup_program
,
Block
,
\
Parameter
,
grad_var_name
from
.details
import
*
from
.details.distribute_lookuptable_utils
import
find_distributed_lookup_table
from
functools
import
reduce
LOOKUP_TABLE_TYPE
=
"lookup_table"
...
...
@@ -292,7 +291,8 @@ class DistributeTranspiler(object):
self
.
optimize_ops
,
self
.
params_grads
=
self
.
_get_optimize_pass
()
ps_dispatcher
=
self
.
config
.
split_method
(
self
.
pserver_endpoints
)
self
.
has_distributed_lookup_table
=
self
.
_has_distributed_lookup_table
()
self
.
table_name
=
find_distributed_lookup_table
(
self
.
origin_program
)
self
.
has_distributed_lookup_table
=
self
.
table_name
!=
None
self
.
param_name_to_grad_name
=
dict
()
self
.
grad_name_to_param_name
=
dict
()
for
param_var
,
grad_var
in
self
.
params_grads
:
...
...
@@ -966,28 +966,6 @@ to transpile() call.")
# ====================== private transpiler functions =====================
def
_has_distributed_lookup_table
(
self
):
# process lookup_table_op
# 1. check all lookup_table_op is distributed
# 2. check all lookup_table_op share the same table.
distributed_lookup_table_ops
=
[]
# support only one distributed_lookup_table now
self
.
table_name
=
None
for
op
in
self
.
origin_program
.
global_block
().
ops
:
if
op
.
type
==
LOOKUP_TABLE_TYPE
:
if
op
.
attr
(
'is_distributed'
)
is
True
:
if
self
.
table_name
is
None
:
self
.
table_name
=
op
.
input
(
"W"
)[
0
]
if
self
.
table_name
!=
op
.
input
(
"W"
)[
0
]:
raise
RuntimeError
(
"all distributed lookup_table_ops"
" should have only one table"
)
distributed_lookup_table_ops
.
append
(
op
)
else
:
if
self
.
table_name
is
not
None
:
assert
op
.
input
(
"W"
)[
0
]
!=
self
.
table_name
return
len
(
distributed_lookup_table_ops
)
>
0
def
_update_dist_lookup_table_vars
(
self
,
param_list
,
grad_list
,
params_grads
):
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
...
...
@@ -1259,9 +1237,8 @@ to transpile() call.")
# create table param and grad var in pserver program
# create table optimize block in pserver program
table_opt_op
=
[
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
op
for
op
in
self
.
optimize_ops
if
'Param'
in
op
.
input_names
and
op
.
input
(
"Param"
)[
0
]
==
self
.
table_name
][
0
]
origin_param_var
=
self
.
origin_program
.
global_block
().
vars
[
...
...
@@ -1341,7 +1318,6 @@ to transpile() call.")
"""
create a new block to handle save checkpoint.
"""
import
os
pserver_program
.
global_block
().
create_var
(
name
=
"kLookupTablePath"
,
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录