Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
47042a97
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
47042a97
编写于
3月 11, 2021
作者:
S
sandyhouse
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
d1c428da
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
180 addition
and
271 deletion
+180
-271
paddle/fluid/framework/pipeline_trainer.cc
paddle/fluid/framework/pipeline_trainer.cc
+4
-1
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+127
-142
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+49
-128
未找到文件。
paddle/fluid/framework/pipeline_trainer.cc
浏览文件 @
47042a97
...
@@ -82,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
...
@@ -82,7 +82,10 @@ void PipelineTrainer::CopyParameters(int microbatch_id,
for
(
auto
&
var
:
global_block
.
AllVars
())
{
for
(
auto
&
var
:
global_block
.
AllVars
())
{
bool
is_param_grad
=
false
;
bool
is_param_grad
=
false
;
size_t
pos
=
0
;
size_t
pos
=
0
;
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
)
{
// A magic suffix to indicated the merged gradient.
std
::
string
magicSuffix
=
"MERGED"
;
if
((
pos
=
var
->
Name
().
find
(
kGradVarSuffix
))
!=
std
::
string
::
npos
&&
var
->
Name
().
find
(
magicSuffix
)
!=
std
::
string
::
npos
)
{
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
auto
prefix_name
=
var
->
Name
().
substr
(
0
,
pos
);
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
if
(
param_map
.
find
(
prefix_name
)
!=
param_map
.
end
())
{
is_param_grad
=
true
;
is_param_grad
=
true
;
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
47042a97
...
@@ -153,6 +153,9 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -153,6 +153,9 @@ class ShardingOptimizer(MetaOptimizerBase):
if
self
.
use_pipeline
:
if
self
.
use_pipeline
:
pp_optimizer
.
_rename_gradient_var_name
(
main_block
)
pp_optimizer
.
_rename_gradient_var_name
(
main_block
)
pp_optimizer
.
_accumulate_gradients
(
main_block
)
with
open
(
"main_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
main_program
))
# step1: set_up
# step1: set_up
self
.
_set_up
(
params_grads
)
self
.
_set_up
(
params_grads
)
...
@@ -210,23 +213,6 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -210,23 +213,6 @@ class ShardingOptimizer(MetaOptimizerBase):
# if self._shard.has_param(param_name):
# if self._shard.has_param(param_name):
# param_list.append(param_name)
# param_list.append(param_name)
#pp_optimizer._clear_gradients(main_block, param_list)
#pp_optimizer._clear_gradients(main_block, param_list)
accumulated_grad_names
=
pp_optimizer
.
_accumulate_gradients
(
main_block
,
pp_allreduce_in_optimize
=
self
.
pp_allreduce_in_optimize
)
# accumulated_grad_names = sorted(accumulated_grad_names)
if
self
.
pp_allreduce_in_optimize
:
print
(
"persistable FP32 grad: "
)
print
(
accumulated_grad_names
)
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
main_block
)
insert_reduce_ops
(
main_block
,
first_optimize_op_index
,
self
.
sharding_ring_id
,
accumulated_grad_names
,
self
.
_shard
,
core
.
op_proto_and_checker_maker
.
OpRole
.
Optimize
,
use_calc_stream
=
True
)
#if not self._shard.has_param(param_name): continue
#if not self._shard.has_param(param_name): continue
##if not main_block.has_var(grad_name): continue
##if not main_block.has_var(grad_name): continue
#assert main_block.has_var(grad_name)
#assert main_block.has_var(grad_name)
...
@@ -246,131 +232,130 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -246,131 +232,130 @@ class ShardingOptimizer(MetaOptimizerBase):
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
# })
pass
#def _create_var(block, ref_var, name):
#def _create_var(block, ref_var, name):
# """
# """
# Create a new var for block, which has the same type,
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# shape and dtype as ref_var, then rename it with the
# name `name`.
# name `name`.
# """
# """
# new_var = block.create_var(
# new_var = block.create_var(
# name=name,
# name=name,
# shape=ref_var.shape,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# dtype=ref_var.dtype,
# type=ref_var.type,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
# return new_var
#def _rename_arg(op, old_name, new_name):
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# assert main_block.has_var(grad_name)
# use_fp16 = False
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_grad_var_name)
# cast_var.persistable = False
# cast_var.persistable = False
# main_block.append_op(
# main_block.append_op(
# #index=offset + 1,
# #index=offset + 1,
# type='cast',
# type='cast',
# inputs={'X': grad_var},
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# outputs={'Out': cast_var},
# attrs={
# attrs={
# 'in_dtype': grad_var.dtype,
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# })
# #offset += 1
# #offset += 1
# main_block.append_op(
# main_block.append_op(
# #index=offset + 1,
# #index=offset + 1,
# type='sum',
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# outputs={'Out': fp16_grad_var},
# attrs={
# attrs={
# 'op_role':
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# 'op_role_var': op_role_var
# })
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# offset = index
# if is_backward_op(op) and (
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# if len(op_role_var) == 0:
# continue
# continue
# assert len(op_role_var) % 2 == 0
# assert len(op_role_var) % 2 == 0
# offset = index
# offset = index
# for i in range(0, len(op_role_var), 2):
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_grad_var_name)
# new_var.persistable = False
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# main_block._insert_op(
# index=offset + 1,
# index=offset + 1,
# type='sum',
# type='sum',
# inputs={'X': [grad_var, new_var]},
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# outputs={'Out': grad_var},
# attrs={
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# 'op_role_var': op_role_var
# })
# })
# offset += 1
# offset += 1
# if 'cast_fp16' in grad_name:
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_grad_var_name)
# cast_var.persistable = False
# cast_var.persistable = False
# main_block._insert_op(
# main_block._insert_op(
# index=offset + 1,
# index=offset + 1,
# type='cast',
# type='cast',
# inputs={'X': fp32_grad_var},
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# outputs={'Out': cast_var},
# attrs={
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# # self._op_role_var_key: op_role_var
# })
# })
# offset += 1
# offset += 1
# main_block._insert_op(
# main_block._insert_op(
# index=offset + 1,
# index=offset + 1,
# type='sum',
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# outputs={'Out': grad_var},
# attrs={
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
# 'op_role_var': op_role_var})
main_block
.
_sync_with_cpp
()
main_block
.
_sync_with_cpp
()
# TODO(wangxi): add optimize offload
# TODO(wangxi): add optimize offload
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
47042a97
...
@@ -4064,11 +4064,8 @@ class PipelineOptimizer(object):
...
@@ -4064,11 +4064,8 @@ class PipelineOptimizer(object):
return
None
return
None
def
_rename_arg
(
self
,
op
,
old_name
,
new_name
):
def
_rename_arg
(
self
,
op
,
old_name
,
new_name
):
op_desc
=
op
.
desc
op
.
_rename_input
(
old_name
,
new_name
)
if
isinstance
(
op_desc
,
tuple
):
op
.
_rename_output
(
old_name
,
new_name
)
op_desc
=
op_desc
[
0
]
op_desc
.
_rename_input
(
old_name
,
new_name
)
op_desc
.
_rename_output
(
old_name
,
new_name
)
def
_create_var
(
self
,
block
,
ref_var
,
name
):
def
_create_var
(
self
,
block
,
ref_var
,
name
):
"""
"""
...
@@ -4823,48 +4820,33 @@ class PipelineOptimizer(object):
...
@@ -4823,48 +4820,33 @@ class PipelineOptimizer(object):
def
_rename_gradient_var_name
(
self
,
block
):
def
_rename_gradient_var_name
(
self
,
block
):
for
index
,
op
in
enumerate
(
block
.
ops
):
for
index
,
op
in
enumerate
(
block
.
ops
):
if
self
.
_is_backward_op
(
op
)
and
(
if
not
self
.
_is_optimize_op
(
op
):
continue
self
.
_op_role_var_key
in
op
.
attr_names
):
input_names
=
op
.
input_arg_names
op_role_var
=
op
.
attr
(
self
.
_op_role_var_key
)
output_names
=
op
.
output_arg_names
in_out_names
=
input_names
+
output_names
if
len
(
op_role_var
)
==
0
:
# append "MERGED" to the names of parameter gradients,
continue
# and mofify the op_role_var attribute (by rename_arg func).
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
for
name
in
in_out_names
:
grad_name
=
op_role_var
[
i
+
1
]
if
not
core
.
grad_var_suffix
()
in
name
:
continue
grad_var
=
block
.
vars
[
grad_name
]
param_name
=
name
.
strip
(
core
.
grad_var_suffix
())
new_grad_var_name
=
unique_name
.
generate
(
grad_name
)
new_grad_name
=
name
+
"@MERGED"
new_var
=
self
.
_create_var
(
block
,
grad_var
,
self
.
_rename_arg
(
op
,
name
,
new_grad_name
)
new_grad_var_name
)
new_var
.
persistable
=
False
self
.
_rename_arg
(
op
,
grad_name
,
new_grad_var_name
)
def
_accumulate_gradients
(
self
,
block
,
pp_allreduce_in_optimize
=
False
):
def
_accumulate_gradients
(
self
,
block
,
pp_allreduce_in_optimize
=
False
):
"""
"""
Accumulate the gradients generated in microbatch to the one in mini-batch.
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
"""
"""
# the name of real grad vars that should be allreduce
# accumulated_gradient_names = []
first_optimize_op_index
=
None
accumulated_grad_names
=
[]
for
index
,
op
in
reversed
(
tuple
(
enumerate
(
list
(
block
.
ops
)))):
for
index
,
op
in
reversed
(
tuple
(
enumerate
(
list
(
block
.
ops
)))):
# remove the cast op of fp16 grad to fp32 grad
# remove the cast op of fp16 grad to fp32 grad
if
self
.
_is_optimize_op
(
op
)
and
op
.
type
==
'cast'
:
if
self
.
_is_optimize_op
(
op
)
and
op
.
type
==
'cast'
:
in_name
=
op
.
input_arg_names
[
0
]
in_name
=
op
.
input_arg_names
[
0
]
out_name
=
op
.
output_arg_names
[
0
]
out_name
=
op
.
output_arg_names
[
0
]
if
out_name
.
strip
(
'@GRAD'
)
in
self
.
_param_device_map
:
if
out_name
.
strip
(
'@GRAD
@MERGED
'
)
in
self
.
_param_device_map
:
assert
in_name
.
replace
(
'.cast_fp16'
,
''
)
==
out_name
assert
in_name
.
replace
(
'.cast_fp16'
,
''
)
==
out_name
block
.
_remove_op
(
index
)
block
.
_remove_op
(
index
)
continue
continue
if
not
self
.
_is_optimize_op
(
op
)
and
not
first_optimize_op_index
:
first_optimize_op_index
=
index
+
1
if
block
.
ops
[
first_optimize_op_index
].
type
==
'c_sync_comm_stream'
:
block
.
ops
[
first_optimize_op_index
].
_set_attr
(
self
.
_op_role_key
,
self
.
_op_role
.
Backward
)
first_optimize_op_index
+=
1
if
self
.
_is_backward_op
(
op
)
and
(
if
self
.
_is_backward_op
(
op
)
and
(
self
.
_op_role_var_key
in
op
.
attr_names
):
self
.
_op_role_var_key
in
op
.
attr_names
):
op_role_var
=
op
.
attr
(
self
.
_op_role_var_key
)
op_role_var
=
op
.
attr
(
self
.
_op_role_var_key
)
...
@@ -4872,143 +4854,80 @@ class PipelineOptimizer(object):
...
@@ -4872,143 +4854,80 @@ class PipelineOptimizer(object):
if
len
(
op_role_var
)
==
0
:
if
len
(
op_role_var
)
==
0
:
continue
continue
assert
len
(
op_role_var
)
%
2
==
0
assert
len
(
op_role_var
)
%
2
==
0
op
.
_remove_attr
(
self
.
_op_role_var_key
)
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
for
i
in
range
(
0
,
len
(
op_role_var
),
2
):
offset
=
0
offset
=
1
param_name
=
op_role_var
[
i
]
param_name
=
op_role_var
[
i
]
assert
block
.
has_var
(
param_name
),
(
if
not
pp_allreduce_in_optimize
:
"parameter {} not in "
if
not
block
.
has_var
(
param_name
):
"current block."
.
format
(
param_name
))
continue
if
'@BroadCast'
in
param_name
:
param_name
=
param_name
[
0
:
param_name
.
find
(
'@BroadCast'
)]
# clear gradient
# clear gradient
assert
param_name
in
self
.
origin_main_block
.
vars
,
"[{}] not in original main block"
.
format
(
assert
param_name
in
self
.
origin_main_block
.
vars
,
"[{}] not in original main block"
.
format
(
param_name
)
param_name
)
param_grad_name
=
self
.
_append_grad_suffix
(
param_name
)
param_grad_name
=
self
.
_append_grad_suffix
(
param_name
)
if
not
block
.
has_var
(
param_grad_name
):
merged_param_grad_name
=
param_grad_name
+
'@MERGED'
self
.
_create_var
(
if
not
block
.
has_var
(
merged_param_grad_name
):
block
,
self
.
origin_main_
block
.
vars
[
param_name
],
self
.
_create_var
(
block
,
block
.
vars
[
param_name
],
param_grad_name
)
merged_
param_grad_name
)
assert
block
.
has_var
(
param_grad_name
)
assert
block
.
has_var
(
merged_
param_grad_name
)
param_grad_var
=
block
.
var
(
param_grad_name
)
param_grad_var
=
block
.
var
(
param_grad_name
)
param_grad_var
.
persistable
=
True
merged_param_grad_var
=
block
.
var
(
merged_param_grad_name
)
merged_param_grad_var
.
persistable
=
True
block
.
_insert_op
(
block
.
_insert_op
(
index
=
first_optimize_op_
index
+
offset
,
index
=
index
+
offset
,
type
=
'fill_constant'
,
type
=
'fill_constant'
,
inputs
=
{},
inputs
=
{},
outputs
=
{
'Out'
:
[
param_grad_var
]},
outputs
=
{
'Out'
:
[
merged_
param_grad_var
]},
attrs
=
{
attrs
=
{
'shape'
:
param_grad_var
.
shape
,
'shape'
:
merged_
param_grad_var
.
shape
,
'dtype'
:
param_grad_var
.
dtype
,
'dtype'
:
merged_
param_grad_var
.
dtype
,
'value'
:
float
(
0
),
'value'
:
float
(
0
),
# self._op_device_key: device,
# a trick to run this op once per mini-batch
# a trick to run this op once per mini-batch
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
.
LRSched
,
self
.
_op_role_key
:
self
.
_op_role
.
Optimize
.
LRSched
,
})
})
#
offset += 1
offset
+=
1
grad_name
=
op_role_var
[
i
+
1
]
# with _0 suffix
grad_name
=
op_role_var
[
i
+
1
]
grad_var
=
block
.
vars
[
grad_name
]
grad_var
=
block
.
vars
[
grad_name
]
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD' # without _0 suffix
#real_grad_var = block.vars[
# real_grad_name] # without _0 suffix
# new_grad_var_name = unique_name.generate(grad_name)
# new_var = self._create_var(block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# self._rename_arg(op, grad_name, new_grad_var_name)
if
not
'cast_fp16'
in
grad_name
:
if
not
'cast_fp16'
in
grad_name
:
block
.
_insert_op
(
block
.
_insert_op
(
index
=
index
+
1
,
index
=
index
+
offset
,
type
=
'sum'
,
type
=
'sum'
,
inputs
=
{
'X'
:
[
grad_var
,
param_grad_var
]},
inputs
=
{
'X'
:
[
grad_var
,
merged_
param_grad_var
]},
outputs
=
{
'Out'
:
param_grad_var
},
outputs
=
{
'Out'
:
merged_
param_grad_var
},
attrs
=
{
attrs
=
{
#self._op_device_key: device,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
#self._op_role_var_key: op_role_var
})
})
#offset += 1
offset
+=
1
accumulated_grad_names
.
append
(
param_grad_var
.
name
)
else
:
else
:
grad_name
=
op_role_var
[
i
+
1
]
# with _0 suffix
# cast gradient to fp32 to accumulate to merged gradient
grad_var
=
block
.
vars
[
grad_name
]
cast_grad_var_name
=
param_grad_name
+
'@TMP'
#fp32_grad_var_name = param_name + core.grad_var_suffix(
#) # without _0 suffix
#fp32_grad_var = block.vars[fp32_grad_var_name]
#fp32_grad_var.persistable = True
cast_grad_var_name
=
unique_name
.
generate
(
param_grad_name
)
cast_grad_var
=
self
.
_create_var
(
block
,
param_grad_var
,
cast_grad_var
=
self
.
_create_var
(
block
,
param_grad_var
,
cast_grad_var_name
)
cast_grad_var_name
)
cast_grad_var
.
persistable
=
False
cast_grad_var
.
persistable
=
False
block
.
_insert_op
(
block
.
_insert_op
(
index
=
index
+
1
,
index
=
index
+
offset
,
type
=
'cast'
,
type
=
'cast'
,
inputs
=
{
'X'
:
grad_var
},
inputs
=
{
'X'
:
grad_var
},
outputs
=
{
'Out'
:
cast_grad_var
},
outputs
=
{
'Out'
:
cast_grad_var
},
attrs
=
{
attrs
=
{
'in_dtype'
:
grad_var
.
dtype
,
'in_dtype'
:
grad_var
.
dtype
,
'out_dtype'
:
cast_grad_var
.
dtype
,
'out_dtype'
:
cast_grad_var
.
dtype
,
# self._op_device_key: device,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
# self._op_role_var_key: op_role_var
})
})
offset
+=
1
offset
+=
1
block
.
_insert_op
(
block
.
_insert_op
(
index
=
index
+
2
,
index
=
index
+
offset
,
type
=
'sum'
,
type
=
'sum'
,
inputs
=
{
'X'
:
[
param_grad_var
,
cast_grad_var
]},
inputs
=
{
outputs
=
{
'Out'
:
param_grad_var
},
'X'
:
[
merged_param_grad_var
,
cast_grad_var
]
},
outputs
=
{
'Out'
:
merged_param_grad_var
},
attrs
=
{
attrs
=
{
# self._op_device_key: device,
# self._op_device_key: device,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
self
.
_op_role_key
:
self
.
_op_role
.
Backward
,
#
self._op_role_var_key: op_role_var
self
.
_op_role_var_key
:
op_role_var
})
})
offset
+=
1
offset
+=
1
accumulated_grad_names
.
append
(
param_grad_var
.
name
)
#real_grad_name = grad_name[0:grad_name.find(
# '@GRAD')] + '@GRAD'
#real_grad_var = block.vars[
# real_grad_name] # without _0 suffix
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': real_grad_var},
# attrs={
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
#offset += 1
#block._insert_op(
# index=first_optimize_op_index + offset,
# type='cast',
# inputs={'X': real_grad_var},
# outputs={'Out': fp32_grad_var},
# attrs={
# 'in_dtype': real_grad_var.dtype,
# 'out_dtype': fp32_grad_var.dtype,
# # self._op_device_key: device,
# self._op_role_key: self._op_role.Backward,
# # self._op_role_var_key: op_role_var
# })
return
accumulated_grad_names
def
_add_sub_blocks
(
self
,
main_block
,
program_list
):
def
_add_sub_blocks
(
self
,
main_block
,
program_list
):
main_program
=
main_block
.
program
main_program
=
main_block
.
program
...
@@ -5351,7 +5270,9 @@ class PipelineOptimizer(object):
...
@@ -5351,7 +5270,9 @@ class PipelineOptimizer(object):
if
real_block
.
has_var
(
param
):
param_list
.
append
(
param
)
if
real_block
.
has_var
(
param
):
param_list
.
append
(
param
)
#self._clear_gradients(real_block, param_list)
#self._clear_gradients(real_block, param_list)
self
.
_rename_gradient_var_name
(
real_block
)
self
.
_rename_gradient_var_name
(
real_block
)
real_block
.
_sync_with_cpp
()
self
.
_accumulate_gradients
(
real_block
)
self
.
_accumulate_gradients
(
real_block
)
real_block
.
_sync_with_cpp
()
place_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
place_id
=
int
(
os
.
getenv
(
"FLAGS_selected_gpus"
,
"0"
))
main_program
.
_pipeline_opt
=
{
main_program
.
_pipeline_opt
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录