Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
eeca5ef6
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
eeca5ef6
编写于
3月 09, 2021
作者:
J
JZ-LIANG
提交者:
sandyhouse
3月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
479efeeb
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
22 addition
and
10 deletion
+22
-10
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+18
-7
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+4
-3
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
eeca5ef6
...
@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
...
@@ -16,7 +16,7 @@ from paddle.fluid import unique_name, core
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
OpRole
,
OP_ROLE_VAR_KEY
,
CollectiveHelper
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
,
is_optimizer_op
,
is_update_op
,
OpRole
from
paddle.distributed.fleet.meta_optimizers.common
import
is_backward_op
,
is_optimizer_op
,
is_update_op
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.meta_optimizer_base
import
MetaOptimizerBase
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.shard
import
Shard
,
ProgramSegment
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
...
@@ -208,7 +208,8 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -208,7 +208,8 @@ class ShardingOptimizer(MetaOptimizerBase):
#pp_optimizer._clear_gradients(main_block, param_list)
#pp_optimizer._clear_gradients(main_block, param_list)
accumulated_grad_names
=
pp_optimizer
.
_accumulate_gradients
(
accumulated_grad_names
=
pp_optimizer
.
_accumulate_gradients
(
main_block
)
main_block
)
accumulated_grad_names
=
sorted
(
accumulated_grad_names
)
# accumulated_grad_names = sorted(accumulated_grad_names)
print
(
"persistable FP32 grad: "
)
print
(
accumulated_grad_names
)
print
(
accumulated_grad_names
)
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
main_block
)
main_block
)
...
@@ -218,7 +219,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -218,7 +219,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
sharding_ring_id
,
self
.
sharding_ring_id
,
accumulated_grad_names
,
accumulated_grad_names
,
self
.
_shard
,
self
.
_shard
,
OpRole
.
Optimize
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
,
use_calc_stream
=
True
)
use_calc_stream
=
True
)
#if not self._shard.has_param(param_name): continue
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
##if not main_block.has_var(grad_name): continue
...
@@ -470,10 +471,20 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -470,10 +471,20 @@ class ShardingOptimizer(MetaOptimizerBase):
self
.
_main_program
.
global_block
())
self
.
_main_program
.
global_block
())
def
_wait
(
self
,
):
def
_wait
(
self
,
):
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
# only the first parallelsm group that init nccl need to be wait.
current_endpoint
=
endpoints
[
self
.
role_maker
.
_worker_index
()]
if
self
.
_as_outer_parallelism
:
if
self
.
role_maker
.
_worker_index
()
==
0
:
endpoints
=
self
.
role_maker
.
_get_trainer_endpoints
()
self
.
_collective_helper
.
_wait
(
current_endpoint
,
endpoints
)
current_endpoint
=
endpoints
[
self
.
role_maker
.
_worker_index
()]
else
:
endpoints
=
self
.
sharding_group_endpoints
[:]
current_endpoint
=
self
.
sharding_group_endpoints
[
self
.
sharding_rank
]
if
self
.
_as_outer_parallelism
:
if
self
.
role_maker
.
_worker_index
()
==
0
:
self
.
_collective_helper
.
_wait
(
current_endpoint
,
endpoints
)
else
:
if
self
.
sharding_rank
==
0
:
self
.
_collective_helper
.
_wait
(
current_endpoint
,
endpoints
)
# def _wait(self, ):
# def _wait(self, ):
# # only the first parallelsm group that init nccl need to be wait.
# # only the first parallelsm group that init nccl need to be wait.
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
eeca5ef6
...
@@ -4879,8 +4879,9 @@ class PipelineOptimizer(object):
...
@@ -4879,8 +4879,9 @@ class PipelineOptimizer(object):
if
'@BroadCast'
in
param_name
:
if
'@BroadCast'
in
param_name
:
param_name
=
param_name
[
0
:
param_name
.
find
(
'@BroadCast'
)]
param_name
=
param_name
[
0
:
param_name
.
find
(
'@BroadCast'
)]
# clear gradient
# clear gradient
assert
param_name
in
self
.
origin_main_block
.
vars
,
"[{}] not in original main block"
.
format
(
param_name
)
param_grad_name
=
self
.
_append_grad_suffix
(
param_name
)
param_grad_name
=
self
.
_append_grad_suffix
(
param_name
)
accumulated_grad_names
.
append
(
param_grad_name
)
if
not
block
.
has_var
(
param_grad_name
):
if
not
block
.
has_var
(
param_grad_name
):
self
.
_create_var
(
self
.
_create_var
(
block
,
self
.
origin_main_block
.
vars
[
param_name
],
block
,
self
.
origin_main_block
.
vars
[
param_name
],
...
@@ -4925,7 +4926,7 @@ class PipelineOptimizer(object):
...
@@ -4925,7 +4926,7 @@ class PipelineOptimizer(object):
#self._op_role_var_key: op_role_var
#self._op_role_var_key: op_role_var
})
})
#offset += 1
#offset += 1
# accumulated_gradient
_names.append(param_grad_var.name)
accumulated_grad
_names
.
append
(
param_grad_var
.
name
)
else
:
else
:
grad_name
=
op_role_var
[
i
+
1
]
# with _0 suffix
grad_name
=
op_role_var
[
i
+
1
]
# with _0 suffix
grad_var
=
block
.
vars
[
grad_name
]
grad_var
=
block
.
vars
[
grad_name
]
...
@@ -4962,7 +4963,7 @@ class PipelineOptimizer(object):
...
@@ -4962,7 +4963,7 @@ class PipelineOptimizer(object):
# self._op_role_var_key: op_role_var
# self._op_role_var_key: op_role_var
})
})
offset
+=
1
offset
+=
1
# accumulated_gradient
_names.append(param_grad_var.name)
accumulated_grad
_names
.
append
(
param_grad_var
.
name
)
#real_grad_name = grad_name[0:grad_name.find(
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
#real_grad_var = block.vars[
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录