Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
f32f84a4
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
f32f84a4
编写于
3月 11, 2021
作者:
W
WangXi
提交者:
sandyhouse
3月 22, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
update
上级
c7472f16
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
254 addition
and
124 deletion
+254
-124
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
...tributed/fleet/meta_optimizers/sharding/offload_helper.py
+124
-0
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+130
-124
未找到文件。
python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py
0 → 100644
浏览文件 @
f32f84a4
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
..common
import
is_optimizer_op
,
OP_ROLE_KEY
,
OpRole
from
paddle.fluid
import
unique_name
class
OffloadHelper
(
object
):
cpu_place_type
=
0
cuda_place_type
=
1
cuda_pinned_place_type
=
2
def
__init__
(
self
):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def
_insert_memcpy_op
(
self
,
block
,
idx
,
src_name
,
dst_name
,
dst_place_type
):
src_var
=
block
.
var
(
src_name
)
dst_var
=
block
.
var
(
dst_name
)
block
.
_insert_op_without_sync
(
idx
,
type
=
'memcpy'
,
inputs
=
{
'X'
:
src_var
},
outputs
=
{
'Out'
:
dst_var
},
attrs
=
{
'dst_place_type'
:
dst_place_type
,
OP_ROLE_KEY
:
OpRole
.
Optimize
,
})
def
_insert_fetch_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
self
.
_insert_memcpy_op
(
block
,
idx
,
src_name
,
dst_name
,
OffloadHelper
.
cuda_place_type
)
def
_insert_offload_op
(
self
,
block
,
idx
,
src_name
,
dst_name
):
self
.
_insert_memcpy_op
(
block
,
idx
,
src_name
,
dst_name
,
OffloadHelper
.
cuda_pinned_place_type
)
def
_get_offload_var_name
(
self
,
name
):
return
unique_name
.
generate
(
name
+
'@offload'
)
def
_create_offload_var
(
self
,
var_name
,
offload_var_name
,
blocks
):
for
block
in
blocks
:
var
=
block
.
var
(
var_name
)
var
.
persistable
=
False
offload_var
=
block
.
create_var
(
name
=
offload_var_name
,
shape
=
var
.
shape
,
dtype
=
var
.
dtype
,
persistable
=
True
)
def
offload
(
self
,
block
,
startup_block
):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name
=
dict
()
# main_block add offload
for
idx
,
op
in
reversed
(
list
(
enumerate
(
block
.
ops
))):
if
not
is_optimizer_op
(
op
):
break
vars_name
=
[]
if
op
.
type
==
"adam"
:
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name
.
append
(
op
.
desc
.
input
(
"Moment1"
)[
0
])
vars_name
.
append
(
op
.
desc
.
input
(
"Moment1"
)[
0
])
elif
op
.
type
==
'momentum'
:
pass
elif
op
.
type
==
'lars'
:
pass
elif
op
.
type
==
'lamb'
:
pass
# step1: create and init offload_var
for
var_name
in
vars_name
:
assert
var_name
not
in
vars_name_to_offload_name
offload_var_name
=
self
.
_get_offload_var_name
(
var_name
)
vars_name_to_offload_name
[
var_name
]
=
offload_var_name
self
.
_create_offload_var
(
var_name
,
offload_var_name
,
[
block
,
startup_block
])
# step2: insert offload op
for
var_name
in
vars_name
:
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
self
.
_insert_offload_op
(
block
,
idx
+
1
,
var_name
,
offload_var_name
)
# step3: insert fetch op
for
var_name
in
vars_name
:
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
self
.
_insert_fetch_op
(
block
,
idx
,
offload_var_name
,
var_name
)
# startup_block add offload
visited_vars
=
set
()
for
idx
,
op
in
reversed
(
list
(
enumerate
(
startup_block
.
ops
))):
for
out_name
in
op
.
output_arg_names
:
if
out_name
in
visited_vars
:
continue
if
out_name
in
vars_name_to_offload_name
:
var_name
=
out_name
offload_var_name
=
vars_name_to_offload_name
[
var_name
]
# insert offload op after var is generated
self
.
_insert_offload_op
(
startup_block
,
idx
+
1
,
var_name
,
offload_var_name
)
visited_vars
.
add
(
out_name
)
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
f32f84a4
...
...
@@ -22,6 +22,7 @@ from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, Progr
from
paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper
import
FP16Utils
from
paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper
import
WeightDecayHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper
import
GradientClipHelper
from
.sharding.offload_helper
import
OffloadHelper
from
paddle.distributed.fleet.meta_optimizers.sharding.prune
import
ProgramDeps
from
paddle.distributed.fleet.meta_optimizers.sharding.utils
import
*
...
...
@@ -245,132 +246,137 @@ class ShardingOptimizer(MetaOptimizerBase):
# 'op_role': core.op_proto_and_checker_maker.OpRole.LRSched,
# })
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
pass
#def _create_var(block, ref_var, name):
# """
# Create a new var for block, which has the same type,
# shape and dtype as ref_var, then rename it with the
# name `name`.
# """
# new_var = block.create_var(
# name=name,
# shape=ref_var.shape,
# dtype=ref_var.dtype,
# type=ref_var.type,
# lod_level=ref_var.lod_level,
# persistable=ref_var.persistable,
# is_data=ref_var.is_data,
# need_check_feed=ref_var.desc.need_check_feed())
# new_var.stop_gradient = ref_var.stop_gradient
# return new_var
#def _rename_arg(op, old_name, new_name):
# op_desc = op.desc
# if isinstance(op_desc, tuple):
# op_desc = op_desc[0]
# op_desc._rename_input(old_name, new_name)
# op_desc._rename_output(old_name, new_name)
#print("params_grads:", params_grads)
#for param_name, grad_name in params_grads:
# if not self._shard.has_param(param_name): continue
# #if not main_block.has_var(grad_name): continue
# assert main_block.has_var(grad_name)
# use_fp16 = False
# fp16_grad_name = param_name + '.cast_fp16@GRAD'
# if main_block.has_var(grad_name):
# fp16_grad_var = main_block.vars[fp16_grad_name]
# use_fp16 = True
# grad_var = main_block.vars[grad_name]
# if use_fp16:
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# grad_name)
# cast_var = _create_var(main_block, fp16_grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block.append_op(
# #index=offset + 1,
# type='cast',
# inputs={'X': grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# })
# #offset += 1
# main_block.append_op(
# #index=offset + 1,
# type='sum',
# inputs={'X': [fp16_grad_var, cast_var]},
# outputs={'Out': fp16_grad_var},
# attrs={
# 'op_role':
# core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# for index, op in reversed(tuple(enumerate(list(main_block.ops)))):
# offset = index
# if is_backward_op(op) and (
# 'op_role_var' in op.attr_names):
# op_role_var = op.all_attrs()['op_role_var']
# if len(op_role_var) == 0:
# continue
# assert len(op_role_var) % 2 == 0
# offset = index
# for i in range(0, len(op_role_var), 2):
# grad_name = op_role_var[i + 1]
# if not main_block.has_var(grad_name): continue
# grad_var = main_block.vars[grad_name]
# if not 'cast_fp16' in grad_name:
# new_grad_var_name = paddle.fluid.unique_name.generate(grad_name)
# new_var = _create_var(main_block, grad_var,
# new_grad_var_name)
# new_var.persistable = False
# _rename_arg(op, grad_name, new_grad_var_name)
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, new_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var
# })
# offset += 1
# if 'cast_fp16' in grad_name:
# param_name = op_role_var[i]
# fp32_grad_var_name = param_name + "@GRAD"
# fp32_grad_var = main_block.vars[grad_name]
# cast_grad_var_name = paddle.fluid.unique_name.generate(
# fp32_grad_var_name)
# cast_var = _create_var(main_block, grad_var,
# cast_grad_var_name)
# cast_var.persistable = False
# main_block._insert_op(
# index=offset + 1,
# type='cast',
# inputs={'X': fp32_grad_var},
# outputs={'Out': cast_var},
# attrs={
# 'in_dtype': fp32_grad_var.dtype,
# 'out_dtype': cast_var.dtype,
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# # self._op_role_var_key: op_role_var
# })
# offset += 1
# main_block._insert_op(
# index=offset + 1,
# type='sum',
# inputs={'X': [grad_var, cast_var]},
# outputs={'Out': grad_var},
# attrs={
# 'op_role': core.op_proto_and_checker_maker.OpRole.Backward,
# 'op_role_var': op_role_var})
main_block
.
_sync_with_cpp
()
# TODO(wangxi): add optimize offload
offload_helper
=
OffloadHelper
()
offload_helper
.
offload
(
main_block
,
startup_block
)
with
open
(
"start_sharding_%d"
%
self
.
role_maker
.
_worker_index
(),
'w'
)
as
f
:
f
.
writelines
(
str
(
startup_block
.
program
))
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录