Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6959eae5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6959eae5
编写于
4月 12, 2023
作者:
Y
Yiqun Liu
提交者:
GitHub
4月 12, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Unify the static amp codes of fp16 and bf16. Reimplement #52694 in release/2.4. (#52697)
上级
d1e8b1e2
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
920 addition
and
506 deletion
+920
-506
python/paddle/fluid/contrib/mixed_precision/__init__.py
python/paddle/fluid/contrib/mixed_precision/__init__.py
+1
-1
python/paddle/fluid/contrib/mixed_precision/amp_nn.py
python/paddle/fluid/contrib/mixed_precision/amp_nn.py
+68
-46
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+301
-159
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+46
-28
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+211
-118
python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
...paddle/fluid/dygraph/dygraph_to_static/partial_program.py
+293
-154
未找到文件。
python/paddle/fluid/contrib/mixed_precision/__init__.py
浏览文件 @
6959eae5
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
from
.
import
decorator
from
.
import
decorator
from
.decorator
import
*
from
.decorator
import
decorate
,
amp_decorate
from
.
import
fp16_lists
from
.
import
fp16_lists
from
.fp16_lists
import
*
from
.fp16_lists
import
*
from
.
import
fp16_utils
from
.
import
fp16_utils
...
...
python/paddle/fluid/contrib/mixed_precision/amp_nn.py
浏览文件 @
6959eae5
...
@@ -38,27 +38,35 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
...
@@ -38,27 +38,35 @@ def check_finite_and_unscale(x, scale, name=None, float_status=None):
"""
"""
check_type
(
x
,
'x'
,
(
tuple
,
list
),
'check_finite_and_unscale'
)
check_type
(
x
,
'x'
,
(
tuple
,
list
),
'check_finite_and_unscale'
)
for
e
in
x
:
for
e
in
x
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'check_finite_and_unscale'
)
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'check_finite_and_unscale'
,
)
helper
=
LayerHelper
(
"check_finite_and_unscale"
,
**
locals
())
helper
=
LayerHelper
(
"check_finite_and_unscale"
,
**
locals
())
found_inf
=
helper
.
create_variable_for_type_inference
(
dtype
=
'bool'
)
found_inf
=
helper
.
create_variable_for_type_inference
(
dtype
=
'bool'
)
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
}
inputs
=
{
'X'
:
x
,
'Scale'
:
scale
}
if
core
.
is_compiled_with_npu
():
if
core
.
is_compiled_with_npu
():
check_variable_and_dtype
(
float_status
,
"float_status"
,
check_variable_and_dtype
(
float_status
,
"float_status"
,
[
'float16'
,
'float32'
],
[
'float16'
,
'float32'
],
'check_finite_and_unscale'
)
'check_finite_and_unscale'
,
)
inputs
[
'FloatStatus'
]
=
float_status
inputs
[
'FloatStatus'
]
=
float_status
outputs
=
{
'Out'
:
x
,
'FoundInfinite'
:
found_inf
}
outputs
=
{
'Out'
:
x
,
'FoundInfinite'
:
found_inf
}
helper
.
append_op
(
type
=
'check_finite_and_unscale'
,
helper
.
append_op
(
inputs
=
inputs
,
type
=
'check_finite_and_unscale'
,
inputs
=
inputs
,
outputs
=
outputs
outputs
=
outputs
)
)
return
x
,
found_inf
return
x
,
found_inf
def
update_loss_scaling
(
x
,
def
update_loss_scaling
(
x
,
found_inf
,
found_inf
,
prev_loss_scaling
,
prev_loss_scaling
,
num_good_steps
,
num_good_steps
,
...
@@ -68,7 +76,8 @@ def update_loss_scaling(x,
...
@@ -68,7 +76,8 @@ def update_loss_scaling(x,
incr_ratio
,
incr_ratio
,
decr_ratio
,
decr_ratio
,
stop_update
=
False
,
stop_update
=
False
,
name
=
None
):
name
=
None
,
):
"""
"""
Update loss scaling according to overall gradients. If all gradients is
Update loss scaling according to overall gradients. If all gradients is
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
finite after incr_every_n_steps, loss scaling will increase by incr_ratio.
...
@@ -96,17 +105,31 @@ def update_loss_scaling(x,
...
@@ -96,17 +105,31 @@ def update_loss_scaling(x,
loss scaling.
loss scaling.
"""
"""
check_variable_and_dtype
(
prev_loss_scaling
,
"prev_loss_scaling"
,
check_variable_and_dtype
(
[
'float32'
,
'float64'
],
"update_loss_scaling"
)
prev_loss_scaling
,
"prev_loss_scaling"
,
[
'float32'
,
'float64'
],
"update_loss_scaling"
,
)
check_type
(
x
,
'x'
,
(
tuple
,
list
),
'update_loss_scaling'
)
check_type
(
x
,
'x'
,
(
tuple
,
list
),
'update_loss_scaling'
)
for
e
in
x
:
for
e
in
x
:
check_variable_and_dtype
(
e
,
"x"
,
[
'float16'
,
'float32'
,
'float64'
],
check_variable_and_dtype
(
'update_loss_scaling'
)
e
,
if
e
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
"x"
,
assert
prev_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
,
\
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
'update_loss_scaling'
,
)
if
(
e
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
or
e
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
):
assert
(
prev_loss_scaling
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
),
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else
:
else
:
assert
prev_loss_scaling
.
dtype
==
e
.
dtype
,
"The dtype of prev_loss_scaling should be equal to the dtype of x."
assert
(
prev_loss_scaling
.
dtype
==
e
.
dtype
),
"The dtype of prev_loss_scaling should be equal to the dtype of x."
helper
=
LayerHelper
(
"update_loss_scaling"
,
**
locals
())
helper
=
LayerHelper
(
"update_loss_scaling"
,
**
locals
())
...
@@ -115,14 +138,14 @@ def update_loss_scaling(x,
...
@@ -115,14 +138,14 @@ def update_loss_scaling(x,
'FoundInfinite'
:
found_inf
,
'FoundInfinite'
:
found_inf
,
'PrevLossScaling'
:
prev_loss_scaling
,
'PrevLossScaling'
:
prev_loss_scaling
,
'InGoodSteps'
:
num_good_steps
,
'InGoodSteps'
:
num_good_steps
,
'InBadSteps'
:
num_bad_steps
'InBadSteps'
:
num_bad_steps
,
}
}
outputs
=
{
outputs
=
{
'Out'
:
x
,
'Out'
:
x
,
'LossScaling'
:
prev_loss_scaling
,
'LossScaling'
:
prev_loss_scaling
,
'OutGoodSteps'
:
num_good_steps
,
'OutGoodSteps'
:
num_good_steps
,
'OutBadSteps'
:
num_bad_steps
'OutBadSteps'
:
num_bad_steps
,
}
}
attrs
=
{
attrs
=
{
...
@@ -137,9 +160,8 @@ def update_loss_scaling(x,
...
@@ -137,9 +160,8 @@ def update_loss_scaling(x,
else
:
else
:
attrs
[
'stop_update'
]
=
stop_update
attrs
[
'stop_update'
]
=
stop_update
helper
.
append_op
(
type
=
'update_loss_scaling'
,
helper
.
append_op
(
inputs
=
inputs
,
type
=
'update_loss_scaling'
,
inputs
=
inputs
,
outputs
=
outputs
,
attrs
=
attrs
outputs
=
outputs
,
)
attrs
=
attrs
)
return
x
return
x
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
6959eae5
...
@@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object):
...
@@ -63,10 +63,20 @@ class OptimizerWithMixedPrecision(object):
"""
"""
def
__init__
(
self
,
optimizer
,
amp_lists
,
init_loss_scaling
,
def
__init__
(
use_dynamic_loss_scaling
,
incr_every_n_steps
,
self
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
,
use_pure_fp16
,
optimizer
,
use_fp16_guard
):
amp_lists
,
init_loss_scaling
,
use_dynamic_loss_scaling
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
,
use_pure_fp16
,
use_fp16_guard
,
use_bf16
=
False
,
):
self
.
_optimizer
=
optimizer
self
.
_optimizer
=
optimizer
self
.
_amp_lists
=
amp_lists
self
.
_amp_lists
=
amp_lists
self
.
_param_grads
=
None
self
.
_param_grads
=
None
...
@@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object):
...
@@ -77,11 +87,23 @@ class OptimizerWithMixedPrecision(object):
self
.
_loss_scaling
=
None
self
.
_loss_scaling
=
None
self
.
_init_loss_scaling
=
init_loss_scaling
self
.
_init_loss_scaling
=
init_loss_scaling
self
.
_use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
self
.
_use_dynamic_loss_scaling
=
use_dynamic_loss_scaling
if
use_bf16
:
if
use_dynamic_loss_scaling
:
self
.
_use_dynamic_loss_scaling
=
False
self
.
_init_loss_scaling
=
1.0
warnings
.
warn
(
"Dynamic loss scaling for bfloat16 amp training is disabled, and the init_loss_scaling is changed to 1.0 automatically by PaddlePaddle."
)
self
.
_amp_dtype
=
core
.
VarDesc
.
VarType
.
BF16
else
:
self
.
_amp_dtype
=
core
.
VarDesc
.
VarType
.
FP16
self
.
_learning_rate
=
optimizer
.
_learning_rate
self
.
_learning_rate
=
optimizer
.
_learning_rate
self
.
_learning_rate_map
=
optimizer
.
_learning_rate_map
self
.
_learning_rate_map
=
optimizer
.
_learning_rate_map
self
.
_use_pure_fp16
=
use_pure_fp16
self
.
_use_pure_fp16
=
use_pure_fp16
self
.
_use_fp16_guard
=
use_fp16_guard
self
.
_use_fp16_guard
=
use_fp16_guard
self
.
_to_fp16_var_names
=
None
self
.
_to_fp16_var_names
=
None
self
.
_use_bf16
=
use_bf16
if
self
.
_use_dynamic_loss_scaling
:
if
self
.
_use_dynamic_loss_scaling
:
self
.
_incr_every_n_steps
=
incr_every_n_steps
self
.
_incr_every_n_steps
=
incr_every_n_steps
self
.
_decr_every_n_nan_or_inf
=
decr_every_n_nan_or_inf
self
.
_decr_every_n_nan_or_inf
=
decr_every_n_nan_or_inf
...
@@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object):
...
@@ -97,9 +119,10 @@ class OptimizerWithMixedPrecision(object):
self
.
_is_distributed
=
flag
self
.
_is_distributed
=
flag
def
get_loss_scaling
(
self
):
def
get_loss_scaling
(
self
):
"""Return the real-time loss scaling factor.
"""Return the real-time loss scaling factor."""
"""
assert
(
assert
self
.
_loss_scaling
is
not
None
,
'Please call minimize() before calling get_loss_scaling().'
self
.
_loss_scaling
is
not
None
),
'Please call minimize() before calling get_loss_scaling().'
return
self
.
_loss_scaling
return
self
.
_loss_scaling
def
get_scaled_loss
(
self
):
def
get_scaled_loss
(
self
):
...
@@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -117,7 +140,8 @@ class OptimizerWithMixedPrecision(object):
shape
=
[
1
],
shape
=
[
1
],
value
=
self
.
_init_loss_scaling
,
value
=
self
.
_init_loss_scaling
,
dtype
=
'float32'
,
dtype
=
'float32'
,
persistable
=
True
)
persistable
=
True
,
)
if
self
.
_use_dynamic_loss_scaling
:
if
self
.
_use_dynamic_loss_scaling
:
self
.
_num_good_steps
=
layers
.
create_global_var
(
self
.
_num_good_steps
=
layers
.
create_global_var
(
...
@@ -125,31 +149,37 @@ class OptimizerWithMixedPrecision(object):
...
@@ -125,31 +149,37 @@ class OptimizerWithMixedPrecision(object):
shape
=
[
1
],
shape
=
[
1
],
value
=
0
,
value
=
0
,
dtype
=
'int32'
,
dtype
=
'int32'
,
persistable
=
True
)
persistable
=
True
,
)
self
.
_num_bad_steps
=
layers
.
create_global_var
(
self
.
_num_bad_steps
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"num_bad_steps"
),
name
=
unique_name
.
generate
(
"num_bad_steps"
),
shape
=
[
1
],
shape
=
[
1
],
value
=
0
,
value
=
0
,
dtype
=
'int32'
,
dtype
=
'int32'
,
persistable
=
True
)
persistable
=
True
,
)
# Ensure the data type of learning rate vars is float32 (same as the
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
# master parameter dtype)
if
isinstance
(
self
.
_optimizer
.
_learning_rate
,
float
):
if
isinstance
(
self
.
_optimizer
.
_learning_rate
,
float
):
self
.
_optimizer
.
_learning_rate_map
[
default_main_program
()]
=
\
self
.
_optimizer
.
_learning_rate_map
[
layers
.
create_global_var
(
default_main_program
()
]
=
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
shape
=
[
1
],
value
=
float
(
self
.
_optimizer
.
_learning_rate
),
value
=
float
(
self
.
_optimizer
.
_learning_rate
),
dtype
=
'float32'
,
dtype
=
'float32'
,
persistable
=
True
)
persistable
=
True
,
)
def
backward
(
self
,
def
backward
(
self
,
loss
,
loss
,
startup_program
=
None
,
startup_program
=
None
,
parameter_list
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
callbacks
=
None
,
):
"""
"""
Backward propagation or auto differentiation for gradients' computation.
Backward propagation or auto differentiation for gradients' computation.
...
@@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object):
...
@@ -171,9 +201,9 @@ class OptimizerWithMixedPrecision(object):
# NOTE(zhiqiu): _float_status is only used for NPU.
# NOTE(zhiqiu): _float_status is only used for NPU.
if
core
.
is_compiled_with_npu
():
if
core
.
is_compiled_with_npu
():
float_status
=
paddle
.
static
.
data
(
name
=
"float_status"
,
float_status
=
paddle
.
static
.
data
(
shape
=
[
8
],
name
=
"float_status"
,
shape
=
[
8
],
dtype
=
'float32'
dtype
=
'float32'
)
)
self
.
_train_program
.
global_block
().
append_op
(
self
.
_train_program
.
global_block
().
append_op
(
type
=
"alloc_float_status"
,
type
=
"alloc_float_status"
,
outputs
=
{
"FloatStatus"
:
float_status
},
outputs
=
{
"FloatStatus"
:
float_status
},
...
@@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object):
...
@@ -192,9 +222,15 @@ class OptimizerWithMixedPrecision(object):
if
self
.
_use_pure_fp16
:
if
self
.
_use_pure_fp16
:
self
.
_to_fp16_var_names
=
cast_model_to_fp16
(
self
.
_to_fp16_var_names
=
cast_model_to_fp16
(
self
.
_train_program
,
self
.
_amp_lists
,
self
.
_use_fp16_guard
)
self
.
_train_program
,
self
.
_amp_lists
,
self
.
_use_fp16_guard
,
self
.
_amp_dtype
,
)
else
:
else
:
rewrite_program
(
self
.
_train_program
,
self
.
_amp_lists
)
rewrite_program
(
self
.
_train_program
,
self
.
_amp_lists
,
self
.
_amp_dtype
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
loss
=
loss
.
astype
(
'float32'
)
loss
=
loss
.
astype
(
'float32'
)
...
@@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object):
...
@@ -205,10 +241,13 @@ class OptimizerWithMixedPrecision(object):
else
:
else
:
self
.
_scaled_loss
=
loss
self
.
_scaled_loss
=
loss
params_grads
=
self
.
_optimizer
.
backward
(
self
.
_scaled_loss
,
params_grads
=
self
.
_optimizer
.
backward
(
self
.
_scaled_loss
,
startup_program
,
startup_program
,
parameter_list
,
no_grad_set
,
parameter_list
,
callbacks
)
no_grad_set
,
callbacks
,
)
if
self
.
_supports_check_nan_inf
():
if
self
.
_supports_check_nan_inf
():
self
.
_add_cast_ops_to_startup_program
(
startup_program
)
self
.
_add_cast_ops_to_startup_program
(
startup_program
)
return
params_grads
return
params_grads
...
@@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object):
...
@@ -216,8 +255,11 @@ class OptimizerWithMixedPrecision(object):
def
_add_cast_ops_to_startup_program
(
self
,
startup_program
):
def
_add_cast_ops_to_startup_program
(
self
,
startup_program
):
names
=
list
(
self
.
_to_fp16_var_names
)
if
self
.
_to_fp16_var_names
else
[]
names
=
list
(
self
.
_to_fp16_var_names
)
if
self
.
_to_fp16_var_names
else
[]
names
.
sort
()
names
.
sort
()
startup_program
=
default_startup_program
(
startup_program
=
(
)
if
startup_program
is
None
else
startup_program
default_startup_program
()
if
startup_program
is
None
else
startup_program
)
block
=
startup_program
.
global_block
()
block
=
startup_program
.
global_block
()
param_names
=
[
p
.
name
for
p
in
block
.
all_parameters
()]
param_names
=
[
p
.
name
for
p
in
block
.
all_parameters
()]
for
name
in
names
:
for
name
in
names
:
...
@@ -225,23 +267,23 @@ class OptimizerWithMixedPrecision(object):
...
@@ -225,23 +267,23 @@ class OptimizerWithMixedPrecision(object):
continue
continue
tmp
=
block
.
create_var
(
dtype
=
core
.
VarDesc
.
VarType
.
FP32
)
tmp
=
block
.
create_var
(
dtype
=
core
.
VarDesc
.
VarType
.
FP32
)
block
.
append_op
(
type
=
'assign'
,
block
.
append_op
(
inputs
=
{
'X'
:
[
name
]},
type
=
'assign'
,
inputs
=
{
'X'
:
[
name
]},
outputs
=
{
'Out'
:
[
tmp
]}
outputs
=
{
'Out'
:
[
tmp
]})
)
block
.
append_op
(
type
=
'cast'
,
block
.
append_op
(
type
=
'cast'
,
inputs
=
{
'X'
:
[
tmp
]},
inputs
=
{
'X'
:
[
tmp
]},
outputs
=
{
'Out'
:
[
name
]},
outputs
=
{
'Out'
:
[
name
]},
attrs
=
{
attrs
=
{
'in_dtype'
:
core
.
VarDesc
.
VarType
.
FP32
,
'in_dtype'
:
core
.
VarDesc
.
VarType
.
FP32
,
'out_dtype'
:
core
.
VarDesc
.
VarType
.
FP16
,
'out_dtype'
:
self
.
_amp_dtype
,
})
},
)
self
.
_to_fp16_var_names
=
None
self
.
_to_fp16_var_names
=
None
def
amp_init
(
self
,
def
amp_init
(
place
,
self
,
place
,
scope
=
None
,
test_program
=
None
,
use_fp16_test
=
False
scope
=
None
,
):
test_program
=
None
,
use_fp16_test
=
False
):
"""
"""
Init the amp training, such as cast fp32 parameters to fp16 type.
Init the amp training, such as cast fp32 parameters to fp16 type.
...
@@ -297,17 +339,27 @@ class OptimizerWithMixedPrecision(object):
...
@@ -297,17 +339,27 @@ class OptimizerWithMixedPrecision(object):
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
run_example_code()
"""
"""
assert
self
.
_train_program
is
not
None
,
\
assert
(
"Please call the minimize method first."
self
.
_train_program
is
not
None
),
"Please call the minimize method first."
if
self
.
_use_pure_fp16
:
if
self
.
_use_pure_fp16
:
cast_parameters_to_fp16
(
place
,
self
.
_train_program
,
scope
,
cast_parameters_to_fp16
(
self
.
_to_fp16_var_names
)
place
,
self
.
_train_program
,
scope
,
self
.
_to_fp16_var_names
,
self
.
_amp_dtype
,
)
if
test_program
is
not
None
:
if
test_program
is
not
None
:
if
self
.
_use_pure_fp16
:
if
self
.
_use_pure_fp16
:
cast_model_to_fp16
(
test_program
,
self
.
_amp_lists
,
cast_model_to_fp16
(
self
.
_use_fp16_guard
)
test_program
,
self
.
_amp_lists
,
self
.
_use_fp16_guard
,
self
.
_amp_dtype
,
)
elif
use_fp16_test
:
elif
use_fp16_test
:
rewrite_program
(
test_program
,
self
.
_amp_lists
)
rewrite_program
(
test_program
,
self
.
_amp_lists
,
self
.
_amp_dtype
)
def
apply_gradients
(
self
,
params_grads
):
def
apply_gradients
(
self
,
params_grads
):
"""
"""
...
@@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object):
...
@@ -327,7 +379,10 @@ class OptimizerWithMixedPrecision(object):
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# When not using dynamic loss scaling and the init loss scaling value is equal to 1.0,
# the model can be optimized.
# the model can be optimized.
if
not
self
.
_use_dynamic_loss_scaling
and
self
.
_init_loss_scaling
==
1.0
:
if
(
not
self
.
_use_dynamic_loss_scaling
and
self
.
_init_loss_scaling
==
1.0
):
return
self
.
_optimizer
.
apply_gradients
(
params_grads
)
return
self
.
_optimizer
.
apply_gradients
(
params_grads
)
if
self
.
_supports_check_nan_inf
():
if
self
.
_supports_check_nan_inf
():
...
@@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object):
...
@@ -338,7 +393,10 @@ class OptimizerWithMixedPrecision(object):
return
optimize_ops
return
optimize_ops
found_inf
=
self
.
_check_finite_and_unscale
(
params_grads
)
found_inf
=
self
.
_check_finite_and_unscale
(
params_grads
)
if
self
.
_use_dynamic_loss_scaling
:
if
(
self
.
_use_dynamic_loss_scaling
and
self
.
_amp_dtype
==
core
.
VarDesc
.
VarType
.
FP16
):
self
.
_add_dynamic_loss_scaling
(
params_grads
,
found_inf
)
self
.
_add_dynamic_loss_scaling
(
params_grads
,
found_inf
)
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
# Pass found_inf to adam, to skip update for not only param, but also momentum and beta_pow
...
@@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object):
...
@@ -346,13 +404,16 @@ class OptimizerWithMixedPrecision(object):
real_optimizer
=
self
.
_optimizer
real_optimizer
=
self
.
_optimizer
while
hasattr
(
real_optimizer
,
"inner_opt"
):
while
hasattr
(
real_optimizer
,
"inner_opt"
):
real_optimizer
=
real_optimizer
.
inner_opt
real_optimizer
=
real_optimizer
.
inner_opt
if
isinstance
(
real_optimizer
,
if
isinstance
(
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
)):
real_optimizer
,
(
paddle
.
fluid
.
optimizer
.
Adam
,
paddle
.
optimizer
.
AdamW
),
):
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# NOTE(zhiqiu): Since found_inf needs to be on cpu in adam op, we
# copy it in advance to avoid multiple time copies.
# copy it in advance to avoid multiple time copies.
with
self
.
_train_program
.
_optimized_guard
([]):
with
self
.
_train_program
.
_optimized_guard
([]):
found_inf
=
paddle
.
tensor
.
creation
.
_memcpy
(
found_inf
=
paddle
.
tensor
.
creation
.
_memcpy
(
found_inf
,
paddle
.
CPUPlace
())
found_inf
,
paddle
.
CPUPlace
()
)
real_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
real_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
elif
hasattr
(
real_optimizer
,
"_set_auxiliary_var"
):
elif
hasattr
(
real_optimizer
,
"_set_auxiliary_var"
):
real_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
real_optimizer
.
_set_auxiliary_var
(
'found_inf'
,
found_inf
)
...
@@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object):
...
@@ -362,9 +423,10 @@ class OptimizerWithMixedPrecision(object):
def
_split_grads
(
self
,
params_grads
):
def
_split_grads
(
self
,
params_grads
):
grads
=
[
g
for
_
,
g
in
params_grads
]
grads
=
[
g
for
_
,
g
in
params_grads
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp32_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
]
fp16_grads
=
[
g
for
g
in
grads
if
g
.
dtype
==
self
.
_amp_dtype
]
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
grads
),
\
assert
len
(
fp32_grads
)
+
len
(
fp16_grads
)
==
len
(
"Data types of all grads must be either fp16 or fp32."
grads
),
"Data types of all grads must be either fp16/bf16 or fp32."
return
grads
,
fp32_grads
,
fp16_grads
return
grads
,
fp32_grads
,
fp16_grads
def
_check_finite_and_unscale
(
self
,
params_grads
):
def
_check_finite_and_unscale
(
self
,
params_grads
):
...
@@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -380,7 +442,8 @@ class OptimizerWithMixedPrecision(object):
grads
,
grads
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
,
name
=
"find_infinite_scale"
,
float_status
=
self
.
_float_status
)
float_status
=
self
.
_float_status
,
)
found_infs
.
append
(
found_inf
)
found_infs
.
append
(
found_inf
)
else
:
else
:
for
p
,
g
in
params_grads
:
for
p
,
g
in
params_grads
:
...
@@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -391,7 +454,8 @@ class OptimizerWithMixedPrecision(object):
],
],
self
.
_loss_scaling
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
,
name
=
"find_infinite_scale"
,
float_status
=
self
.
_float_status
)
float_status
=
self
.
_float_status
,
)
found_infs
.
append
(
found_inf
)
found_infs
.
append
(
found_inf
)
elif
self
.
_use_pure_fp16
:
elif
self
.
_use_pure_fp16
:
if
fp32_grads
:
if
fp32_grads
:
...
@@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -400,7 +464,8 @@ class OptimizerWithMixedPrecision(object):
fp32_grads
,
fp32_grads
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale_fp32"
,
name
=
"find_infinite_scale_fp32"
,
float_status
=
self
.
_float_status
)
float_status
=
self
.
_float_status
,
)
found_infs
.
append
(
fp32_found_inf
)
found_infs
.
append
(
fp32_found_inf
)
if
fp16_grads
:
if
fp16_grads
:
with
self
.
_train_program
.
_optimized_guard
(
fp16_grads
):
with
self
.
_train_program
.
_optimized_guard
(
fp16_grads
):
...
@@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -408,7 +473,8 @@ class OptimizerWithMixedPrecision(object):
fp16_grads
,
fp16_grads
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale_fp16"
,
name
=
"find_infinite_scale_fp16"
,
float_status
=
self
.
_float_status
)
float_status
=
self
.
_float_status
,
)
found_infs
.
append
(
fp16_found_inf
)
found_infs
.
append
(
fp16_found_inf
)
else
:
else
:
with
self
.
_train_program
.
_optimized_guard
(
grads
):
with
self
.
_train_program
.
_optimized_guard
(
grads
):
...
@@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -416,7 +482,8 @@ class OptimizerWithMixedPrecision(object):
grads
,
grads
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
name
=
"find_infinite_scale"
,
name
=
"find_infinite_scale"
,
float_status
=
self
.
_float_status
)
float_status
=
self
.
_float_status
,
)
if
self
.
_is_distributed
or
self
.
_use_pure_fp16
:
if
self
.
_is_distributed
or
self
.
_use_pure_fp16
:
with
self
.
_train_program
.
_optimized_guard
([]):
with
self
.
_train_program
.
_optimized_guard
([]):
...
@@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -439,7 +506,8 @@ class OptimizerWithMixedPrecision(object):
self
.
_incr_ratio
,
self
.
_incr_ratio
,
self
.
_decr_ratio
,
self
.
_decr_ratio
,
stop_update
=
self
.
_optimizer
.
_get_stop_update_var
(),
stop_update
=
self
.
_optimizer
.
_get_stop_update_var
(),
name
=
"update_loss_scaling"
)
name
=
"update_loss_scaling"
,
)
return
return
grads
,
fp32_grads
,
fp16_grads
=
self
.
_split_grads
(
params_grads
)
grads
,
fp32_grads
,
fp16_grads
=
self
.
_split_grads
(
params_grads
)
...
@@ -447,7 +515,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -447,7 +515,8 @@ class OptimizerWithMixedPrecision(object):
stop_update
=
False
stop_update
=
False
with
self
.
_train_program
.
_optimized_guard
([]):
with
self
.
_train_program
.
_optimized_guard
([]):
if
fp32_grads
:
if
fp32_grads
:
update_loss_scaling
(
fp32_grads
,
update_loss_scaling
(
fp32_grads
,
found_inf
,
found_inf
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
self
.
_num_good_steps
,
self
.
_num_good_steps
,
...
@@ -457,10 +526,12 @@ class OptimizerWithMixedPrecision(object):
...
@@ -457,10 +526,12 @@ class OptimizerWithMixedPrecision(object):
self
.
_incr_ratio
,
self
.
_incr_ratio
,
self
.
_decr_ratio
,
self
.
_decr_ratio
,
stop_update
=
stop_update
,
stop_update
=
stop_update
,
name
=
"update_loss_scaling_fp32"
)
name
=
"update_loss_scaling_fp32"
,
)
stop_update
=
True
stop_update
=
True
if
fp16_grads
:
if
fp16_grads
:
update_loss_scaling
(
fp16_grads
,
update_loss_scaling
(
fp16_grads
,
found_inf
,
found_inf
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
self
.
_num_good_steps
,
self
.
_num_good_steps
,
...
@@ -470,10 +541,12 @@ class OptimizerWithMixedPrecision(object):
...
@@ -470,10 +541,12 @@ class OptimizerWithMixedPrecision(object):
self
.
_incr_ratio
,
self
.
_incr_ratio
,
self
.
_decr_ratio
,
self
.
_decr_ratio
,
stop_update
=
stop_update
,
stop_update
=
stop_update
,
name
=
"update_loss_scaling_fp16"
)
name
=
"update_loss_scaling_fp16"
,
)
else
:
else
:
with
self
.
_train_program
.
_optimized_guard
([]):
with
self
.
_train_program
.
_optimized_guard
([]):
update_loss_scaling
(
grads
,
update_loss_scaling
(
grads
,
found_inf
,
found_inf
,
self
.
_loss_scaling
,
self
.
_loss_scaling
,
self
.
_num_good_steps
,
self
.
_num_good_steps
,
...
@@ -482,7 +555,8 @@ class OptimizerWithMixedPrecision(object):
...
@@ -482,7 +555,8 @@ class OptimizerWithMixedPrecision(object):
self
.
_decr_every_n_nan_or_inf
,
self
.
_decr_every_n_nan_or_inf
,
self
.
_incr_ratio
,
self
.
_incr_ratio
,
self
.
_decr_ratio
,
self
.
_decr_ratio
,
name
=
"update_loss_scaling"
)
name
=
"update_loss_scaling"
,
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
program
=
loss
.
block
.
program
program
=
loss
.
block
.
program
...
@@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object):
...
@@ -490,11 +564,9 @@ class OptimizerWithMixedPrecision(object):
optimize_ops
=
self
.
apply_gradients
(
params_grads
)
optimize_ops
=
self
.
apply_gradients
(
params_grads
)
return
optimize_ops
return
optimize_ops
def
minimize
(
self
,
def
minimize
(
loss
,
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
startup_program
=
None
,
):
parameter_list
=
None
,
no_grad_set
=
None
):
"""
"""
Perform optimization by minimizing the given loss.
Perform optimization by minimizing the given loss.
...
@@ -511,24 +583,29 @@ class OptimizerWithMixedPrecision(object):
...
@@ -511,24 +583,29 @@ class OptimizerWithMixedPrecision(object):
"""
"""
opt_dict
=
self
.
_optimizer
.
__class__
.
__dict__
opt_dict
=
self
.
_optimizer
.
__class__
.
__dict__
if
'minimize'
in
opt_dict
and
isinstance
(
opt_dict
[
'minimize'
],
if
'minimize'
in
opt_dict
and
isinstance
(
types
.
FunctionType
):
opt_dict
[
'minimize'
],
types
.
FunctionType
):
warnings
.
warn
(
warnings
.
warn
(
"The decorated optimizer has its own `minimize` method, but it will not be executed."
"The decorated optimizer has its own `minimize` method, but it will not be executed."
)
)
scaled_params_grads
=
self
.
backward
(
loss
,
scaled_params_grads
=
self
.
backward
(
loss
,
startup_program
=
startup_program
,
startup_program
=
startup_program
,
parameter_list
=
parameter_list
,
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
no_grad_set
=
no_grad_set
,
)
optimize_ops
=
self
.
apply_optimize
(
loss
,
startup_program
,
optimize_ops
=
self
.
apply_optimize
(
scaled_params_grads
)
loss
,
startup_program
,
scaled_params_grads
)
return
optimize_ops
,
scaled_params_grads
return
optimize_ops
,
scaled_params_grads
def
decorate
(
optimizer
,
def
decorate
(
optimizer
,
amp_lists
=
None
,
amp_lists
=
None
,
init_loss_scaling
=
2
**
15
,
init_loss_scaling
=
2
**
15
,
incr_every_n_steps
=
1000
,
incr_every_n_steps
=
1000
,
...
@@ -537,7 +614,9 @@ def decorate(optimizer,
...
@@ -537,7 +614,9 @@ def decorate(optimizer,
decr_ratio
=
0.8
,
decr_ratio
=
0.8
,
use_dynamic_loss_scaling
=
True
,
use_dynamic_loss_scaling
=
True
,
use_pure_fp16
=
False
,
use_pure_fp16
=
False
,
use_fp16_guard
=
None
):
use_fp16_guard
=
None
,
use_bf16
=
False
,
):
"""
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Decorate the given optimizer to adapt to the mixed-precision training.
...
@@ -628,15 +707,78 @@ def decorate(optimizer,
...
@@ -628,15 +707,78 @@ def decorate(optimizer,
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
run_example_code()
"""
"""
dtype
=
"bfloat16"
if
use_bf16
else
"float16"
if
amp_lists
is
None
:
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionLists
()
amp_lists
=
AutoMixedPrecisionLists
(
dtype
=
dtype
)
if
use_fp16_guard
is
None
:
if
use_fp16_guard
is
None
:
use_fp16_guard
=
use_pure_fp16
use_fp16_guard
=
use_pure_fp16
mp_optimizer
=
OptimizerWithMixedPrecision
(
mp_optimizer
=
OptimizerWithMixedPrecision
(
optimizer
,
amp_lists
,
init_loss_scaling
,
use_dynamic_loss_scaling
,
optimizer
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
,
amp_lists
,
use_pure_fp16
,
use_fp16_guard
)
init_loss_scaling
,
use_dynamic_loss_scaling
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
,
use_pure_fp16
,
use_fp16_guard
,
use_bf16
,
)
return
mp_optimizer
def
amp_decorate
(
optimizer
,
amp_lists
=
None
,
level
=
'O1'
,
dtype
=
'float16'
,
init_loss_scaling
=
2
**
15
,
incr_every_n_steps
=
1000
,
decr_every_n_nan_or_inf
=
2
,
incr_ratio
=
2.0
,
decr_ratio
=
0.8
,
use_dynamic_loss_scaling
=
True
,
use_amp_guard
=
False
,
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
"""
# check amp_dtype: float16 or bfloat16
dtype
=
dtype
.
lower
()
if
not
(
dtype
in
[
'float16'
,
'bfloat16'
]):
raise
ValueError
(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionLists
(
dtype
=
dtype
)
# check amp_level: O0-O2
level
=
level
.
upper
()
if
not
(
level
in
[
'O0'
,
'O1'
,
'O2'
]):
raise
ValueError
(
"level should be O0, O1 or O2. O0 represents fp32 train mode, O1 represents AMP train mode, O2 represents pure fp16/bf16 train mode."
)
use_pure_fp16
=
level
==
"O2"
use_fp16_guard
=
use_amp_guard
use_bf16
=
dtype
==
"bfloat16"
mp_optimizer
=
OptimizerWithMixedPrecision
(
optimizer
,
amp_lists
,
init_loss_scaling
,
use_dynamic_loss_scaling
,
incr_every_n_steps
,
decr_every_n_nan_or_inf
,
incr_ratio
,
decr_ratio
,
use_pure_fp16
,
use_fp16_guard
,
use_bf16
,
)
return
mp_optimizer
return
mp_optimizer
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
6959eae5
...
@@ -13,16 +13,47 @@
...
@@ -13,16 +13,47 @@
# limitations under the License.
# limitations under the License.
import
copy
import
copy
from
...
import
core
from
...
import
core
__all__
=
[
"CustomOpLists"
,
"AutoMixedPrecisionLists"
]
__all__
=
[
"CustomOpLists"
,
"AutoMixedPrecisionLists"
]
# lookup_table fp16 is slower than fp32, though fp16 is supported.
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_extra_unsupported_fp16_list
=
{
_extra_unsupported_list
=
{
'lookup_table'
,
'lookup_table_v2'
,
'scatter'
,
'scatter_grad'
'lookup_table'
,
'lookup_table_v2'
,
'scatter'
,
'scatter_grad'
,
}
}
def
_get_unsupported_list
(
dtype
):
if
dtype
==
"float16"
:
amp_dtype
=
core
.
VarDesc
.
VarType
.
FP16
elif
dtype
==
"bfloat16"
:
amp_dtype
=
core
.
VarDesc
.
VarType
.
BF16
else
:
raise
ValueError
(
"If enable AMP, dtype should be 'float16' or 'bfloat16'."
)
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_list
=
[]
# _sys_unsupported_bf16_list = []
if
core
.
is_compiled_with_xpu
():
_
,
_
,
_sys_unsupported_list
=
core
.
op_supported_infos
(
'XPU'
,
amp_dtype
)
elif
core
.
is_compiled_with_npu
():
_
,
_
,
_sys_unsupported_list
=
core
.
op_supported_infos
(
'NPU'
,
amp_dtype
)
elif
core
.
is_compiled_with_mlu
():
_
,
_
,
_sys_unsupported_list
=
core
.
op_supported_infos
(
'MLU'
,
amp_dtype
)
else
:
_
,
_
,
_sys_unsupported_list
=
core
.
op_supported_infos
(
'GPU'
,
amp_dtype
)
unsupported_list
=
_extra_unsupported_list
|
_sys_unsupported_list
return
unsupported_list
class
AutoMixedPrecisionLists
(
object
):
class
AutoMixedPrecisionLists
(
object
):
"""
"""
AutoMixedPrecisionLists is a class for black/white list. It can update
AutoMixedPrecisionLists is a class for black/white list. It can update
...
@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
...
@@ -36,16 +67,20 @@ class AutoMixedPrecisionLists(object):
custom_black_varnames (set): Users' custom black varibles' names.
custom_black_varnames (set): Users' custom black varibles' names.
"""
"""
def
__init__
(
self
,
def
__init__
(
self
,
custom_white_list
=
None
,
custom_white_list
=
None
,
custom_black_list
=
None
,
custom_black_list
=
None
,
custom_black_varnames
=
None
):
custom_black_varnames
=
None
,
dtype
=
"float16"
,
):
self
.
_custom_white_list
=
custom_white_list
self
.
_custom_white_list
=
custom_white_list
self
.
_custom_black_list
=
custom_black_list
self
.
_custom_black_list
=
custom_black_list
self
.
amp_dtype
=
dtype
self
.
white_list
=
copy
.
copy
(
white_list
)
self
.
white_list
=
copy
.
copy
(
white_list
)
self
.
black_list
=
copy
.
copy
(
black_list
)
self
.
black_list
=
copy
.
copy
(
black_list
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
gray_list
=
copy
.
copy
(
gray_list
)
self
.
unsupported_list
=
copy
.
copy
(
unsupported_fp16_list
)
self
.
unsupported_list
=
copy
.
copy
(
_get_unsupported_list
(
self
.
amp_dtype
)
)
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
self
.
black_varnames
=
copy
.
copy
(
custom_black_varnames
)
self
.
_update_list
()
self
.
_update_list
()
...
@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
...
@@ -56,8 +91,9 @@ class AutoMixedPrecisionLists(object):
if
self
.
_custom_white_list
and
self
.
_custom_black_list
:
if
self
.
_custom_white_list
and
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_white_list
:
for
op_name
in
self
.
_custom_white_list
:
if
op_name
in
self
.
_custom_black_list
:
if
op_name
in
self
.
_custom_black_list
:
raise
ValueError
(
"Custom white list overlap "
raise
ValueError
(
"custom black list"
)
"Custom white list overlap "
"custom black list"
)
if
self
.
_custom_white_list
:
if
self
.
_custom_white_list
:
for
op_name
in
self
.
_custom_white_list
:
for
op_name
in
self
.
_custom_white_list
:
if
op_name
in
self
.
black_list
:
if
op_name
in
self
.
black_list
:
...
@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
...
@@ -65,7 +101,7 @@ class AutoMixedPrecisionLists(object):
elif
op_name
in
self
.
gray_list
:
elif
op_name
in
self
.
gray_list
:
self
.
gray_list
.
remove
(
op_name
)
self
.
gray_list
.
remove
(
op_name
)
self
.
white_list
.
add
(
op_name
)
self
.
white_list
.
add
(
op_name
)
if
op_name
in
_extra_unsupported_
fp16_
list
:
if
op_name
in
_extra_unsupported_list
:
self
.
unsupported_list
.
remove
(
op_name
)
self
.
unsupported_list
.
remove
(
op_name
)
if
self
.
_custom_black_list
:
if
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_black_list
:
for
op_name
in
self
.
_custom_black_list
:
...
@@ -170,22 +206,4 @@ gray_list = {
...
@@ -170,22 +206,4 @@ gray_list = {
'fused_multi_transformer'
,
'fused_multi_transformer'
,
}
}
# The set of ops that don't support fp16 calculation
# lookup_table fp16 is slower than fp32, though fp16 is supported.
_sys_unsupported_fp16_list
=
[]
if
core
.
is_compiled_with_xpu
():
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'XPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
elif
core
.
is_compiled_with_npu
():
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'NPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
elif
core
.
is_compiled_with_mlu
():
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'MLU'
,
core
.
VarDesc
.
VarType
.
FP16
)
else
:
_
,
_
,
_sys_unsupported_fp16_list
=
core
.
op_supported_infos
(
'GPU'
,
core
.
VarDesc
.
VarType
.
FP16
)
unsupported_fp16_list
=
_extra_unsupported_fp16_list
|
_sys_unsupported_fp16_list
CustomOpLists
=
AutoMixedPrecisionLists
CustomOpLists
=
AutoMixedPrecisionLists
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
6959eae5
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
from
__future__
import
print_function
from
__future__
import
print_function
import
paddle
from
...
import
core
from
...
import
core
from
...
import
framework
from
...
import
framework
from
...
import
layers
from
...
import
layers
...
@@ -27,13 +28,14 @@ import numpy as np
...
@@ -27,13 +28,14 @@ import numpy as np
__all__
=
[
"fp16_guard"
,
"cast_model_to_fp16"
,
"cast_parameters_to_fp16"
]
__all__
=
[
"fp16_guard"
,
"cast_model_to_fp16"
,
"cast_parameters_to_fp16"
]
_logger
=
get_logger
(
__name__
,
_logger
=
get_logger
(
logging
.
INFO
,
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
)
_valid_types
=
[
_valid_types
=
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR_ARRAY
,
]
]
_fp16_guard_pattern
=
"__use_fp16__"
_fp16_guard_pattern
=
"__use_fp16__"
...
@@ -75,7 +77,9 @@ def _dtype_to_str(dtype):
...
@@ -75,7 +77,9 @@ def _dtype_to_str(dtype):
Args:
Args:
dtype (VarType): Variable type.
dtype (VarType): Variable type.
"""
"""
if
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
dtype
in
[
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
BF16
]:
# TODO(Xreki): change the returned str to "bf16" for BF16 data type.
# Currently too many codes use "cast_fp16" as key.
return
'fp16'
return
'fp16'
else
:
else
:
return
'fp32'
return
'fp32'
...
@@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name):
...
@@ -108,7 +112,12 @@ def _keep_fp32_input(op, in_name):
return
in_name
not
in
{
'X'
,
'FilterX'
,
'Z'
,
'FilterZ'
}
return
in_name
not
in
{
'X'
,
'FilterX'
,
'Z'
,
'FilterZ'
}
if
op_type
in
[
'fused_attention'
,
'fused_feedforward'
]:
if
op_type
in
[
'fused_attention'
,
'fused_feedforward'
]:
return
in_name
in
{
return
in_name
in
{
'LnScale'
,
'LnBias'
,
'Ln2Scale'
,
'Ln2Bias'
,
"Ln1Scale"
,
"Ln1Bias"
'LnScale'
,
'LnBias'
,
'Ln2Scale'
,
'Ln2Bias'
,
"Ln1Scale"
,
"Ln1Bias"
,
}
}
if
op_type
==
'fused_multi_transformer'
:
if
op_type
==
'fused_multi_transformer'
:
return
in_name
in
{
'LnScale'
,
'LnBias'
,
'FFNLnScale'
,
'FFNLnBias'
}
return
in_name
in
{
'LnScale'
,
'LnBias'
,
'FFNLnScale'
,
'FFNLnBias'
}
...
@@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name):
...
@@ -125,8 +134,12 @@ def _keep_fp32_output(op, out_name):
return
out_name
not
in
{
'Y'
,
'ConvX'
,
'ConvZ'
}
return
out_name
not
in
{
'Y'
,
'ConvX'
,
'ConvZ'
}
if
op_type
in
[
'fused_attention'
,
'fused_feedforward'
]:
if
op_type
in
[
'fused_attention'
,
'fused_feedforward'
]:
return
out_name
in
{
return
out_name
in
{
'LnMean'
,
'LnVariance'
,
'Ln2Mean'
,
'Ln2Variance'
,
'Ln1Mean'
,
'LnMean'
,
'Ln1Variance'
'LnVariance'
,
'Ln2Mean'
,
'Ln2Variance'
,
'Ln1Mean'
,
'Ln1Variance'
,
}
}
return
False
return
False
...
@@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
...
@@ -149,7 +162,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for
in_name
in
op
.
input_names
:
for
in_name
in
op
.
input_names
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
_keep_fp32_input
(
op
,
in_name
):
op
,
in_name
):
continue
continue
for
in_var_name
in
op
.
input
(
in_name
):
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
block
.
_find_var_recursive
(
in_var_name
)
in_var
=
block
.
_find_var_recursive
(
in_var_name
)
...
@@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
...
@@ -165,11 +179,15 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
# set cast_op device to `all`, can reduce send cast_var.
# set cast_op device to `all`, can reduce send cast_var.
# TODO: need remove this after we unified the dynamic
# TODO: need remove this after we unified the dynamic
# and static pipeline interface.
# and static pipeline interface.
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
in_var
.
stop_gradient
:
if
(
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
in_var
.
stop_gradient
):
prev_op
=
None
prev_op
=
None
if
in_var
.
op
is
op
:
if
in_var
.
op
is
op
:
prev_op
=
find_true_prev_op
(
block
.
ops
,
op
,
prev_op
=
find_true_prev_op
(
in_var_name
)
block
.
ops
,
op
,
in_var_name
)
elif
in_var
.
op
is
not
None
:
elif
in_var
.
op
is
not
None
:
prev_op
=
in_var
.
op
prev_op
=
in_var
.
op
...
@@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
...
@@ -177,33 +195,40 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if
prev_op
is
not
None
:
if
prev_op
is
not
None
:
prev_op_device
=
prev_op
.
attr
(
'op_device'
)
prev_op_device
=
prev_op
.
attr
(
'op_device'
)
if
prev_op_device
is
not
None
and
'all'
in
prev_op_device
:
if
(
prev_op_device
is
not
None
and
'all'
in
prev_op_device
):
op_device
=
prev_op_device
op_device
=
prev_op_device
out_var
=
block
.
create_var
(
out_var
=
block
.
create_var
(
name
=
cast_name
,
name
=
cast_name
,
dtype
=
dest_dtype
,
dtype
=
dest_dtype
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
in_var
.
stop_gradient
)
stop_gradient
=
in_var
.
stop_gradient
,
)
block
.
_insert_op_without_sync
(
idx
,
block
.
_insert_op_without_sync
(
idx
,
type
=
"cast"
,
type
=
"cast"
,
inputs
=
{
"X"
:
in_var
},
inputs
=
{
"X"
:
in_var
},
outputs
=
{
"Out"
:
out_var
},
outputs
=
{
"Out"
:
out_var
},
attrs
=
{
attrs
=
{
"in_dtype"
:
in_var
.
dtype
,
"in_dtype"
:
in_var
.
dtype
,
"out_dtype"
:
"out_dtype"
:
out_var
.
dtype
,
out_var
.
dtype
,
"op_device"
:
op_device
,
"op_device"
:
op_device
,
"op_role"
:
"op_role"
:
op
.
attr
(
"op_role"
),
op
.
attr
(
"op_role"
)
,
}
,
}
)
)
num_cast_ops
+=
1
num_cast_ops
+=
1
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
_rename_arg
(
op
,
in_var
.
name
,
out_var
.
name
)
else
:
else
:
if
op
.
has_attr
(
'in_dtype'
):
if
op
.
has_attr
(
'in_dtype'
):
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
op
.
_set_attr
(
'in_dtype'
,
dest_dtype
)
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dest_dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
src_dtype
==
core
.
VarDesc
.
VarType
.
FP32
and
dest_dtype
in
[
core
.
VarDesc
.
VarType
.
FP16
,
core
.
VarDesc
.
VarType
.
BF16
,
]:
for
out_name
in
op
.
output_names
:
for
out_name
in
op
.
output_names
:
if
_keep_fp32_output
(
op
,
out_name
):
if
_keep_fp32_output
(
op
,
out_name
):
continue
continue
...
@@ -212,32 +237,38 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
...
@@ -212,32 +237,38 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
if
out_var
.
type
not
in
_valid_types
:
if
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
out_var
.
desc
.
set_dtype
(
dest_dtype
)
if
op
.
has_attr
(
'out_dtype'
):
if
op
.
has_attr
(
'out_dtype'
):
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
op
.
_set_attr
(
'out_dtype'
,
dest_dtype
)
return
num_cast_ops
return
num_cast_ops
def
_insert_cast_post_op
(
block
,
op
,
idx
,
src_dtype
,
dest_dtype
,
target_name
,
def
_insert_cast_post_op
(
op_var_rename_map
):
block
,
op
,
idx
,
src_dtype
,
dest_dtype
,
target_name
,
op_var_rename_map
):
num_cast_ops
=
0
num_cast_ops
=
0
target_var
=
block
.
var
(
target_name
)
target_var
=
block
.
var
(
target_name
)
if
target_var
.
type
not
in
_valid_types
or
target_var
.
dtype
==
dest_dtype
:
if
target_var
.
type
not
in
_valid_types
or
target_var
.
dtype
==
dest_dtype
:
return
num_cast_ops
return
num_cast_ops
assert
target_var
.
dtype
==
src_dtype
,
\
assert
(
"The real dtype({}) is not equal to the src dtype({})"
.
format
(
target_var
.
dtype
==
src_dtype
_dtype_to_str
(
target_var
.
dtype
),
_dtype_to_str
(
src_dtype
))
),
"The real dtype({}) is not equal to the src dtype({})"
.
format
(
_dtype_to_str
(
target_var
.
dtype
),
_dtype_to_str
(
src_dtype
)
)
cast_name
=
target_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dest_dtype
)
cast_name
=
target_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dest_dtype
)
cast_var
=
block
.
vars
.
get
(
cast_name
)
cast_var
=
block
.
vars
.
get
(
cast_name
)
if
cast_var
is
None
or
cast_var
.
dtype
!=
dest_dtype
:
if
cast_var
is
None
or
cast_var
.
dtype
!=
dest_dtype
:
cast_var
=
block
.
create_var
(
name
=
cast_name
,
cast_var
=
block
.
create_var
(
name
=
cast_name
,
dtype
=
dest_dtype
,
dtype
=
dest_dtype
,
persistable
=
False
,
persistable
=
False
,
stop_gradient
=
target_var
.
stop_gradient
)
stop_gradient
=
target_var
.
stop_gradient
,
block
.
_insert_op
(
idx
,
)
block
.
_insert_op
(
idx
,
type
=
"cast"
,
type
=
"cast"
,
inputs
=
{
"X"
:
target_var
},
inputs
=
{
"X"
:
target_var
},
outputs
=
{
"Out"
:
cast_var
},
outputs
=
{
"Out"
:
cast_var
},
...
@@ -246,7 +277,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
...
@@ -246,7 +277,8 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
"out_dtype"
:
cast_var
.
dtype
,
"out_dtype"
:
cast_var
.
dtype
,
"op_device"
:
op
.
attr
(
"op_device"
),
"op_device"
:
op
.
attr
(
"op_device"
),
"op_role"
:
op
.
attr
(
"op_role"
),
"op_role"
:
op
.
attr
(
"op_role"
),
})
},
)
num_cast_ops
+=
1
num_cast_ops
+=
1
op_var_rename_map
[
block
.
idx
][
target_var
.
name
]
=
cast_var
.
name
op_var_rename_map
[
block
.
idx
][
target_var
.
name
]
=
cast_var
.
name
...
@@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name):
...
@@ -272,8 +304,10 @@ def find_true_prev_op(ops, cur_op, var_name):
prev_op
.
append
(
op
)
prev_op
.
append
(
op
)
if
prev_op
:
if
prev_op
:
if
not
len
(
prev_op
)
==
1
:
if
not
len
(
prev_op
)
==
1
:
raise
ValueError
(
"There must be only one previous op "
raise
ValueError
(
"that outputs {0} variable"
.
format
(
var_name
))
"There must be only one previous op "
"that outputs {0} variable"
.
format
(
var_name
)
)
else
:
else
:
return
prev_op
[
0
]
return
prev_op
[
0
]
return
None
return
None
...
@@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
...
@@ -315,8 +349,7 @@ def find_true_post_op(ops, cur_op, var_name, search_all=False):
def
find_op_index
(
block_desc
,
cur_op_desc
):
def
find_op_index
(
block_desc
,
cur_op_desc
):
"""
""" """
"""
for
idx
in
range
(
block_desc
.
op_size
()):
for
idx
in
range
(
block_desc
.
op_size
()):
if
cur_op_desc
==
block_desc
.
op
(
idx
):
if
cur_op_desc
==
block_desc
.
op
(
idx
):
return
idx
return
idx
...
@@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
...
@@ -350,8 +383,9 @@ def _need_keep_fp32(op, unsupported_op_list, use_fp16_guard):
return
True
return
True
if
use_fp16_guard
:
if
use_fp16_guard
:
if
op
.
has_attr
(
"op_namescope"
)
and
\
if
op
.
has_attr
(
"op_namescope"
)
and
(
(
_fp16_guard_pattern
in
op
.
attr
(
"op_namescope"
)):
_fp16_guard_pattern
in
op
.
attr
(
"op_namescope"
)
):
# op in fp16 guard
# op in fp16 guard
return
False
return
False
else
:
else
:
...
@@ -388,7 +422,12 @@ def fp16_guard():
...
@@ -388,7 +422,12 @@ def fp16_guard():
yield
yield
def
cast_model_to_fp16
(
program
,
amp_lists
=
None
,
use_fp16_guard
=
True
):
def
cast_model_to_fp16
(
program
,
amp_lists
=
None
,
use_fp16_guard
=
True
,
dest_type
=
core
.
VarDesc
.
VarType
.
FP16
,
):
"""
"""
Traverse all ops in the whole model and set their inputs and outputs
Traverse all ops in the whole model and set their inputs and outputs
to the fp16 data type. This function will do some special process for
to the fp16 data type. This function will do some special process for
...
@@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -399,6 +438,7 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
use_fp16_guard(bool): Determine whether to use `fp16_guard` when
use_fp16_guard(bool): Determine whether to use `fp16_guard` when
constructing the program. Default True.
constructing the program. Default True.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
"""
if
amp_lists
is
None
:
if
amp_lists
is
None
:
...
@@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -421,7 +461,8 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
for
in_name
in
op
.
input_names
:
for
in_name
in
op
.
input_names
:
# for ipu, all inputs must be converted to fp16
# for ipu, all inputs must be converted to fp16
if
not
core
.
is_compiled_with_ipu
()
and
_keep_fp32_input
(
if
not
core
.
is_compiled_with_ipu
()
and
_keep_fp32_input
(
op
,
in_name
):
op
,
in_name
):
continue
continue
for
in_var_name
in
op
.
input
(
in_name
):
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
None
in_var
=
None
...
@@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -429,29 +470,36 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
in_var
=
block
.
var
(
in_var_name
)
in_var
=
block
.
var
(
in_var_name
)
except
ValueError
as
e
:
except
ValueError
as
e
:
_logger
.
debug
(
_logger
.
debug
(
"-- {}, try to get it in the global block --"
.
"-- {}, try to get it in the global block --"
.
format
(
format
(
e
))
e
)
)
in_var
=
global_block
.
var
(
in_var_name
)
in_var
=
global_block
.
var
(
in_var_name
)
if
in_var
is
not
None
:
if
in_var
is
not
None
:
_logger
.
debug
(
_logger
.
debug
(
"-- var {} is got in the global block --"
.
"-- var {} is got in the global block --"
.
format
(
format
(
in_var_name
))
in_var_name
)
)
if
in_var
is
None
or
in_var
.
type
not
in
_valid_types
:
if
in_var
is
None
or
in_var
.
type
not
in
_valid_types
:
continue
continue
if
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
if
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
in_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
in_var
.
desc
.
set_dtype
(
dest_type
)
to_fp16_var_names
.
add
(
in_var_name
)
to_fp16_var_names
.
add
(
in_var_name
)
_logger
.
debug
(
_logger
.
debug
(
"-- op type: {}, in var name: {}, in var dtype: {} --"
.
"-- op type: {}, in var name: {}, in var dtype: {} --"
.
format
(
format
(
op
.
type
,
in_var_name
,
in_var
.
dtype
))
op
.
type
,
in_var_name
,
in_var
.
dtype
)
)
for
out_name
in
op
.
output_names
:
for
out_name
in
op
.
output_names
:
# for ipu, all outputs must be converted to fp16
# for ipu, all outputs must be converted to fp16
if
not
core
.
is_compiled_with_ipu
()
and
_keep_fp32_output
(
if
not
core
.
is_compiled_with_ipu
()
and
_keep_fp32_output
(
op
,
out_name
):
op
,
out_name
):
continue
continue
for
out_var_name
in
op
.
output
(
out_name
):
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
None
out_var
=
None
...
@@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -459,32 +507,35 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
out_var
=
block
.
var
(
out_var_name
)
out_var
=
block
.
var
(
out_var_name
)
except
ValueError
as
e
:
except
ValueError
as
e
:
_logger
.
debug
(
_logger
.
debug
(
"-- {}, try to get it in the global block --"
.
"-- {}, try to get it in the global block --"
.
format
(
format
(
e
))
e
)
)
out_var
=
global_block
.
var
(
out_var_name
)
out_var
=
global_block
.
var
(
out_var_name
)
if
out_var
is
not
None
:
if
out_var
is
not
None
:
_logger
.
debug
(
_logger
.
debug
(
"-- var {} is got in the global block --"
.
"-- var {} is got in the global block --"
.
format
(
format
(
out_var_name
))
out_var_name
)
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP16
)
out_var
.
desc
.
set_dtype
(
dest_type
)
_logger
.
debug
(
_logger
.
debug
(
"-- op type: {}, out var name: {}, out var dtype: {} --"
"-- op type: {}, out var name: {}, out var dtype: {} --"
.
format
(
.
format
(
op
.
type
,
out_var_name
,
out_var
.
dtype
))
op
.
type
,
out_var_name
,
out_var
.
dtype
if
op
.
has_attr
(
'in_dtype'
)
and
op
.
attr
(
)
'in_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
)
op
.
_set_attr
(
'in_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
for
attr_name
in
[
'in_dtype'
,
'out_dtype'
,
'dtype'
]:
if
op
.
has_attr
(
'out_dtype'
)
and
op
.
attr
(
if
(
'out_dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
has_attr
(
attr_name
)
op
.
_set_attr
(
'out_dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
and
op
.
attr
(
attr_name
)
==
core
.
VarDesc
.
VarType
.
FP32
if
op
.
has_attr
(
'dtype'
)
and
op
.
attr
(
):
'dtype'
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
attr_name
,
dest_type
)
op
.
_set_attr
(
'dtype'
,
core
.
VarDesc
.
VarType
.
FP16
)
# process ops in keep_fp32_ops
# process ops in keep_fp32_ops
op_var_rename_map
=
[
op_var_rename_map
=
[
...
@@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -497,25 +548,29 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
op
=
ops
[
idx
]
op
=
ops
[
idx
]
num_cast_ops
=
0
num_cast_ops
=
0
if
op
in
keep_fp32_ops
:
if
op
in
keep_fp32_ops
:
pre_cast_num
=
_insert_cast_op
(
block
,
op
,
idx
,
pre_cast_num
=
_insert_cast_op
(
core
.
VarDesc
.
VarType
.
FP16
,
block
,
op
,
idx
,
dest_type
,
core
.
VarDesc
.
VarType
.
FP32
core
.
VarDesc
.
VarType
.
FP32
)
)
num_cast_ops
+=
pre_cast_num
num_cast_ops
+=
pre_cast_num
for
out_var_name
in
op
.
output_arg_names
:
for
out_var_name
in
op
.
output_arg_names
:
out_var
=
block
.
vars
.
get
(
out_var_name
)
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
if
out_var
.
dtype
==
dest_type
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
post_ops
=
find_true_post_op
(
ops
,
op
,
out_var_name
)
post_ops
=
find_true_post_op
(
ops
,
op
,
out_var_name
)
for
post_op
in
post_ops
:
for
post_op
in
post_ops
:
if
post_op
in
keep_fp32_ops
:
if
post_op
in
keep_fp32_ops
:
continue
continue
post_cast_num
=
_insert_cast_post_op
(
post_cast_num
=
_insert_cast_post_op
(
block
,
op
,
idx
+
pre_cast_num
+
1
,
block
,
op
,
idx
+
pre_cast_num
+
1
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
FP16
,
out_var_name
,
dest_type
,
op_var_rename_map
)
out_var_name
,
op_var_rename_map
,
)
num_cast_ops
+=
post_cast_num
num_cast_ops
+=
post_cast_num
idx
+=
num_cast_ops
+
1
idx
+=
num_cast_ops
+
1
...
@@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
...
@@ -523,7 +578,22 @@ def cast_model_to_fp16(program, amp_lists=None, use_fp16_guard=True):
return
to_fp16_var_names
return
to_fp16_var_names
def
cast_parameters_to_fp16
(
place
,
program
,
scope
=
None
,
to_fp16_var_names
=
None
):
def
_convert_float_to_bfloat16
(
place
,
fp32_array
):
paddle
.
disable_static
()
framework
.
_set_expected_place
(
place
)
fp32_tensor
=
paddle
.
to_tensor
(
fp32_array
)
bf16_array
=
paddle
.
cast
(
fp32_tensor
,
paddle
.
bfloat16
).
numpy
()
paddle
.
enable_static
()
return
bf16_array
def
cast_parameters_to_fp16
(
place
,
program
,
scope
=
None
,
to_fp16_var_names
=
None
,
dest_type
=
core
.
VarDesc
.
VarType
.
FP16
,
):
"""
"""
Traverse all parameters in the whole model and set them to the FP16 data type.
Traverse all parameters in the whole model and set them to the FP16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Whereas, this function will keep parameters of batchnorms in FP32.
...
@@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
...
@@ -535,6 +605,7 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
to_fp16_var_names(set|list, optional): The data types of vars in `to_fp16_var_names`
will be set to FP16. Usually, it is the returned
will be set to FP16. Usually, it is the returned
value of `cast_model_to_fp16` API.
value of `cast_model_to_fp16` API.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
"""
all_parameters
=
[]
all_parameters
=
[]
for
block
in
program
.
blocks
:
for
block
in
program
.
blocks
:
...
@@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
...
@@ -544,13 +615,22 @@ def cast_parameters_to_fp16(place, program, scope=None, to_fp16_var_names=None):
var_scope
=
scope
if
scope
else
global_scope
()
var_scope
=
scope
if
scope
else
global_scope
()
for
param
in
all_parameters
:
for
param
in
all_parameters
:
if
param
.
name
in
fp16_var_names
:
if
param
.
name
in
fp16_var_names
:
_logger
.
debug
(
"---- cast {} to fp16 dtype ----"
.
format
(
param
.
name
))
_logger
.
debug
(
"---- cast {} to fp16/bf16 dtype ----"
.
format
(
param
.
name
)
)
if
var_scope
.
find_var
(
param
.
name
):
param_t
=
var_scope
.
find_var
(
param
.
name
).
get_tensor
()
param_t
=
var_scope
.
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
data
=
np
.
array
(
param_t
)
if
dest_type
==
core
.
VarDesc
.
VarType
.
BF16
:
bf16_data
=
_convert_float_to_bfloat16
(
place
,
data
)
param_t
.
set
(
bf16_data
,
place
)
else
:
param_t
.
set
(
np
.
float16
(
data
),
place
)
param_t
.
set
(
np
.
float16
(
data
),
place
)
else
:
_logger
.
warning
(
f
"Cannot find
{
param
.
name
}
"
)
def
rewrite_program
(
main_prog
,
amp_lists
):
def
rewrite_program
(
main_prog
,
amp_lists
,
dest_type
=
core
.
VarDesc
.
VarType
.
FP16
):
"""
"""
Traverse all ops in current block and insert cast op according to
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
which set current op belongs to.
...
@@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists):
...
@@ -569,6 +649,7 @@ def rewrite_program(main_prog, amp_lists):
Args:
Args:
main_prog (Program): The main program for training.
main_prog (Program): The main program for training.
dest_type(core.VarDesc.VarType): the cast type. such as core.VarDesc.VarType.FP16 and core.VarDesc.VarType.BF16.
"""
"""
block
=
main_prog
.
global_block
()
block
=
main_prog
.
global_block
()
block
.
_sync_with_cpp
()
block
.
_sync_with_cpp
()
...
@@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists):
...
@@ -585,7 +666,8 @@ def rewrite_program(main_prog, amp_lists):
continue
continue
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
if
amp_lists
.
black_varnames
is
not
None
and
_is_in_black_varnames
(
op
,
amp_lists
):
op
,
amp_lists
):
black_op_set
.
add
(
op
)
black_op_set
.
add
(
op
)
continue
continue
...
@@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists):
...
@@ -611,11 +693,15 @@ def rewrite_program(main_prog, amp_lists):
else
:
else
:
prev_op
=
in_var
.
op
prev_op
=
in_var
.
op
# if it's one of inputs
# if it's one of inputs
if
prev_op
in
black_op_set
or
\
if
(
prev_op
.
type
in
amp_lists
.
black_list
:
prev_op
in
black_op_set
or
prev_op
.
type
in
amp_lists
.
black_list
):
is_black_op
=
True
is_black_op
=
True
elif
prev_op
in
white_op_set
or
\
elif
(
prev_op
.
type
in
amp_lists
.
white_list
:
prev_op
in
white_op_set
or
prev_op
.
type
in
amp_lists
.
white_list
):
is_white_op
=
True
is_white_op
=
True
if
is_black_op
:
if
is_black_op
:
black_op_set
.
add
(
op
)
black_op_set
.
add
(
op
)
...
@@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists):
...
@@ -633,13 +719,13 @@ def rewrite_program(main_prog, amp_lists):
op
=
ops
[
idx
]
op
=
ops
[
idx
]
num_cast_ops
=
0
num_cast_ops
=
0
if
op
in
black_op_set
:
if
op
in
black_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
num_cast_ops
=
_insert_cast_op
(
core
.
VarDesc
.
VarType
.
FP16
,
block
,
op
,
idx
,
dest_type
,
core
.
VarDesc
.
VarType
.
FP32
core
.
VarDesc
.
VarType
.
FP32
)
)
elif
op
in
white_op_set
:
elif
op
in
white_op_set
:
num_cast_ops
=
_insert_cast_op
(
block
,
op
,
idx
,
num_cast_ops
=
_insert_cast_op
(
core
.
VarDesc
.
VarType
.
FP32
,
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
dest_type
core
.
VarDesc
.
VarType
.
FP16
)
)
else
:
else
:
pass
pass
...
@@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads):
...
@@ -670,13 +756,16 @@ def update_role_var_grad(main_prog, params_grads):
if
role
&
int
(
BACKWARD
)
and
op
.
has_attr
(
'op_role_var'
):
if
role
&
int
(
BACKWARD
)
and
op
.
has_attr
(
'op_role_var'
):
op
.
_remove_attr
(
"op_role_var"
)
op
.
_remove_attr
(
"op_role_var"
)
else
:
else
:
raise
ValueError
(
"The cast op {0} must be in BACKWARD role "
raise
ValueError
(
"and have op_role_var attr."
.
format
(
op
))
"The cast op {0} must be in BACKWARD role "
"and have op_role_var attr."
.
format
(
op
)
)
fp16_grad_name
=
op
.
input
(
op
.
input_names
[
0
])[
0
]
fp16_grad_name
=
op
.
input
(
op
.
input_names
[
0
])[
0
]
op_for_fp16_grad
=
find_true_prev_op
(
block
.
ops
,
op
,
fp16_grad_name
)
op_for_fp16_grad
=
find_true_prev_op
(
block
.
ops
,
op
,
fp16_grad_name
)
op_role_var_attr_name
=
\
op_role_var_attr_name
=
(
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
core
.
op_proto_and_checker_maker
.
kOpRoleVarAttrName
()
)
attr_val
=
[
p
.
name
,
fp16_grad_name
]
attr_val
=
[
p
.
name
,
fp16_grad_name
]
if
op_for_fp16_grad
.
has_attr
(
op_role_var_attr_name
):
if
op_for_fp16_grad
.
has_attr
(
op_role_var_attr_name
):
attr_val
.
extend
(
op_for_fp16_grad
.
attr
(
op_role_var_attr_name
))
attr_val
.
extend
(
op_for_fp16_grad
.
attr
(
op_role_var_attr_name
))
...
@@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads):
...
@@ -690,18 +779,22 @@ def update_role_var_grad(main_prog, params_grads):
continue
continue
post_ops
=
find_true_post_op
(
block
.
ops
,
op
,
g
.
name
)
post_ops
=
find_true_post_op
(
block
.
ops
,
op
,
g
.
name
)
if
post_ops
:
if
post_ops
:
raise
ValueError
(
"The cast op {0}'s output should not be"
raise
ValueError
(
"The cast op {0}'s output should not be"
"used by a non-optimize op, however, it"
"used by a non-optimize op, however, it"
"is used by {1}"
.
format
(
op
,
post_ops
[
0
]))
"is used by {1}"
.
format
(
op
,
post_ops
[
0
])
)
# add new op in the python and cpp at the same time
# add new op in the python and cpp at the same time
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
=
block
.
desc
.
append_op
()
new_op_desc
.
copy_from
(
op
.
desc
)
new_op_desc
.
copy_from
(
op
.
desc
)
new_op
=
framework
.
Operator
(
block
=
block
,
new_op
=
framework
.
Operator
(
block
=
block
,
desc
=
new_op_desc
,
desc
=
new_op_desc
,
type
=
None
,
type
=
None
,
inputs
=
None
,
inputs
=
None
,
outputs
=
None
,
outputs
=
None
,
attrs
=
None
)
attrs
=
None
,
)
block
.
ops
.
append
(
new_op
)
block
.
ops
.
append
(
new_op
)
op_idx
=
find_op_index
(
block
.
desc
,
op
.
desc
)
op_idx
=
find_op_index
(
block
.
desc
,
op
.
desc
)
if
op_idx
==
-
1
:
if
op_idx
==
-
1
:
...
...
python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py
浏览文件 @
6959eae5
...
@@ -18,19 +18,32 @@ import six
...
@@ -18,19 +18,32 @@ import six
import
paddle
import
paddle
from
paddle.fluid
import
framework
,
backward
,
core
,
program_guard
from
paddle.fluid
import
framework
,
backward
,
core
,
program_guard
from
paddle.fluid.executor
import
_is_enable_standalone_executor
,
_is_dy2st_enable_standalone_executor
from
paddle.fluid.executor
import
(
_is_enable_standalone_executor
,
_is_dy2st_enable_standalone_executor
,
)
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph
import
layers
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.base
import
switch_to_static_graph
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static
import
logging_utils
from
paddle.fluid.dygraph.dygraph_to_static.return_transformer
import
RETURN_NO_VALUE_MAGIC_NUM
from
paddle.fluid.dygraph.dygraph_to_static.return_transformer
import
(
RETURN_NO_VALUE_MAGIC_NUM
,
)
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
flatten
from
paddle.fluid.layers.utils
import
pack_sequence_as
from
paddle.fluid.layers.utils
import
pack_sequence_as
from
paddle.fluid.layers.utils
import
_hash_with_id
from
paddle.fluid.layers.utils
import
_hash_with_id
from
paddle.fluid.compiler
import
BuildStrategy
from
paddle.fluid.compiler
import
BuildStrategy
from
paddle.fluid.framework
import
_apply_pass
from
paddle.fluid.framework
import
_apply_pass
from
paddle.fluid.contrib.mixed_precision.decorator
import
AutoMixedPrecisionLists
from
paddle.fluid.contrib.mixed_precision.decorator
import
(
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
rewrite_program
,
cast_model_to_fp16
AutoMixedPrecisionLists
,
from
paddle.fluid.dygraph.amp.auto_cast
import
_in_amp_guard
,
_in_pure_fp16_guard
)
from
paddle.fluid.contrib.mixed_precision.fp16_utils
import
(
rewrite_program
,
cast_model_to_fp16
,
)
from
paddle.fluid.dygraph.amp.auto_cast
import
(
_in_amp_guard
,
_in_pure_fp16_guard
,
)
import
paddle.compat
as
cpt
import
paddle.compat
as
cpt
from
paddle
import
_C_ops
,
_legacy_C_ops
from
paddle
import
_C_ops
,
_legacy_C_ops
...
@@ -64,7 +77,8 @@ class NestSequence(object):
...
@@ -64,7 +77,8 @@ class NestSequence(object):
var_ids
=
[]
var_ids
=
[]
for
idx
,
var
in
enumerate
(
self
.
__input_list
):
for
idx
,
var
in
enumerate
(
self
.
__input_list
):
if
isinstance
(
if
isinstance
(
var
,
(
framework
.
Variable
,
core
.
VarBase
,
core
.
eager
.
Tensor
)):
var
,
(
framework
.
Variable
,
core
.
VarBase
,
core
.
eager
.
Tensor
)
):
var_ids
.
append
(
idx
)
var_ids
.
append
(
idx
)
return
var_ids
return
var_ids
...
@@ -77,15 +91,17 @@ class NestSequence(object):
...
@@ -77,15 +91,17 @@ class NestSequence(object):
warning_types
=
set
()
warning_types
=
set
()
for
var
in
self
.
__input_list
:
for
var
in
self
.
__input_list
:
if
not
isinstance
(
if
not
isinstance
(
var
,
var
,
(
framework
.
Variable
,
core
.
VarBase
,
core
.
eager
.
Tensor
)
(
framework
.
Variable
,
core
.
VarBase
,
core
.
eager
.
Tensor
)
):
):
warning_types
.
add
(
type
(
var
))
warning_types
.
add
(
type
(
var
))
if
warning_types
:
if
warning_types
:
logging_utils
.
warn
(
logging_utils
.
warn
(
"Output of traced function contains non-tensor type values: {}. "
"Output of traced function contains non-tensor type values: {}. "
"Currently, We don't support to update them while training and will return "
"Currently, We don't support to update them while training and will return "
"what we first saw. Please try to return them as tensor."
.
"what we first saw. Please try to return them as tensor."
.
format
(
format
(
list
(
warning_types
)))
list
(
warning_types
)
)
)
@
property
@
property
def
var_ids
(
self
):
def
var_ids
(
self
):
...
@@ -139,12 +155,9 @@ class PartialProgramLayer:
...
@@ -139,12 +155,9 @@ class PartialProgramLayer:
Layer: A Layer object that run all ops internally in static mode.
Layer: A Layer object that run all ops internally in static mode.
"""
"""
def
__init__
(
self
,
def
__init__
(
main_program
,
self
,
main_program
,
inputs
,
outputs
,
parameters
=
None
,
**
kwargs
inputs
,
):
outputs
,
parameters
=
None
,
**
kwargs
):
super
(
PartialProgramLayer
,
self
).
__init__
()
super
(
PartialProgramLayer
,
self
).
__init__
()
self
.
_inputs
=
NestSequence
(
inputs
)
self
.
_inputs
=
NestSequence
(
inputs
)
self
.
_outputs
=
NestSequence
(
outputs
,
need_check
=
True
)
self
.
_outputs
=
NestSequence
(
outputs
,
need_check
=
True
)
...
@@ -160,14 +173,18 @@ class PartialProgramLayer:
...
@@ -160,14 +173,18 @@ class PartialProgramLayer:
# Set default mode to train
# Set default mode to train
self
.
training
=
True
self
.
training
=
True
custom_white_list
,
custom_black_list
=
None
,
None
amp_dtype
,
custom_white_list
,
custom_black_list
=
None
,
None
,
None
tracer
=
framework
.
_dygraph_tracer
()
tracer
=
framework
.
_dygraph_tracer
()
if
tracer
:
if
tracer
:
custom_white_list
,
custom_black_list
=
tracer
.
_get_amp_op_list
()
custom_white_list
,
custom_black_list
=
tracer
.
_get_amp_op_list
()
amp_dtype
=
tracer
.
_amp_dtype
if
amp_dtype
is
not
None
and
amp_dtype
in
[
'float16'
,
'bfloat16'
]:
# For AMP training
# For AMP training
self
.
_amp_list
=
AutoMixedPrecisionLists
(
self
.
_amp_list
=
AutoMixedPrecisionLists
(
custom_white_list
=
custom_white_list
,
custom_white_list
=
custom_white_list
,
custom_black_list
=
custom_black_list
)
custom_black_list
=
custom_black_list
,
dtype
=
amp_dtype
,
)
# program_id -> list(scope)
# program_id -> list(scope)
self
.
_scope_cache
=
{}
self
.
_scope_cache
=
{}
...
@@ -203,7 +220,8 @@ class PartialProgramLayer:
...
@@ -203,7 +220,8 @@ class PartialProgramLayer:
return
self
.
_origin_main_program
.
clone
(
for_test
=
is_infer_mode
)
return
self
.
_origin_main_program
.
clone
(
for_test
=
is_infer_mode
)
else
:
else
:
train_program
=
self
.
_append_backward_desc
(
train_program
=
self
.
_append_backward_desc
(
self
.
_origin_main_program
)
self
.
_origin_main_program
)
# Note: Only set grad type once after initializing train program. So we put it here.
# Note: Only set grad type once after initializing train program. So we put it here.
self
.
_set_grad_type
(
self
.
_params
,
train_program
)
self
.
_set_grad_type
(
self
.
_params
,
train_program
)
return
train_program
return
train_program
...
@@ -223,16 +241,18 @@ class PartialProgramLayer:
...
@@ -223,16 +241,18 @@ class PartialProgramLayer:
@
switch_to_static_graph
@
switch_to_static_graph
def
_create_pure_fp16_program
(
self
,
is_infer_mode
=
False
):
def
_create_pure_fp16_program
(
self
,
is_infer_mode
=
False
):
pure_fp16_program
=
self
.
_origin_main_program
.
clone
(
pure_fp16_program
=
self
.
_origin_main_program
.
clone
(
for_test
=
is_infer_mode
)
for_test
=
is_infer_mode
)
with
program_guard
(
pure_fp16_program
):
with
program_guard
(
pure_fp16_program
):
cast_model_to_fp16
(
pure_fp16_program
,
cast_model_to_fp16
(
self
.
_amp_list
,
pure_fp16_program
,
self
.
_amp_list
,
use_fp16_guard
=
False
use_fp16_guard
=
False
)
)
if
is_infer_mode
:
if
is_infer_mode
:
return
pure_fp16_program
return
pure_fp16_program
else
:
else
:
train_pure_fp16_program
=
self
.
_append_backward_desc
(
train_pure_fp16_program
=
self
.
_append_backward_desc
(
pure_fp16_program
)
pure_fp16_program
)
self
.
_set_grad_type
(
self
.
_params
,
train_pure_fp16_program
)
self
.
_set_grad_type
(
self
.
_params
,
train_pure_fp16_program
)
return
train_pure_fp16_program
return
train_pure_fp16_program
...
@@ -240,23 +260,27 @@ class PartialProgramLayer:
...
@@ -240,23 +260,27 @@ class PartialProgramLayer:
def
_create_forward_backward_train_program
(
self
):
def
_create_forward_backward_train_program
(
self
):
whole_program
=
self
.
_create_program
()
whole_program
=
self
.
_create_program
()
forward_end_op_index
=
self
.
_infer_program
.
desc
.
block
(
0
).
op_size
()
forward_end_op_index
=
self
.
_infer_program
.
desc
.
block
(
0
).
op_size
()
return
self
.
_get_forward_backward_program_form
(
whole_program
,
return
self
.
_get_forward_backward_program_form
(
forward_end_op_index
)
whole_program
,
forward_end_op_index
)
@
switch_to_static_graph
@
switch_to_static_graph
def
_create_forward_backward_train_amp_program
(
self
):
def
_create_forward_backward_train_amp_program
(
self
):
whole_program
=
self
.
_create_amp_program
()
whole_program
=
self
.
_create_amp_program
()
forward_end_op_index
=
self
.
_infer_amp_program
.
desc
.
block
(
0
).
op_size
()
forward_end_op_index
=
self
.
_infer_amp_program
.
desc
.
block
(
0
).
op_size
()
return
self
.
_get_forward_backward_program_form
(
whole_program
,
return
self
.
_get_forward_backward_program_form
(
forward_end_op_index
)
whole_program
,
forward_end_op_index
)
@
switch_to_static_graph
@
switch_to_static_graph
def
_create_forward_backward_train_pure_fp16_program
(
self
):
def
_create_forward_backward_train_pure_fp16_program
(
self
):
whole_program
=
self
.
_create_pure_fp16_program
()
whole_program
=
self
.
_create_pure_fp16_program
()
forward_end_op_index
=
self
.
_infer_pure_fp16_program
.
desc
.
block
(
forward_end_op_index
=
self
.
_infer_pure_fp16_program
.
desc
.
block
(
0
).
op_size
()
0
return
self
.
_get_forward_backward_program_form
(
whole_program
,
).
op_size
()
forward_end_op_index
)
return
self
.
_get_forward_backward_program_form
(
whole_program
,
forward_end_op_index
)
@
LazyInitialized
@
LazyInitialized
def
_train_program
(
self
):
def
_train_program
(
self
):
...
@@ -352,8 +376,9 @@ class PartialProgramLayer:
...
@@ -352,8 +376,9 @@ class PartialProgramLayer:
@
LazyInitialized
@
LazyInitialized
def
_train_program_id
(
self
):
def
_train_program_id
(
self
):
program_id
=
_hash_with_id
(
self
.
_train_program
,
self
)
program_id
=
_hash_with_id
(
self
.
_train_program
,
self
)
core
.
_set_cached_executor_build_strategy
(
program_id
,
core
.
_set_cached_executor_build_strategy
(
self
.
_build_strategy
)
program_id
,
self
.
_build_strategy
)
return
program_id
return
program_id
@
LazyInitialized
@
LazyInitialized
...
@@ -363,8 +388,9 @@ class PartialProgramLayer:
...
@@ -363,8 +388,9 @@ class PartialProgramLayer:
@
LazyInitialized
@
LazyInitialized
def
_train_amp_program_id
(
self
):
def
_train_amp_program_id
(
self
):
program_id
=
_hash_with_id
(
self
.
_train_amp_program
,
self
)
program_id
=
_hash_with_id
(
self
.
_train_amp_program
,
self
)
core
.
_set_cached_executor_build_strategy
(
program_id
,
core
.
_set_cached_executor_build_strategy
(
self
.
_build_strategy
)
program_id
,
self
.
_build_strategy
)
return
program_id
return
program_id
@
LazyInitialized
@
LazyInitialized
...
@@ -374,8 +400,9 @@ class PartialProgramLayer:
...
@@ -374,8 +400,9 @@ class PartialProgramLayer:
@
LazyInitialized
@
LazyInitialized
def
_train_pure_fp16_program_id
(
self
):
def
_train_pure_fp16_program_id
(
self
):
program_id
=
_hash_with_id
(
self
.
_train_pure_fp16_program
,
self
)
program_id
=
_hash_with_id
(
self
.
_train_pure_fp16_program
,
self
)
core
.
_set_cached_executor_build_strategy
(
program_id
,
core
.
_set_cached_executor_build_strategy
(
self
.
_build_strategy
)
program_id
,
self
.
_build_strategy
)
return
program_id
return
program_id
@
LazyInitialized
@
LazyInitialized
...
@@ -411,8 +438,9 @@ class PartialProgramLayer:
...
@@ -411,8 +438,9 @@ class PartialProgramLayer:
return
main_program
return
main_program
def
prepare_gradient_aggregation
(
self
,
start_idx
,
main_program
,
def
prepare_gradient_aggregation
(
target_program
):
self
,
start_idx
,
main_program
,
target_program
):
"""
"""
Why we need add gradient aggregation operation ?
Why we need add gradient aggregation operation ?
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
In some cases, if non leaf nodes are used as output, gradient overwriting will occur, such as
...
@@ -431,7 +459,7 @@ class PartialProgramLayer:
...
@@ -431,7 +459,7 @@ class PartialProgramLayer:
"""
"""
if
not
isinstance
(
var
,
framework
.
Variable
)
or
var
.
type
not
in
[
if
not
isinstance
(
var
,
framework
.
Variable
)
or
var
.
type
not
in
[
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
LOD_TENSOR
,
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
core
.
VarDesc
.
VarType
.
SELECTED_ROWS
,
]:
]:
return
False
return
False
if
var
.
dtype
not
in
[
paddle
.
float32
,
paddle
.
float64
]:
if
var
.
dtype
not
in
[
paddle
.
float32
,
paddle
.
float64
]:
...
@@ -448,20 +476,28 @@ class PartialProgramLayer:
...
@@ -448,20 +476,28 @@ class PartialProgramLayer:
new_grad_name
=
var
.
name
+
suffix
+
"@GRAD"
new_grad_name
=
var
.
name
+
suffix
+
"@GRAD"
finded_ops
=
list
(
finded_ops
=
list
(
filter
(
filter
(
lambda
x
:
x
[
0
]
>=
start_idx
and
any
([
lambda
x
:
x
[
0
]
>=
start_idx
and
any
(
[
out_arg
==
var_grad_name
out_arg
==
var_grad_name
for
out_arg
in
x
[
1
].
output_arg_names
for
out_arg
in
x
[
1
].
output_arg_names
]),
enumerate
(
target_program
.
block
(
0
).
ops
)))
]
),
enumerate
(
target_program
.
block
(
0
).
ops
),
)
)
# len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may equals zero when stop_gradient works.
# len(finded_ops) may > 1, because we may have fill_constant op.
# len(finded_ops) may > 1, because we may have fill_constant op.
if
len
(
finded_ops
)
==
0
:
if
len
(
finded_ops
)
==
0
:
return
None
return
None
# step1: create a new var named var.name@GRAD
# step1: create a new var named var.name@GRAD
target_program
.
block
(
0
).
create_var
(
name
=
new_grad_name
,
target_program
.
block
(
0
).
create_var
(
name
=
new_grad_name
,
type
=
var
.
type
,
type
=
var
.
type
,
dtype
=
var
.
dtype
,
dtype
=
var
.
dtype
,
shape
=
var
.
shape
)
shape
=
var
.
shape
,
)
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
# step2: rename the var.name@GRAD to var.name@GRAD@dy2static
for
idx
,
op
in
finded_ops
:
for
idx
,
op
in
finded_ops
:
op
.
_rename_input
(
var_grad_name
,
new_grad_name
)
op
.
_rename_input
(
var_grad_name
,
new_grad_name
)
...
@@ -472,11 +508,13 @@ class PartialProgramLayer:
...
@@ -472,11 +508,13 @@ class PartialProgramLayer:
finded_ops
[
-
1
][
0
]
+
1
,
finded_ops
[
-
1
][
0
]
+
1
,
type
=
'sum'
,
type
=
'sum'
,
inputs
=
{
'X'
:
[
var_grad_name
,
new_grad_name
]},
inputs
=
{
'X'
:
[
var_grad_name
,
new_grad_name
]},
outputs
=
{
"Out"
:
var_grad_name
})
outputs
=
{
"Out"
:
var_grad_name
},
)
return
None
return
None
to_processed_vars
=
list
(
to_processed_vars
=
list
(
filter
(
_need_aggregation
,
self
.
_outputs
.
tolist
()))
filter
(
_need_aggregation
,
self
.
_outputs
.
tolist
())
)
for
_var
in
to_processed_vars
:
for
_var
in
to_processed_vars
:
_insert_aggregation_ops_for_var
(
target_program
,
_var
)
_insert_aggregation_ops_for_var
(
target_program
,
_var
)
...
@@ -492,8 +530,9 @@ class PartialProgramLayer:
...
@@ -492,8 +530,9 @@ class PartialProgramLayer:
if
targets
and
self
.
_params
:
if
targets
and
self
.
_params
:
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
backward
.
gradients
(
targets
=
targets
,
inputs
=
[])
start_idx
=
len
(
start_idx
=
len
(
main_program
.
block
(
0
).
ops
)
+
2
*
len
(
main_program
.
block
(
0
).
ops
)
+
2
*
len
(
self
.
_outputs
.
tolist
())
self
.
_outputs
.
tolist
()
)
self
.
prepare_gradient_aggregation
(
start_idx
,
main_program
,
program
)
self
.
prepare_gradient_aggregation
(
start_idx
,
main_program
,
program
)
...
@@ -512,7 +551,10 @@ class PartialProgramLayer:
...
@@ -512,7 +551,10 @@ class PartialProgramLayer:
found_param
=
False
found_param
=
False
for
block
in
program
.
blocks
:
for
block
in
program
.
blocks
:
for
op
in
block
.
ops
:
for
op
in
block
.
ops
:
if
param
.
name
in
op
.
input_arg_names
or
param
.
name
in
op
.
output_arg_names
:
if
(
param
.
name
in
op
.
input_arg_names
or
param
.
name
in
op
.
output_arg_names
):
required_params
.
append
(
param
)
required_params
.
append
(
param
)
found_param
=
True
found_param
=
True
break
break
...
@@ -529,15 +571,21 @@ class PartialProgramLayer:
...
@@ -529,15 +571,21 @@ class PartialProgramLayer:
var_desc
=
block
.
vars
[
name
].
desc
var_desc
=
block
.
vars
[
name
].
desc
var_base
=
None
var_base
=
None
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
_in_eager_mode_
:
var_base
=
core
.
VarBase
(
var_desc
.
dtype
(),
var_base
=
core
.
VarBase
(
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_desc
.
shape
(),
var_desc
.
name
(),
var_desc
.
name
(),
var_desc
.
type
(),
False
)
var_desc
.
type
(),
False
,
)
else
:
else
:
var_base
=
core
.
eager
.
Tensor
(
var_desc
.
dtype
(),
var_base
=
core
.
eager
.
Tensor
(
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_desc
.
shape
(),
var_desc
.
name
(),
var_desc
.
name
(),
var_desc
.
type
(),
False
)
var_desc
.
type
(),
False
,
)
double_grads
.
append
(
var_base
)
double_grads
.
append
(
var_base
)
return
self
.
_valid_vars
(
double_grads
)
return
self
.
_valid_vars
(
double_grads
)
...
@@ -557,36 +605,62 @@ class PartialProgramLayer:
...
@@ -557,36 +605,62 @@ class PartialProgramLayer:
attrs
=
[
attrs
=
[
'global_block'
,
'global_block'
,
self
.
program
.
desc
.
block
(
0
),
'start_op_index'
,
0
,
'end_op_index'
,
self
.
program
.
desc
.
block
(
0
),
self
.
_get_end_op_index
(),
'is_test'
,
not
self
.
training
,
'start_op_index'
,
'program_id'
,
self
.
program_id
0
,
'end_op_index'
,
self
.
_get_end_op_index
(),
'is_test'
,
not
self
.
training
,
'program_id'
,
self
.
program_id
,
]
]
if
self
.
_cuda_graph_capture_mode
:
if
self
.
_cuda_graph_capture_mode
:
attrs
.
extend
(
attrs
.
extend
(
(
'cuda_graph_capture_mode'
,
self
.
_cuda_graph_capture_mode
,
(
'cuda_graph_pool_id'
,
self
.
_cuda_graph_pool_id
))
'cuda_graph_capture_mode'
,
self
.
_cuda_graph_capture_mode
,
use_interpretorcore
=
_is_enable_standalone_executor
(
'cuda_graph_pool_id'
,
)
and
_is_dy2st_enable_standalone_executor
()
self
.
_cuda_graph_pool_id
,
)
)
use_interpretorcore
=
(
_is_enable_standalone_executor
()
and
_is_dy2st_enable_standalone_executor
()
)
attrs
.
extend
((
'use_interpretorcore'
,
use_interpretorcore
))
attrs
.
extend
((
'use_interpretorcore'
,
use_interpretorcore
))
if
use_interpretorcore
:
if
use_interpretorcore
:
attrs
.
extend
(
attrs
.
extend
(
(
'forward_global_block'
,
self
.
forward_program
.
desc
.
block
(
0
),
(
'backward_global_block'
,
self
.
backward_program
.
desc
.
block
(
0
)))
'forward_global_block'
,
self
.
forward_program
.
desc
.
block
(
0
),
'backward_global_block'
,
self
.
backward_program
.
desc
.
block
(
0
),
)
)
_legacy_C_ops
.
run_program
(
_legacy_C_ops
.
run_program
(
self
.
_valid_vars
(
in_vars
),
self
.
_valid_vars
(
self
.
_params
),
self
.
_valid_vars
(
in_vars
),
self
.
_valid_vars
(
self
.
_params
),
self
.
_valid_vars
(
out_vars
),
self
.
_valid_vars
(
out_vars
),
self
.
_create_scope_vec
(
program_id
=
self
.
program_id
,
self
.
_create_scope_vec
(
use_scope_cache
=
True
),
program_id
=
self
.
program_id
,
use_scope_cache
=
True
self
.
_double_grads
,
self
.
_cuda_graph_vec
,
*
attrs
)
),
self
.
_double_grads
,
self
.
_cuda_graph_vec
,
*
attrs
)
else
:
else
:
_legacy_C_ops
.
run_program
(
self
.
_valid_vars
(
in_vars
),
_legacy_C_ops
.
run_program
(
self
.
_valid_vars
(
in_vars
),
self
.
_valid_vars
(
self
.
_params
),
self
.
_valid_vars
(
self
.
_params
),
self
.
_valid_vars
(
out_vars
),
self
.
_valid_vars
(
out_vars
),
self
.
_create_scope_vec
(),
self
.
_create_scope_vec
(),
self
.
_double_grads
,
self
.
_cuda_graph_vec
,
self
.
_double_grads
,
*
attrs
)
self
.
_cuda_graph_vec
,
*
attrs
)
restored_nest_out
=
self
.
_restore_out
(
out_vars
)
restored_nest_out
=
self
.
_restore_out
(
out_vars
)
return
self
.
_remove_no_value
(
restored_nest_out
)
return
self
.
_remove_no_value
(
restored_nest_out
)
...
@@ -594,9 +668,11 @@ class PartialProgramLayer:
...
@@ -594,9 +668,11 @@ class PartialProgramLayer:
if
_in_pure_fp16_guard
():
if
_in_pure_fp16_guard
():
for
i
,
var
in
enumerate
(
in_vars
):
for
i
,
var
in
enumerate
(
in_vars
):
name
=
var
.
name
name
=
var
.
name
if
(
self
.
program
.
global_block
().
has_var
(
name
)
if
(
self
.
program
.
global_block
().
has_var
(
name
)
and
self
.
program
.
global_block
().
var
(
name
).
dtype
and
self
.
program
.
global_block
().
var
(
name
).
dtype
==
paddle
.
float16
):
==
paddle
.
float16
):
in_vars
[
i
]
=
var
.
astype
(
'float16'
)
in_vars
[
i
]
=
var
.
astype
(
'float16'
)
in_vars
[
i
].
name
=
name
in_vars
[
i
].
name
=
name
...
@@ -627,25 +703,32 @@ class PartialProgramLayer:
...
@@ -627,25 +703,32 @@ class PartialProgramLayer:
return
self
.
_infer_program
return
self
.
_infer_program
@
switch_to_static_graph
@
switch_to_static_graph
def
_get_forward_backward_program_form
(
self
,
whole_program
,
def
_get_forward_backward_program_form
(
forward_end_op_index
):
self
,
whole_program
,
forward_end_op_index
):
forward_builded_program
=
add_build_strategy_for
(
forward_builded_program
=
add_build_strategy_for
(
whole_program
,
0
,
forward_end_op_index
,
self
.
_build_strategy
)
whole_program
,
0
,
forward_end_op_index
,
self
.
_build_strategy
)
backward_start_op_index
=
forward_end_op_index
+
2
*
len
(
backward_start_op_index
=
forward_end_op_index
+
2
*
len
(
self
.
_outputs
.
var_ids
)
self
.
_outputs
.
var_ids
)
backward_end_op_index
=
whole_program
.
desc
.
block
(
0
).
op_size
()
backward_end_op_index
=
whole_program
.
desc
.
block
(
0
).
op_size
()
backward_builded_program
=
add_build_strategy_for
(
backward_builded_program
=
add_build_strategy_for
(
whole_program
,
backward_start_op_index
,
backward_end_op_index
,
whole_program
,
self
.
_build_strategy
)
backward_start_op_index
,
self
.
_apply_inplace_pass
(
forward_builded_program
,
backward_end_op_index
,
backward_builded_program
)
self
.
_build_strategy
,
)
self
.
_apply_inplace_pass
(
forward_builded_program
,
backward_builded_program
)
return
[
forward_builded_program
,
backward_builded_program
]
return
[
forward_builded_program
,
backward_builded_program
]
def
_apply_inplace_pass
(
self
,
forward_program
,
backward_program
):
def
_apply_inplace_pass
(
self
,
forward_program
,
backward_program
):
attr_types
=
{
attr_types
=
{
"use_cuda"
:
"bool"
,
"use_cuda"
:
"bool"
,
"mem_opt_skip_vars"
:
"list[str]"
,
"mem_opt_skip_vars"
:
"list[str]"
,
"for_partial_block"
:
"bool"
"for_partial_block"
:
"bool"
,
}
}
empty_startup_program
=
paddle
.
static
.
Program
()
empty_startup_program
=
paddle
.
static
.
Program
()
use_cuda
=
True
if
core
.
is_compiled_with_cuda
()
else
False
use_cuda
=
True
if
core
.
is_compiled_with_cuda
()
else
False
...
@@ -667,22 +750,33 @@ class PartialProgramLayer:
...
@@ -667,22 +750,33 @@ class PartialProgramLayer:
forward_mem_opt_skip_vars
.
append
(
var
.
desc
.
name
())
forward_mem_opt_skip_vars
.
append
(
var
.
desc
.
name
())
backward_mem_opt_skip_vars
.
append
(
var
.
desc
.
name
())
backward_mem_opt_skip_vars
.
append
(
var
.
desc
.
name
())
for
var_name
in
core
.
parse_safe_eager_deletion_skip_vars
(
for
var_name
in
core
.
parse_safe_eager_deletion_skip_vars
(
backward_program
.
desc
):
backward_program
.
desc
):
forward_mem_opt_skip_vars
.
append
(
var_name
)
forward_mem_opt_skip_vars
.
append
(
var_name
)
attrs
=
{
attrs
=
{
"use_cuda"
:
use_cuda
,
"use_cuda"
:
use_cuda
,
"mem_opt_skip_vars"
:
forward_mem_opt_skip_vars
,
"mem_opt_skip_vars"
:
forward_mem_opt_skip_vars
,
"for_partial_block"
:
True
"for_partial_block"
:
True
,
}
}
_apply_pass
(
forward_program
,
empty_startup_program
,
_apply_pass
(
"buffer_shared_inplace_pass"
,
attrs
,
attr_types
)
forward_program
,
empty_startup_program
,
"buffer_shared_inplace_pass"
,
attrs
,
attr_types
,
)
attrs
=
{
attrs
=
{
"use_cuda"
:
use_cuda
,
"use_cuda"
:
use_cuda
,
"mem_opt_skip_vars"
:
backward_mem_opt_skip_vars
,
"mem_opt_skip_vars"
:
backward_mem_opt_skip_vars
,
"for_partial_block"
:
True
"for_partial_block"
:
True
,
}
}
_apply_pass
(
backward_program
,
empty_startup_program
,
_apply_pass
(
"buffer_shared_inplace_pass"
,
attrs
,
attr_types
)
backward_program
,
empty_startup_program
,
"buffer_shared_inplace_pass"
,
attrs
,
attr_types
,
)
def
_prepare
(
self
,
inputs
):
def
_prepare
(
self
,
inputs
):
"""
"""
...
@@ -698,23 +792,28 @@ class PartialProgramLayer:
...
@@ -698,23 +792,28 @@ class PartialProgramLayer:
if
isinstance
(
value
,
np
.
ndarray
):
if
isinstance
(
value
,
np
.
ndarray
):
var
=
None
var
=
None
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
_in_eager_mode_
:
var
=
core
.
VarBase
(
value
=
value
,
var
=
core
.
VarBase
(
value
=
value
,
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
persistable
=
False
,
persistable
=
False
,
place
=
expected_place
,
place
=
expected_place
,
zero_copy
=
True
)
zero_copy
=
True
,
)
else
:
else
:
var
=
core
.
eager
.
Tensor
(
value
=
value
,
var
=
core
.
eager
.
Tensor
(
value
=
value
,
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
name
=
self
.
_inputs
[
i
].
desc
.
name
(),
persistable
=
False
,
persistable
=
False
,
place
=
expected_place
,
place
=
expected_place
,
zero_copy
=
True
)
zero_copy
=
True
,
)
elif
isinstance
(
value
,
(
core
.
VarBase
,
core
.
eager
.
Tensor
)):
elif
isinstance
(
value
,
(
core
.
VarBase
,
core
.
eager
.
Tensor
)):
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# NOTE(Aurelius84): If var is on CPUPlace, it will be transformed multi times
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
# into CUDAPlace when it's as input of multi Ops. so we move it in advance
# to avoid this problem.
# to avoid this problem.
if
value
.
stop_gradient
and
not
value
.
place
.
_equals
(
if
value
.
stop_gradient
and
not
value
.
place
.
_equals
(
expected_place
):
expected_place
):
var
=
value
.
_copy_to
(
expected_place
,
False
)
var
=
value
.
_copy_to
(
expected_place
,
False
)
var
.
stop_gradient
=
True
var
.
stop_gradient
=
True
else
:
else
:
...
@@ -737,12 +836,21 @@ class PartialProgramLayer:
...
@@ -737,12 +836,21 @@ class PartialProgramLayer:
return
out_varbase_map
[
var_desc
.
name
()]
return
out_varbase_map
[
var_desc
.
name
()]
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
_in_eager_mode_
:
var_base
=
core
.
VarBase
(
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_base
=
core
.
VarBase
(
var_desc
.
name
(),
var_desc
.
type
(),
False
)
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_desc
.
name
(),
var_desc
.
type
(),
False
,
)
else
:
else
:
var_base
=
core
.
eager
.
Tensor
(
var_desc
.
dtype
(),
var_desc
.
shape
(),
var_base
=
core
.
eager
.
Tensor
(
var_desc
.
name
(),
var_desc
.
type
(),
var_desc
.
dtype
(),
False
)
var_desc
.
shape
(),
var_desc
.
name
(),
var_desc
.
type
(),
False
,
)
var_base
.
stop_gradient
=
var
.
stop_gradient
var_base
.
stop_gradient
=
var
.
stop_gradient
out_varbase_map
[
var_desc
.
name
()]
=
var_base
out_varbase_map
[
var_desc
.
name
()]
=
var_base
return
var_base
return
var_base
...
@@ -755,20 +863,30 @@ class PartialProgramLayer:
...
@@ -755,20 +863,30 @@ class PartialProgramLayer:
def
_create_scope_vec
(
self
,
program_id
=
None
,
use_scope_cache
=
False
):
def
_create_scope_vec
(
self
,
program_id
=
None
,
use_scope_cache
=
False
):
# Hold forward variables
# Hold forward variables
tmp_scope_vec
=
None
tmp_scope_vec
=
None
inner_scope
=
self
.
_get_scope
(
program_id
=
program_id
,
inner_scope
=
self
.
_get_scope
(
use_scope_cache
=
use_scope_cache
)
program_id
=
program_id
,
use_scope_cache
=
use_scope_cache
)
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
_in_eager_mode_
:
tmp_scope_vec
=
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
[],
tmp_scope_vec
=
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
[],
"program_out_scope"
,
"program_out_scope"
,
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
True
)
core
.
VarDesc
.
VarType
.
STEP_SCOPES
,
True
,
)
tmp_scope_vec
.
value
().
set_scope
(
inner_scope
)
tmp_scope_vec
.
value
().
set_scope
(
inner_scope
)
else
:
else
:
tmp_scope_vec
=
[
inner_scope
]
tmp_scope_vec
=
[
inner_scope
]
return
tmp_scope_vec
return
tmp_scope_vec
def
_create_cuda_graph_vec
(
self
):
def
_create_cuda_graph_vec
(
self
):
var
=
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
[],
"cuda_graph"
,
var
=
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
RAW
,
True
)
core
.
VarDesc
.
VarType
.
FP32
,
[],
"cuda_graph"
,
core
.
VarDesc
.
VarType
.
RAW
,
True
,
)
var
.
stop_gradient
=
True
var
.
stop_gradient
=
True
return
var
return
var
...
@@ -791,8 +909,9 @@ class PartialProgramLayer:
...
@@ -791,8 +909,9 @@ class PartialProgramLayer:
return
main_program
.
clone
(
for_test
=
True
)
return
main_program
.
clone
(
for_test
=
True
)
def
_is_no_value
(
self
,
var
):
def
_is_no_value
(
self
,
var
):
if
isinstance
(
var
,
if
isinstance
(
var
,
(
core
.
VarBase
,
core
.
eager
.
Tensor
))
and
var
.
shape
==
[
(
core
.
VarBase
,
core
.
eager
.
Tensor
))
and
var
.
shape
==
[
1
]:
1
]:
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
# NOTE: .numpy() will insert MemcpySync operation, it hits performance.
if
var
.
numpy
()[
0
]
==
RETURN_NO_VALUE_MAGIC_NUM
:
if
var
.
numpy
()[
0
]
==
RETURN_NO_VALUE_MAGIC_NUM
:
return
True
return
True
...
@@ -808,13 +927,14 @@ class PartialProgramLayer:
...
@@ -808,13 +927,14 @@ class PartialProgramLayer:
return
out_vars
return
out_vars
elif
isinstance
(
out_vars
,
(
tuple
,
list
)):
elif
isinstance
(
out_vars
,
(
tuple
,
list
)):
if
isinstance
(
out_vars
,
tuple
):
if
isinstance
(
out_vars
,
tuple
):
res
=
tuple
(
var
for
var
in
out_vars
res
=
tuple
(
if
not
self
.
_is_no_value
(
var
))
var
for
var
in
out_vars
if
not
self
.
_is_no_value
(
var
)
)
else
:
else
:
# isinstance(out_vars, list)
# isinstance(out_vars, list)
res
=
[
var
for
var
in
out_vars
if
not
self
.
_is_no_value
(
var
)]
res
=
[
var
for
var
in
out_vars
if
not
self
.
_is_no_value
(
var
)]
has_removed
=
(
len
(
out_vars
)
>
len
(
res
)
)
has_removed
=
len
(
out_vars
)
>
len
(
res
)
# len(out_vars) > len(res) means we have removed var. This is
# len(out_vars) > len(res) means we have removed var. This is
# preventing out_vars is empty or just one element at the beginning
# preventing out_vars is empty or just one element at the beginning
if
len
(
res
)
==
0
and
has_removed
:
if
len
(
res
)
==
0
and
has_removed
:
...
@@ -835,7 +955,8 @@ class PartialProgramLayer:
...
@@ -835,7 +955,8 @@ class PartialProgramLayer:
for
param
in
params
:
for
param
in
params
:
grad_name
=
param
.
name
+
core
.
grad_var_suffix
()
grad_name
=
param
.
name
+
core
.
grad_var_suffix
()
grad_var
=
train_program
.
desc
.
block
(
0
).
find_var
(
grad_var
=
train_program
.
desc
.
block
(
0
).
find_var
(
cpt
.
to_bytes
(
grad_name
))
cpt
.
to_bytes
(
grad_name
)
)
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
# NOTE: cannot find var desc maybe no problem, such as in batch_norm
if
grad_var
is
None
:
if
grad_var
is
None
:
continue
continue
...
@@ -864,15 +985,18 @@ class PartialProgramLayer:
...
@@ -864,15 +985,18 @@ class PartialProgramLayer:
if
not
isinstance
(
self
.
_params
,
(
list
,
tuple
)):
if
not
isinstance
(
self
.
_params
,
(
list
,
tuple
)):
raise
TypeError
(
raise
TypeError
(
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
"Type of self._params in PartialProgramLayer should be list or tuple, but received %s."
%
type
(
self
.
_params
))
%
type
(
self
.
_params
)
)
param_and_buffer_names_set
=
set
()
param_and_buffer_names_set
=
set
()
for
i
,
var
in
enumerate
(
self
.
_params
):
for
i
,
var
in
enumerate
(
self
.
_params
):
# self._params constains parameters and buffers with persistable=True.
# self._params constains parameters and buffers with persistable=True.
if
not
isinstance
(
var
,
(
core
.
VarBase
,
core
.
eager
.
Tensor
)):
if
not
isinstance
(
var
,
(
core
.
VarBase
,
core
.
eager
.
Tensor
)):
raise
TypeError
(
raise
TypeError
(
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
'Type of self._params[{}] in PartialProgramLayer should be Parameter or Variable, but received {}.'
.
format
(
.
format
(
i
,
type
(
var
)))
i
,
type
(
var
)
)
)
param_and_buffer_names_set
.
add
(
var
.
name
)
param_and_buffer_names_set
.
add
(
var
.
name
)
for
block
in
main_program
.
blocks
:
for
block
in
main_program
.
blocks
:
...
@@ -886,7 +1010,8 @@ class PartialProgramLayer:
...
@@ -886,7 +1010,8 @@ class PartialProgramLayer:
"
\n\t
Revise suggestion: "
"
\n\t
Revise suggestion: "
"
\n\t\t
1. Please ensure all your sublayers are inheritted from nn.Layer."
"
\n\t\t
1. Please ensure all your sublayers are inheritted from nn.Layer."
"
\n\t\t
2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
"
\n\t\t
2. Please use nn.ParameterList and nn.LayerList as container instead of using a native Python container such as List"
%
name
)
%
name
)
def
_valid_vars
(
self
,
vars
):
def
_valid_vars
(
self
,
vars
):
"""
"""
...
@@ -903,13 +1028,23 @@ def _create_fake_var():
...
@@ -903,13 +1028,23 @@ def _create_fake_var():
"""
"""
if
not
framework
.
_in_eager_mode_
:
if
not
framework
.
_in_eager_mode_
:
return
[
return
[
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
FP32
,
[],
"Fake_var"
,
core
.
VarBase
(
core
.
VarDesc
.
VarType
.
RAW
,
False
)
core
.
VarDesc
.
VarType
.
FP32
,
[],
"Fake_var"
,
core
.
VarDesc
.
VarType
.
RAW
,
False
,
)
]
]
else
:
else
:
return
[
return
[
core
.
eager
.
Tensor
(
core
.
VarDesc
.
VarType
.
FP32
,
[],
"Fake_var"
,
core
.
eager
.
Tensor
(
core
.
VarDesc
.
VarType
.
RAW
,
False
)
core
.
VarDesc
.
VarType
.
FP32
,
[],
"Fake_var"
,
core
.
VarDesc
.
VarType
.
RAW
,
False
,
)
]
]
...
@@ -918,23 +1053,27 @@ def partial_program_from(concrete_program):
...
@@ -918,23 +1053,27 @@ def partial_program_from(concrete_program):
if
inputs
and
isinstance
(
inputs
[
0
],
layers
.
Layer
):
if
inputs
and
isinstance
(
inputs
[
0
],
layers
.
Layer
):
inputs
=
inputs
[
1
:]
inputs
=
inputs
[
1
:]
return
PartialProgramLayer
(
concrete_program
.
main_program
,
inputs
,
return
PartialProgramLayer
(
concrete_program
.
main_program
,
inputs
,
concrete_program
.
outputs
,
concrete_program
.
outputs
,
concrete_program
.
parameters
,
concrete_program
.
parameters
,
**
concrete_program
.
kwargs
)
**
concrete_program
.
kwargs
)
@
switch_to_static_graph
@
switch_to_static_graph
def
add_build_strategy_for
(
program
,
def
add_build_strategy_for
(
start_op_index
,
program
,
start_op_index
,
end_op_index
,
build_strategy
=
None
end_op_index
,
):
build_strategy
=
None
):
if
start_op_index
<
end_op_index
:
if
(
start_op_index
<
end_op_index
):
compiled_program
=
paddle
.
static
.
CompiledProgram
(
compiled_program
=
paddle
.
static
.
CompiledProgram
(
core
.
Graph
(
program
.
desc
,
start_op_index
,
end_op_index
),
core
.
Graph
(
program
.
desc
,
start_op_index
,
end_op_index
),
build_strategy
=
build_strategy
)
build_strategy
=
build_strategy
,
compiled_program
.
_compile
(
core
.
Scope
(),
)
framework
.
_current_expected_place
())
compiled_program
.
_compile
(
core
.
Scope
(),
framework
.
_current_expected_place
()
)
ir_graph
=
framework
.
IrGraph
(
compiled_program
.
_graph
)
ir_graph
=
framework
.
IrGraph
(
compiled_program
.
_graph
)
builded_program
=
ir_graph
.
to_program
()
builded_program
=
ir_graph
.
to_program
()
if
hasattr
(
compiled_program
.
_program
,
'lr_sheduler'
):
if
hasattr
(
compiled_program
.
_program
,
'lr_sheduler'
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录