Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
93535c59
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
93535c59
编写于
4月 29, 2021
作者:
A
arlesniak
提交者:
GitHub
4月 29, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added pure_bf16 mode (#32281) (#32681)
This is cherry-pick of #32281
上级
ca2ef414
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
699 addition
and
86 deletion
+699
-86
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+1
-0
python/paddle/fluid/contrib/mixed_precision/__init__.py
python/paddle/fluid/contrib/mixed_precision/__init__.py
+0
-3
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
+3
-1
python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
...on/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
+11
-3
python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
...on/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
+213
-6
python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py
...on/paddle/fluid/contrib/mixed_precision/bf16/decorator.py
+318
-0
python/paddle/fluid/contrib/tests/test_bf16_utils.py
python/paddle/fluid/contrib/tests/test_bf16_utils.py
+13
-13
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
+23
-13
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-1
python/paddle/fluid/layers/tensor.py
python/paddle/fluid/layers/tensor.py
+3
-4
python/paddle/fluid/tests/book/test_fit_a_line.py
python/paddle/fluid/tests/book/test_fit_a_line.py
+53
-25
python/paddle/fluid/tests/book/test_word2vec_book.py
python/paddle/fluid/tests/book/test_word2vec_book.py
+30
-9
python/paddle/fluid/tests/unittests/test_optimizer_grad.py
python/paddle/fluid/tests/unittests/test_optimizer_grad.py
+28
-4
python/paddle/static/amp/__init__.py
python/paddle/static/amp/__init__.py
+1
-4
未找到文件。
paddle/fluid/operators/assign_op.cc
浏览文件 @
93535c59
...
...
@@ -162,6 +162,7 @@ REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
int64_t
,
ops
::
AssignKernel
,
bool
,
ops
::
AssignKernel
,
plat
::
float16
,
ops
::
AssignKernel
,
plat
::
bfloat16
,
ops
::
AssignKernel
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
...
...
python/paddle/fluid/contrib/mixed_precision/__init__.py
浏览文件 @
93535c59
...
...
@@ -20,10 +20,7 @@ from . import fp16_lists
from
.fp16_lists
import
*
from
.
import
fp16_utils
from
.fp16_utils
import
*
from
.
import
bf16
from
.bf16
import
*
__all__
=
decorator
.
__all__
__all__
+=
fp16_lists
.
__all__
__all__
+=
fp16_utils
.
__all__
__all__
+=
bf16
.
__all__
python/paddle/fluid/contrib/mixed_precision/bf16/__init__.py
浏览文件 @
93535c59
...
...
@@ -18,7 +18,9 @@ from . import amp_lists
from
.amp_lists
import
*
from
.
import
amp_utils
from
.amp_utils
import
*
from
.
import
decorator
from
.decorator
import
*
__all__
=
[]
__all__
=
decorator
.
__all__
__all__
+=
amp_lists
.
__all__
__all__
+=
amp_utils
.
__all__
python/paddle/fluid/contrib/mixed_precision/bf16/amp_lists.py
浏览文件 @
93535c59
...
...
@@ -13,8 +13,10 @@
# limitations under the License.
import
copy
from
paddle.fluid
import
core
from
..fp16_lists
import
white_list
as
white_list_fp16
,
black_list
as
black_list_fp16
,
\
gray_list
as
gray_list_fp16
,
unsupported_fp16_list
gray_list
as
gray_list_fp16
__all__
=
[
"AutoMixedPrecisionListsBF16"
]
...
...
@@ -82,11 +84,17 @@ bf16_list = {'elementwise_add', }
# depends on the prev_op type
gray_list
=
{
'cast'
,
'fill_constant'
,
'reduce_mean'
,
'reshape2'
,
'
lookup_tab
le'
,
'
sca
le'
,
}
unsupported_list
=
unsupported_fp16_list
.
copy
().
copy
()
_
,
_
,
_sys_unsupported_bf16_list
=
core
.
op_supported_infos
(
'CPU'
,
core
.
VarDesc
.
VarType
.
BF16
)
unsupported_list
=
_sys_unsupported_bf16_list
fp32_list
=
black_list_fp16
.
copy
().
copy
()
fp32_list
|=
white_list_fp16
fp32_list
|=
gray_list_fp16
...
...
python/paddle/fluid/contrib/mixed_precision/bf16/amp_utils.py
浏览文件 @
93535c59
...
...
@@ -14,18 +14,25 @@
# limitations under the License.
from
__future__
import
print_function
import
struct
from
....
import
core
from
....
import
framework
from
....
import
global_scope
from
....log_helper
import
get_logger
from
....wrapped_decorator
import
signature_safe_contextmanager
from
.amp_lists
import
AutoMixedPrecisionListsBF16
from
..fp16_utils
import
find_true_prev_op
,
find_true_post_op
,
_rename_arg
,
find_op_index
from
..fp16_utils
import
find_true_prev_op
,
find_true_post_op
,
_rename_arg
,
\
find_op_index
,
_rename_op_input
import
collections
import
struct
import
logging
import
numpy
as
np
__all__
=
[
"bf16_guard"
,
"rewrite_program_bf16"
,
"convert_float_to_uint16"
]
__all__
=
[
"bf16_guard"
,
"rewrite_program_bf16"
,
"cast_model_to_bf16"
,
"cast_parameters_to_bf16"
,
"convert_float_to_uint16"
]
_logger
=
get_logger
(
__name__
,
logging
.
INFO
,
fmt
=
'%(asctime)s-%(levelname)s: %(message)s'
)
...
...
@@ -126,7 +133,41 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
return
num_cast_ops
def
_insert_cast_post_op
(
block
,
op
,
idx
,
src_dtype
,
dest_dtype
,
target_name
,
op_var_rename_map
):
num_cast_ops
=
0
target_var
=
block
.
var
(
target_name
)
if
target_var
.
type
not
in
_valid_types
or
target_var
.
dtype
==
dest_dtype
:
return
num_cast_ops
assert
target_var
.
dtype
==
src_dtype
,
\
"The real dtype({}) is not equal to the src dtype({})"
.
format
(
_dtype_to_str
(
target_var
.
dtype
),
_dtype_to_str
(
src_dtype
))
cast_name
=
target_var
.
name
+
'.cast_'
+
_dtype_to_str
(
dest_dtype
)
cast_var
=
block
.
vars
.
get
(
cast_name
)
if
cast_var
is
None
or
cast_var
.
dtype
!=
dest_dtype
:
cast_var
=
block
.
create_var
(
name
=
cast_name
,
dtype
=
dest_dtype
,
persistable
=
False
,
stop_gradient
=
target_var
.
stop_gradient
)
block
.
_insert_op
(
idx
,
type
=
"cast"
,
inputs
=
{
"X"
:
target_var
},
outputs
=
{
"Out"
:
cast_var
},
attrs
=
{
"in_dtype"
:
target_var
.
dtype
,
"out_dtype"
:
cast_var
.
dtype
})
num_cast_ops
+=
1
op_var_rename_map
[
block
.
idx
][
target_var
.
name
]
=
cast_var
.
name
return
num_cast_ops
def
_is_in_fp32_varnames
(
op
,
amp_lists
):
if
not
amp_lists
.
fp32_varnames
:
return
False
for
in_name
in
op
.
input_arg_names
:
if
in_name
in
amp_lists
.
fp32_varnames
:
return
True
...
...
@@ -191,7 +232,174 @@ def bf16_guard():
yield
def
rewrite_program_bf16
(
main_prog
,
amp_lists
=
None
,
use_bf16_guard
=
False
):
def
cast_model_to_bf16
(
program
,
amp_lists
=
None
,
use_bf16_guard
=
True
):
"""
Traverse all ops in the whole model and set their inputs and outputs
to the bf16 data type. This function will do some special processing for
the batch normalization, which will keep the batchnorm's computations in FP32.
Args:
program (Program): The used program.
amp_lists (AutoMixedPrecisionListsBF16): An AutoMixedPrecisionListsBF16 object.
use_bf16_guard(bool): Determine whether to use `bf16_guard` when
constructing the program. Default True.
"""
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionListsBF16
()
global_block
=
program
.
global_block
()
keep_fp32_ops
=
set
()
to_bf16_var_names
=
set
()
to_bf16_pre_cast_ops
=
set
()
origin_ops
=
[]
for
block
in
program
.
blocks
:
origin_ops
.
extend
(
block
.
ops
)
for
block
in
program
.
blocks
:
ops
=
block
.
ops
for
op
in
ops
:
if
op
.
type
==
'create_py_reader'
or
op
.
type
==
'read'
:
continue
if
_need_keep_fp32
(
op
,
amp_lists
.
unsupported_list
,
use_bf16_guard
):
keep_fp32_ops
.
add
(
op
)
continue
# processed below
for
in_name
in
op
.
input_names
:
if
op
.
type
in
{
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
}
and
in_name
not
in
{
'X'
,
'Z'
}:
continue
for
in_var_name
in
op
.
input
(
in_name
):
in_var
=
None
try
:
in_var
=
block
.
var
(
in_var_name
)
except
ValueError
as
e
:
_logger
.
debug
(
"-- {}, try to get it in the global block --"
.
format
(
e
))
in_var
=
global_block
.
var
(
in_var_name
)
if
in_var
is
not
None
:
_logger
.
debug
(
"-- var {} is got in the global block --"
.
format
(
in_var_name
))
if
in_var
is
None
or
in_var
.
type
not
in
_valid_types
:
continue
if
in_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
in_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
BF16
)
to_bf16_var_names
.
add
(
in_var_name
)
_logger
.
debug
(
"-- op type: {}, in var name: {}, in var dtype: {} --"
.
format
(
op
.
type
,
in_var_name
,
in_var
.
dtype
))
for
out_name
in
op
.
output_names
:
if
op
.
type
in
{
'batch_norm'
,
'fused_bn_add_activation'
,
'layer_norm'
}
and
out_name
!=
'Y'
:
continue
for
out_var_name
in
op
.
output
(
out_name
):
out_var
=
None
try
:
out_var
=
block
.
var
(
out_var_name
)
except
ValueError
as
e
:
_logger
.
debug
(
"-- {}, try to get it in the global block --"
.
format
(
e
))
out_var
=
global_block
.
var
(
out_var_name
)
if
out_var
is
not
None
:
_logger
.
debug
(
"-- var {} is got in the global block --"
.
format
(
out_var_name
))
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
BF16
)
_logger
.
debug
(
"-- op type: {}, out var name: {}, out var dtype: {} --"
.
format
(
op
.
type
,
out_var_name
,
out_var
.
dtype
))
for
attr_name
in
[
'in_dtype'
,
'out_dtype'
,
'dtype'
]:
if
op
.
has_attr
(
attr_name
)
and
op
.
attr
(
attr_name
)
==
core
.
VarDesc
.
VarType
.
FP32
:
op
.
_set_attr
(
attr_name
,
core
.
VarDesc
.
VarType
.
BF16
)
if
op
.
has_attr
(
'use_mkldnn'
):
op
.
_set_attr
(
'use_mkldnn'
,
True
)
if
op
.
has_attr
(
'mkldnn_data_type'
):
op
.
_set_attr
(
'mkldnn_data_type'
,
'bfloat16'
)
# process ops in keep_fp32_ops
op_var_rename_map
=
[
collections
.
OrderedDict
()
for
_
in
range
(
len
(
program
.
blocks
))
]
for
block
in
program
.
blocks
:
ops
=
block
.
ops
idx
=
0
while
idx
<
len
(
ops
):
op
=
ops
[
idx
]
num_cast_ops
=
0
if
op
not
in
keep_fp32_ops
:
if
op
in
to_bf16_pre_cast_ops
:
in_var_cast_num
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
BF16
)
num_cast_ops
+=
in_var_cast_num
else
:
pre_cast_num
=
_insert_cast_op
(
block
,
op
,
idx
,
core
.
VarDesc
.
VarType
.
BF16
,
core
.
VarDesc
.
VarType
.
FP32
)
num_cast_ops
+=
pre_cast_num
for
out_var_name
in
op
.
output_arg_names
:
out_var
=
block
.
vars
.
get
(
out_var_name
)
if
out_var
is
None
or
out_var
.
type
not
in
_valid_types
:
continue
if
out_var
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
out_var
.
desc
.
set_dtype
(
core
.
VarDesc
.
VarType
.
FP32
)
post_ops
=
find_true_post_op
(
ops
,
op
,
out_var_name
)
for
post_op
in
post_ops
:
if
post_op
in
keep_fp32_ops
:
continue
post_cast_num
=
_insert_cast_post_op
(
block
,
op
,
idx
+
pre_cast_num
+
1
,
core
.
VarDesc
.
VarType
.
FP32
,
core
.
VarDesc
.
VarType
.
BF16
,
out_var_name
,
op_var_rename_map
)
num_cast_ops
+=
post_cast_num
idx
+=
num_cast_ops
+
1
_rename_op_input
(
program
,
op_var_rename_map
,
origin_ops
,
keep_fp32_ops
)
return
to_bf16_var_names
def
cast_parameters_to_bf16
(
place
,
program
,
scope
=
None
,
to_bf16_var_names
=
None
):
"""
Traverse all parameters in the whole model and set them to the BF16 data type.
Whereas, this function will keep parameters of batchnorms in FP32.
Args:
place(fluid.CPUPlace|fluid.CUDAPlace): `place` is used to restore the BF16 weight tensors.
program (Program): The used program.
scope(fluid.Scope, optional): `scope` is used to get the FP32 weight tensor values.
Default is None.
to_bf16_var_names(set|list, optional): The data types of vars in `to_bf16_var_names`
will be set to BF16. Usually, it is the returned
value of `cast_model_to_bf16` API.
"""
all_parameters
=
[]
for
block
in
program
.
blocks
:
all_parameters
.
extend
(
block
.
all_parameters
())
bf16_var_names
=
to_bf16_var_names
if
to_bf16_var_names
else
set
()
var_scope
=
scope
if
scope
else
global_scope
()
for
param
in
all_parameters
:
if
param
.
name
in
bf16_var_names
:
_logger
.
debug
(
"---- cast {} to bf16 dtype ----"
.
format
(
param
.
name
))
param_t
=
var_scope
.
find_var
(
param
.
name
).
get_tensor
()
data
=
np
.
array
(
param_t
)
param_t
.
set
(
convert_float_to_uint16
(
data
),
place
)
def
rewrite_program_bf16
(
main_prog
,
amp_lists
=
None
):
"""
Traverse all ops in current block and insert cast op according to
which set current op belongs to.
...
...
@@ -231,8 +439,7 @@ def rewrite_program_bf16(main_prog, amp_lists=None, use_bf16_guard=False):
fp32_op_set
.
add
(
op
)
continue
if
op
.
type
in
amp_lists
.
fp32_list
or
_need_keep_fp32
(
op
,
amp_lists
.
unsupported_list
,
use_bf16_guard
):
if
op
.
type
in
amp_lists
.
fp32_list
:
fp32_op_set
.
add
(
op
)
elif
op
.
type
in
amp_lists
.
bf16_list
:
bf16_op_set
.
add
(
op
)
...
...
python/paddle/fluid/contrib/mixed_precision/bf16/decorator.py
0 → 100644
浏览文件 @
93535c59
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
paddle.fluid
import
(
core
,
default_main_program
,
layers
,
program_guard
,
unique_name
)
from
.amp_utils
import
(
rewrite_program_bf16
,
cast_model_to_bf16
,
cast_parameters_to_bf16
)
from
.amp_lists
import
AutoMixedPrecisionListsBF16
import
types
import
warnings
__all__
=
[
"decorate_bf16"
]
class
OptimizerWithMixedPrecision
(
object
):
"""
Optimizer with mixed-precision (MP) training. This is a wrapper of a common
optimizer, plus the support of mixed-precision pre-training. The object
of this class almost has the same behavior as the common optimizer, with the
methods `minimize()`, `backward()`, `apply_gradients()` implemented.
Additionally, it enables the MP training automatically, i.e, the creation
and maintenance of master parameters, scaling of loss, etc.
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (CustomOpLists): An CustomOpLists object.
use_pure_bf16(bool): Whether to use the pure bf16 training.
use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program.
"""
def
__init__
(
self
,
optimizer
,
amp_lists
,
use_pure_bf16
,
use_bf16_guard
):
self
.
_optimizer
=
optimizer
self
.
_amp_lists
=
amp_lists
self
.
_param_grads
=
None
self
.
_train_program
=
None
self
.
_learning_rate
=
optimizer
.
_learning_rate
self
.
_learning_rate_map
=
optimizer
.
_learning_rate_map
self
.
_use_pure_bf16
=
use_pure_bf16
self
.
_use_bf16_guard
=
use_bf16_guard
self
.
_to_bf16_var_names
=
None
def
_init_amp_var
(
self
):
# Ensure the data type of learning rate vars is float32 (same as the
# master parameter dtype)
if
isinstance
(
self
.
_optimizer
.
_learning_rate
,
float
):
self
.
_optimizer
.
_learning_rate_map
[
default_main_program
()]
=
\
layers
.
create_global_var
(
name
=
unique_name
.
generate
(
"learning_rate"
),
shape
=
[
1
],
value
=
float
(
self
.
_optimizer
.
_learning_rate
),
dtype
=
'float32'
,
persistable
=
True
)
def
backward
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
,
callbacks
=
None
):
"""
Backward propagation or auto differentiation for gradients' computation.
Args:
loss (Variable): The loss Variable to minimize.
startup_program (Program|None): The startup Program for initializing
parameters in `parameter_list`.
parameter_list (list|None): A list of Variables to update.
no_grad_set (set|None): A set of Variables should be ignored.
callbacks (list|None): A list of callable objects to run when appending
backward operator for one parameter.
Returns:
A list of (param, grad), which is a tuple of a parameter and its
gradient respectively, and the scaled loss.
"""
train_program
=
loss
.
block
.
program
self
.
_train_program
=
train_program
with
program_guard
(
self
.
_train_program
,
startup_program
):
self
.
_init_amp_var
()
if
self
.
_use_pure_bf16
:
self
.
_to_bf16_var_names
=
cast_model_to_bf16
(
self
.
_train_program
,
self
.
_amp_lists
,
self
.
_use_bf16_guard
)
else
:
rewrite_program_bf16
(
self
.
_train_program
,
self
.
_amp_lists
)
if
loss
.
dtype
!=
core
.
VarDesc
.
VarType
.
FP32
:
loss
=
loss
.
astype
(
'float32'
)
params_grads
=
self
.
_optimizer
.
backward
(
loss
,
startup_program
,
parameter_list
,
no_grad_set
,
callbacks
)
return
params_grads
def
amp_init
(
self
,
place
,
scope
=
None
,
test_program
=
None
,
use_bf16_test
=
False
):
"""
Init the amp training, such as cast fp32 parameters to bf16 type.
Args:
place(CPUPlace): place is used to initialize
bf16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_bf16_test(bool): Whether to use bf16 testing.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
def run_example_code():
place = paddle.CPUPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use bf16_guard to control the range of bf16 kernels used.
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_fp32_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_fp32_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure bf16 training by setting `use_pure_bf16` to True.
optimizer = paddle.static.amp.bf16.decorate_bf16(
optimizer,
amp_list,
use_pure_bf16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
assert
self
.
_train_program
is
not
None
,
\
"Please call the minimize method first."
if
self
.
_use_pure_bf16
:
cast_parameters_to_bf16
(
place
,
self
.
_train_program
,
scope
,
self
.
_to_bf16_var_names
)
if
test_program
is
not
None
:
if
self
.
_use_pure_bf16
:
cast_model_to_bf16
(
test_program
,
self
.
_amp_lists
,
self
.
_use_bf16_guard
)
elif
use_bf16_test
:
rewrite_program_bf16
(
test_program
,
self
.
_amp_lists
)
def
apply_gradients
(
self
,
params_grads
):
"""
Apply gradients.
Args:
params_grads (list): A list of params.
Returns:
A list of optimize operators.
"""
return
self
.
_optimizer
.
apply_gradients
(
params_grads
)
def
apply_optimize
(
self
,
loss
,
startup_program
,
params_grads
):
program
=
loss
.
block
.
program
with
program_guard
(
program
,
startup_program
):
optimize_ops
=
self
.
apply_gradients
(
params_grads
)
return
optimize_ops
def
minimize
(
self
,
loss
,
startup_program
=
None
,
parameter_list
=
None
,
no_grad_set
=
None
):
"""
Perform optimization by minimizing the given loss.
Args:
loss (Variable): The loss Variable.
startup_program (Program): startup_program for initializing parameters
in `parameter_list`.
parameter_list (list): list of Variables to update.
no_grad_set (set|None): set of Variables should be ignored.
Returns:
The scaled loss by scaling factor, the list of optimize ops, and a
list of scaled parameters and gradients.
"""
opt_dict
=
self
.
_optimizer
.
__class__
.
__dict__
if
'minimize'
in
opt_dict
and
isinstance
(
opt_dict
[
'minimize'
],
types
.
FunctionType
):
warnings
.
warn
(
"The decorated optimizer has its own `minimize` method, but it will not be executed."
)
params_grads
=
self
.
backward
(
loss
,
startup_program
=
startup_program
,
parameter_list
=
parameter_list
,
no_grad_set
=
no_grad_set
)
optimize_ops
=
self
.
apply_optimize
(
loss
,
startup_program
,
params_grads
)
return
optimize_ops
,
params_grads
def
decorate_bf16
(
optimizer
,
amp_lists
=
None
,
use_pure_bf16
=
False
,
use_bf16_guard
=
None
):
"""
Decorate the given optimizer to adapt to the mixed-precision training.
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (CustomOpLists): An CustomOpLists object.
use_pure_bf16(bool): Whether to use the pure bf16 training. Default False.
use_bf16_guard(bool): Whether to use `bf16_guard` when constructing the program.
Default None, which means that its value equals to `use_pure_bf16`.
Returns:
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples 1:
.. code-block:: python
# fp32&bf16 list based strategy example
import paddle
import paddle.static as static
paddle.enable_static()
data = static.data(name='X', shape=[None, 1], dtype='float32')
hidden = static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
mp_optimizer = static.amp.decorate_bf16(optimizer=optimizer)
ops, param_grads = mp_optimizer.minimize(loss)
Examples 2:
.. code-block:: python
# pure bf16 training example
import numpy as np
import paddle
import paddle.nn.functional as F
def run_example_code():
place = paddle.CPUPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use bf16_guard to control the range of bf16 kernels used.
with paddle.static.amp.bf16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_fp32_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_fp32_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure bf16 training by setting `use_pure_bf16` to True.
optimizer = paddle.static.amp.decorate_bf16(
optimizer,
amp_list,
use_pure_bf16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
"""
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionListsBF16
()
if
use_bf16_guard
is
None
:
use_bf16_guard
=
use_pure_bf16
mp_optimizer
=
OptimizerWithMixedPrecision
(
optimizer
,
amp_lists
,
use_pure_bf16
,
use_bf16_guard
)
return
mp_optimizer
python/paddle/fluid/contrib/tests/test_bf16_utils.py
浏览文件 @
93535c59
...
...
@@ -14,7 +14,7 @@
import
copy
import
unittest
import
paddle.fluid
as
fluid
import
paddle.
fluid.contrib.mixed_precision
as
amp
import
paddle.
static.amp
as
amp
from
paddle.fluid
import
core
import
paddle
...
...
@@ -34,34 +34,34 @@ class AMPTest(unittest.TestCase):
self
.
assertEqual
(
self
.
amp_lists_
.
gray_list
,
self
.
gray_list
)
def
test_amp_lists
(
self
):
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
()
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
()
def
test_amp_lists_1
(
self
):
# 1. w={'exp}, b=None
self
.
bf16_list
.
add
(
'exp'
)
self
.
fp32_list
.
remove
(
'exp'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'exp'
})
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
({
'exp'
})
def
test_amp_lists_2
(
self
):
# 2. w={'tanh'}, b=None
self
.
fp32_list
.
remove
(
'tanh'
)
self
.
bf16_list
.
add
(
'tanh'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'tanh'
})
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
({
'tanh'
})
def
test_amp_lists_3
(
self
):
# 3. w={'lstm'}, b=None
self
.
bf16_list
.
add
(
'lstm'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
({
'lstm'
})
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
({
'lstm'
})
def
test_amp_lists_4
(
self
):
# 4. w=None, b={'elementwise_add'}
self
.
bf16_list
.
remove
(
'elementwise_add'
)
self
.
fp32_list
.
add
(
'elementwise_add'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_add'
})
def
test_amp_lists_5
(
self
):
...
...
@@ -69,28 +69,28 @@ class AMPTest(unittest.TestCase):
self
.
fp32_list
.
add
(
'elementwise_add'
)
self
.
bf16_list
.
remove
(
'elementwise_add'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_add'
})
def
test_amp_lists_6
(
self
):
# 6. w=None, b={'lstm'}
self
.
fp32_list
.
add
(
'lstm'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'lstm'
})
def
test_amp_lists_7
(
self
):
self
.
fp32_list
.
add
(
'reshape2'
)
self
.
gray_list
.
remove
(
'reshape2'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'reshape2'
})
def
test_amp_list_8
(
self
):
self
.
bf16_list
.
add
(
'reshape2'
)
self
.
gray_list
.
remove
(
'reshape2'
)
self
.
amp_lists_
=
amp
.
AutoMixedPrecisionListsBF16
(
self
.
amp_lists_
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_bf16_list
=
{
'reshape2'
})
...
...
@@ -98,7 +98,7 @@ class AMPTest2(unittest.TestCase):
def
test_amp_lists_
(
self
):
# 7. w={'lstm'} b={'lstm'}
# raise ValueError
self
.
assertRaises
(
ValueError
,
amp
.
AutoMixedPrecisionListsBF16
,
self
.
assertRaises
(
ValueError
,
amp
.
bf16
.
AutoMixedPrecisionListsBF16
,
{
'lstm'
},
{
'lstm'
})
def
test_find_op_index
(
self
):
...
...
@@ -117,10 +117,10 @@ class AMPTest2(unittest.TestCase):
type
=
"abs"
,
inputs
=
{
"X"
:
[
var1
]},
outputs
=
{
"Out"
:
[
var2
]})
op2
=
block
.
append_op
(
type
=
"abs"
,
inputs
=
{
"X"
:
[
var2
]},
outputs
=
{
"Out"
:
[
var3
]})
amp_lists_1
=
amp
.
AutoMixedPrecisionListsBF16
(
amp_lists_1
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'X'
})
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op1
,
amp_lists_1
)
amp_lists_2
=
amp
.
AutoMixedPrecisionListsBF16
(
amp_lists_2
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'Y'
})
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op2
,
amp_lists_2
)
assert
amp
.
bf16
.
amp_utils
.
_is_in_fp32_varnames
(
op1
,
amp_lists_2
)
...
...
python/paddle/fluid/contrib/tests/test_model_cast_to_bf16.py
浏览文件 @
93535c59
...
...
@@ -65,13 +65,13 @@ class TestModelCastBF16(unittest.TestCase):
fetch_list
=
fetch_list
,
return_numpy
=
(
not
with_lod
))
def
test_graph_rewrite
(
self
):
def
_graph_common
(
self
,
_amp_fun
):
size
=
3
n
=
np
.
ones
([
size
,
size
],
dtype
=
'float32'
)
*
3.2
nn
=
np
.
ones
([
size
,
size
],
dtype
=
'float32'
)
*
-
2.7
n_bf16
=
amp
.
convert_float_to_uint16
(
n
)
nn_bf16
=
amp
.
convert_float_to_uint16
(
nn
)
n_bf16
=
amp
.
bf16
.
convert_float_to_uint16
(
n
)
nn_bf16
=
amp
.
bf16
.
convert_float_to_uint16
(
nn
)
with
self
.
static_graph
():
t_bf16
=
layers
.
data
(
...
...
@@ -85,12 +85,12 @@ class TestModelCastBF16(unittest.TestCase):
ret
=
layers
.
elementwise_mul
(
ret
,
t
)
ret
=
layers
.
reshape
(
ret
,
[
0
,
0
])
with
amp
.
bf16_guard
():
with
amp
.
bf16
.
bf16
_guard
():
ret_bf16
=
layers
.
elementwise_add
(
t_bf16
,
tt_bf16
)
ret_bf16
=
layers
.
elementwise_mul
(
ret_bf16
,
t_bf16
)
ret_bf16
=
layers
.
reshape
(
ret_bf16
,
[
0
,
0
])
with
amp
.
bf16_guard
():
with
amp
.
bf16
.
bf16
_guard
():
ret_fp32bf16
=
layers
.
elementwise_add
(
t
,
tt
)
ret_fp32bf16
=
layers
.
elementwise_mul
(
ret_fp32bf16
,
t
)
ret_fp32bf16
=
layers
.
reshape
(
ret_fp32bf16
,
[
0
,
0
])
...
...
@@ -103,7 +103,7 @@ class TestModelCastBF16(unittest.TestCase):
'tt_bf16'
:
nn_bf16
,
},
fetch_list
=
[
ret_bf16
,
ret
,
ret_fp32bf16
],
amp_fun
=
lambda
prog
:
amp
.
rewrite_program_bf16
(
prog
,
use_bf16_guard
=
True
))
amp_fun
=
lambda
prog
:
amp
.
bf16
.
rewrite_program_bf16
(
prog
))
self
.
assertTrue
(
np
.
allclose
(
static_ret_bf16
,
static_ret
,
1e-2
))
self
.
assertTrue
(
np
.
allclose
(
static_ret_bf16
,
ret_fp32bf16
,
1e-2
))
...
...
@@ -112,7 +112,7 @@ class TestModelCastBF16(unittest.TestCase):
t
=
layers
.
data
(
name
=
't'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
tt
=
layers
.
data
(
name
=
'tt'
,
shape
=
[
size
,
size
],
dtype
=
'float32'
)
with
amp
.
bf16_guard
():
with
amp
.
bf16
.
bf16
_guard
():
ret
=
layers
.
elementwise_add
(
t
,
tt
)
ret
=
layers
.
reshape
(
ret
,
[
0
,
0
],
act
=
'elu'
)
ret
=
layers
.
elementwise_mul
(
ret
,
t
)
...
...
@@ -122,17 +122,27 @@ class TestModelCastBF16(unittest.TestCase):
self
.
get_static_graph_result
(
feed
=
{
't'
:
n
,
'tt'
:
nn
},
fetch_list
=
[
ret
],
amp_fun
=
lambda
prog
:
amp
.
rewrite_program_bf16
(
prog
,
amp
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'elementwise_add_0.tmp_0'
}),
use_bf16_guard
=
True
)
amp_fun
=
_amp_fun
)
self
.
assertTrue
(
static_ret_bf16
,
np
.
ones
(
[
size
,
size
],
dtype
=
'float32'
)
*
-
1.1
)
def
test_graph_rewrite
(
self
):
self
.
_graph_common
(
lambda
prog
:
amp
.
bf16
.
rewrite_program_bf16
(
prog
,
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_varnames
=
{
'elementwise_add_0.tmp_0'
}),
))
def
test_graph_cast
(
self
):
self
.
_graph_common
(
lambda
prog
:
amp
.
bf16
.
cast_model_to_bf16
(
prog
,
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_mul'
}),
use_bf16_guard
=
True
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/layers/nn.py
浏览文件 @
93535c59
...
...
@@ -332,7 +332,8 @@ def fc(input,
for i, input_x in enumerate(input):
check_type(input_x, 'input[' + str(i) + ']', Variable, 'fc')
dtype = helper.input_dtype()
check_dtype(dtype, 'input', ['float16', 'float32', 'float64'], 'fc')
check_dtype(dtype, 'input', ['float16', 'uint16', 'float32', 'float64'],
'fc')
mul_results = []
for input_var, param_attr in helper.iter_inputs_and_params():
input_shape = input_var.shape
...
...
python/paddle/fluid/layers/tensor.py
浏览文件 @
93535c59
...
...
@@ -582,10 +582,9 @@ def assign(input, output=None):
input
=
numpy
.
array
(
input
)
if
isinstance
(
input
,
Variable
):
check_dtype
(
input
.
dtype
,
'input'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
],
'assign'
,
'(When the type of input in assign is Variable.)'
)
check_dtype
(
input
.
dtype
,
'input'
,
[
'float16'
,
'uint16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'bool'
],
'assign'
,
'(When the type of input in assign is Variable.)'
)
if
output
is
None
:
output
=
helper
.
create_variable_for_type_inference
(
dtype
=
input
.
dtype
)
...
...
python/paddle/fluid/tests/book/test_fit_a_line.py
浏览文件 @
93535c59
...
...
@@ -16,6 +16,8 @@ from __future__ import print_function
import
paddle
import
paddle.fluid
as
fluid
import
paddle.static.amp
as
amp
import
contextlib
import
numpy
import
unittest
...
...
@@ -26,19 +28,34 @@ import os
paddle
.
enable_static
()
def
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
):
def
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
,
pure_bf16
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
13
],
dtype
=
'float32'
)
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
1
],
dtype
=
'float32'
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
if
use_bf16
:
if
not
pure_bf16
:
with
amp
.
bf16
.
bf16_guard
():
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
else
:
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
with
amp
.
bf16
.
bf16_guard
():
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
else
:
y_predict
=
fluid
.
layers
.
fc
(
input
=
x
,
size
=
1
,
act
=
None
)
cost
=
fluid
.
layers
.
square_error_cost
(
input
=
y_predict
,
label
=
y
)
avg_cost
=
fluid
.
layers
.
mean
(
cost
)
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
if
use_bf16
:
paddle
.
static
.
amp
.
rewrite_program_bf16
(
fluid
.
default_main_program
())
sgd_optimizer
=
amp
.
bf16
.
decorate_bf16
(
sgd_optimizer
,
amp_lists
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(),
use_bf16_guard
=
False
,
use_pure_bf16
=
pure_bf16
)
sgd_optimizer
.
minimize
(
avg_cost
)
BATCH_SIZE
=
20
...
...
@@ -54,6 +71,10 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
def
train_loop
(
main_program
):
feeder
=
fluid
.
DataFeeder
(
place
=
place
,
feed_list
=
[
x
,
y
])
exe
.
run
(
fluid
.
default_startup_program
())
test_prog
=
main_program
.
clone
(
for_test
=
True
)
if
pure_bf16
:
sgd_optimizer
.
amp_init
(
exe
.
place
,
test_program
=
test_prog
,
use_bf16_test
=
True
)
PASS_NUM
=
100
for
pass_id
in
range
(
PASS_NUM
):
...
...
@@ -61,9 +82,8 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
avg_loss_value
,
=
exe
.
run
(
main_program
,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
])
print
(
avg_loss_value
)
if
avg_loss_value
[
0
]
<
10.0
:
if
save_dirname
is
not
None
:
if
avg_loss_value
[
0
]
<
10.0
or
pure_bf16
:
if
save_dirname
is
not
None
and
not
pure_bf16
:
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
'x'
],
[
y_predict
],
exe
)
return
...
...
@@ -97,7 +117,7 @@ def train(use_cuda, save_dirname, is_local, use_bf16):
train_loop
(
t
.
get_trainer_program
())
def
infer
(
use_cuda
,
save_dirname
=
None
):
def
infer
(
use_cuda
,
save_dirname
=
None
,
use_bf16
=
False
):
if
save_dirname
is
None
:
return
...
...
@@ -135,7 +155,7 @@ def infer(use_cuda, save_dirname=None):
print
(
"ground truth: "
,
test_label
)
def
main
(
use_cuda
,
is_local
=
True
,
use_bf16
=
False
):
def
main
(
use_cuda
,
is_local
=
True
,
use_bf16
=
False
,
pure_bf16
=
False
):
if
use_cuda
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
...
...
@@ -145,11 +165,22 @@ def main(use_cuda, is_local=True, use_bf16=False):
# Directory for saving the trained model
save_dirname
=
"fit_a_line.inference.model"
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
)
infer
(
use_cuda
,
save_dirname
)
train
(
use_cuda
,
save_dirname
,
is_local
,
use_bf16
,
pure_bf16
)
infer
(
use_cuda
,
save_dirname
,
use_bf16
)
class
TestFitALineBase
(
unittest
.
TestCase
):
@
contextlib
.
contextmanager
def
program_scope_guard
(
self
):
prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
yield
class
TestFitALine
(
unittest
.
TestC
ase
):
class
TestFitALine
(
TestFitALineB
ase
):
def
test_cpu
(
self
):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
False
)
...
...
@@ -158,20 +189,17 @@ class TestFitALine(unittest.TestCase):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
True
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
supports_bfloat16
(),
"place does not support BF16 evaluation"
)
@
unittest
.
skipIf
(
not
fluid
.
core
.
supports_bfloat16
(),
"place does not support BF16 evaluation"
)
class
TestFitALineBF16
(
TestFitALineBase
):
def
test_bf16
(
self
):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
False
,
use_bf16
=
True
)
@
contextlib
.
contextmanager
def
program_scope_guard
(
self
):
prog
=
fluid
.
Program
()
startup_prog
=
fluid
.
Program
()
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
yield
def
test_pure_bf16
(
self
):
with
self
.
program_scope_guard
():
main
(
use_cuda
=
False
,
use_bf16
=
True
,
pure_bf16
=
True
)
if
__name__
==
'__main__'
:
...
...
python/paddle/fluid/tests/book/test_word2vec_book.py
浏览文件 @
93535c59
...
...
@@ -44,7 +44,8 @@ def train(target,
is_parallel
,
save_dirname
,
is_local
=
True
,
use_bf16
=
False
):
use_bf16
=
False
,
pure_bf16
=
False
):
PASS_NUM
=
100
EMBED_SIZE
=
32
HIDDEN_SIZE
=
256
...
...
@@ -107,7 +108,13 @@ def train(target,
sgd_optimizer
=
fluid
.
optimizer
.
SGD
(
learning_rate
=
0.001
)
if
use_bf16
:
paddle
.
static
.
amp
.
rewrite_program_bf16
(
fluid
.
default_main_program
())
sgd_optimizer
=
paddle
.
static
.
amp
.
bf16
.
decorate_bf16
(
sgd_optimizer
,
amp_lists
=
paddle
.
static
.
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'softmax'
,
'concat'
},
),
use_bf16_guard
=
False
,
use_pure_bf16
=
pure_bf16
)
sgd_optimizer
.
minimize
(
avg_cost
)
train_reader
=
paddle
.
batch
(
...
...
@@ -121,6 +128,8 @@ def train(target,
def
train_loop
(
main_program
):
exe
.
run
(
fluid
.
default_startup_program
())
if
pure_bf16
:
sgd_optimizer
.
amp_init
(
exe
.
place
)
for
pass_id
in
range
(
PASS_NUM
):
for
data
in
train_reader
():
...
...
@@ -128,7 +137,7 @@ def train(target,
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
avg_cost
])
if
avg_cost_np
[
0
]
<
5.0
:
if
save_dirname
is
not
None
:
if
save_dirname
is
not
None
and
not
pure_bf16
:
fluid
.
io
.
save_inference_model
(
save_dirname
,
[
'firstw'
,
'secondw'
,
'thirdw'
,
'forthw'
],
[
predict_word
],
exe
)
...
...
@@ -246,7 +255,7 @@ def infer(target, save_dirname=None):
assert
np
.
isclose
(
a
,
b
,
rtol
=
5e-5
),
"a: {}, b: {}"
.
format
(
a
,
b
)
def
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
):
def
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
,
pure_bf16
):
if
target
==
"cuda"
and
not
fluid
.
core
.
is_compiled_with_cuda
():
return
if
target
==
"xpu"
and
not
fluid
.
core
.
is_compiled_with_xpu
():
...
...
@@ -265,7 +274,13 @@ def main(target, is_sparse, is_parallel, use_bf16):
# so only inference is turned on.
train
(
"cpu"
,
is_sparse
,
is_parallel
,
save_dirname
)
else
:
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
,
use_bf16
=
use_bf16
)
train
(
target
,
is_sparse
,
is_parallel
,
save_dirname
,
use_bf16
=
use_bf16
,
pure_bf16
=
pure_bf16
)
infer
(
target
,
save_dirname
)
...
...
@@ -278,10 +293,15 @@ class W2VTest(unittest.TestCase):
pass
def
inject_test_method
(
target
,
is_sparse
,
is_parallel
,
use_bf16
=
False
):
def
inject_test_method
(
target
,
is_sparse
,
is_parallel
,
use_bf16
=
False
,
pure_bf16
=
False
):
fn_name
=
"test_{0}_{1}_{2}{3}"
.
format
(
target
,
"sparse"
if
is_sparse
else
"dense"
,
"parallel"
if
is_parallel
else
"normal"
,
"_bf16"
if
is_parallel
else
"normal"
,
"_purebf16"
if
pure_bf16
else
"_bf16"
if
use_bf16
else
""
)
def
__impl__
(
*
args
,
**
kwargs
):
...
...
@@ -290,7 +310,7 @@ def inject_test_method(target, is_sparse, is_parallel, use_bf16=False):
scope
=
fluid
.
core
.
Scope
()
with
fluid
.
scope_guard
(
scope
):
with
fluid
.
program_guard
(
prog
,
startup_prog
):
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
)
main
(
target
,
is_sparse
,
is_parallel
,
use_bf16
,
pure_bf16
)
if
(
not
fluid
.
core
.
is_compiled_with_cuda
()
or
target
==
"cuda"
)
and
is_sparse
:
...
...
@@ -307,7 +327,8 @@ for target in ("cuda", "cpu", "xpu"):
for
is_sparse
in
(
False
,
True
):
for
is_parallel
in
(
False
,
):
inject_test_method
(
target
,
is_sparse
,
is_parallel
)
inject_test_method
(
"cpu"
,
False
,
False
,
use_bf16
=
True
)
inject_test_method
(
"cpu"
,
False
,
False
,
True
)
inject_test_method
(
"cpu"
,
False
,
False
,
True
,
True
)
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_optimizer_grad.py
浏览文件 @
93535c59
...
...
@@ -64,7 +64,7 @@ class SimpleNetWithCond(object):
return
grads
def
build_net
(
self
,
cond_i
):
def
build_net
(
self
,
cond_i
,
use_bf16
=
False
):
"""
pseudo code:
sum_xy = x + y
...
...
@@ -122,13 +122,22 @@ class SimpleNetWithCond(object):
sum_cond
=
fluid
.
layers
.
cond
(
cond_i
>
1.0
,
cond_true
,
cond_false
)
sum_all
=
fluid
.
layers
.
sum
([
sum_xy
,
sub_yz
,
sum_cond
])
mean_out
=
fluid
.
layers
.
mean
(
sum_all
)
if
use_bf16
:
import
paddle.static.amp
as
amp
self
.
optimizer
=
amp
.
bf16
.
decorate_bf16
(
self
.
optimizer
,
amp_lists
=
amp
.
bf16
.
AutoMixedPrecisionListsBF16
(
custom_fp32_list
=
{
'elementwise_add'
}),
use_bf16_guard
=
False
,
use_pure_bf16
=
True
)
self
.
optimizer
.
minimize
(
mean_out
)
fetch_list
=
[
"param_x"
,
"param_z"
]
if
self
.
y_no_grad
else
[
"param_x"
,
"param_y"
,
"param_z"
]
fetch_list
+=
[
_append_grad_suffix_
(
param
)
for
param
in
fetch_list
]
return
fetch_list
return
fetch_list
,
self
.
optimizer
class
TestOptimizer
(
unittest
.
TestCase
):
...
...
@@ -180,7 +189,7 @@ class TestOptimizer(unittest.TestCase):
for
key
in
[
'x'
,
'y'
,
'z'
]:
self
.
param_attr
[
key
]
=
self
.
attr
.
copy
()
def
_check_grads
(
self
):
def
_check_grads
(
self
,
use_bf16
=
False
):
"""
main logic code to check the validity of apply_optimize.
"""
...
...
@@ -204,10 +213,16 @@ class TestOptimizer(unittest.TestCase):
lambda
:
dict
())
test_net
=
self
.
NetClass
(
self
.
optimizer
,
param_lr
,
y_no_grad
)
fetch_list
=
test_net
.
build_net
(
cond_i
)
fetch_list
,
decorated_optimizer
=
test_net
.
build_net
(
cond_i
,
use_bf16
)
if
use_bf16
:
self
.
optimizer
=
decorated_optimizer
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
init_program
)
if
use_bf16
:
self
.
optimizer
.
amp_init
(
exe
.
place
)
# Train 2 steps to check validity
for
batch_i
in
range
(
2
):
...
...
@@ -222,6 +237,15 @@ class TestOptimizer(unittest.TestCase):
param_grads
[
i
])
@
unittest
.
skipIf
(
not
fluid
.
core
.
supports_bfloat16
(),
"place does not support BF16 evaluation"
)
class
TestSGDOptimizer
(
TestOptimizer
):
def
test_optimizer_multiblock_except
(
self
):
with
self
.
assertRaisesRegexp
(
ValueError
,
"var param_y not in this block"
):
self
.
_check_grads
(
use_bf16
=
True
)
class
TestAdamOptimizer
(
TestOptimizer
):
"""
inherit TestOptimizer and shall override two functions as follows:
...
...
python/paddle/static/amp/__init__.py
浏览文件 @
93535c59
...
...
@@ -18,9 +18,6 @@ from ...fluid.contrib.mixed_precision import AutoMixedPrecisionLists # noqa: F4
from
...fluid.contrib.mixed_precision
import
fp16_guard
# noqa: F401
from
...fluid.contrib.mixed_precision
import
cast_model_to_fp16
# noqa: F401
from
...fluid.contrib.mixed_precision
import
cast_parameters_to_fp16
# noqa: F401
from
...fluid.contrib.mixed_precision
import
AutoMixedPrecisionListsBF16
# noqa: F401
from
...fluid.contrib.mixed_precision
import
bf16_guard
# noqa: F401
from
...fluid.contrib.mixed_precision
import
rewrite_program_bf16
# noqa: F401
from
...fluid.contrib.mixed_precision
import
convert_float_to_uint16
# noqa: F401
from
...fluid.contrib.mixed_precision
import
bf16
# noqa: F401
__all__
=
[]
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录