Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
46a73e64
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
46a73e64
编写于
1月 13, 2021
作者:
H
huangxu96
提交者:
GitHub
1月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add amp example document (#30315)
上级
428c884f
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
128 addition
and
12 deletion
+128
-12
python/paddle/fluid/contrib/mixed_precision/decorator.py
python/paddle/fluid/contrib/mixed_precision/decorator.py
+108
-11
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+4
-1
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
+16
-0
未找到文件。
python/paddle/fluid/contrib/mixed_precision/decorator.py
浏览文件 @
46a73e64
...
...
@@ -44,7 +44,7 @@ class OptimizerWithMixedPrecision(object):
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (
AutoMixedPrecisionLists): An AutoMixedPrecision
Lists object.
amp_lists (
CustomOpLists): An CustomOp
Lists object.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
...
...
@@ -196,12 +196,56 @@ class OptimizerWithMixedPrecision(object):
Init the amp training, such as cast fp32 parameters to fp16 type.
Args:
place(C
PUPlace|CUDAPlace): place is used to initialize
place(C
UDAPlace): place is used to initialize
fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_fp16_test(bool): Whether to use fp16 testing.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
def run_example_code():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use fp16_guard to control the range of fp16 kernels used.
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure fp16 training by setting `use_pure_fp16` to True.
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
assert
self
.
_train_program
is
not
None
,
\
"Please call the minimize method first."
...
...
@@ -383,7 +427,7 @@ def decorate(optimizer,
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (
AutoMixedPrecisionLists): An AutoMixedPrecision
Lists object.
amp_lists (
CustomOpLists): An CustomOp
Lists object.
init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
...
...
@@ -403,17 +447,70 @@ def decorate(optimizer,
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples:
Examples
1
:
.. code-block:: python
loss = network()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
# black&white list based strategy example
import paddle
import paddle.static as static
paddle.enable_static()
mp_optimizer = fluid.contrib.mixed_precision.decorate(
data = static.data(name='X', shape=[None, 1], dtype='float32')
hidden = static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
mp_optimizer = static.amp.decorate(
optimizer=optimizer, init_loss_scaling=8.0)
ops, param_grads = mp_optimizer.minimize(loss)
scaled_loss = mp_optimizer.get_scaled_loss()
Examples 2:
.. code-block:: python
# pure fp16 training example
import numpy as np
import paddle
import paddle.nn.functional as F
def run_example_code():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use fp16_guard to control the range of fp16 kernels used.
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure fp16 training by setting `use_pure_fp16` to True.
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
if
amp_lists
is
None
:
amp_lists
=
AutoMixedPrecisionLists
()
...
...
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
46a73e64
...
...
@@ -14,7 +14,7 @@
import
copy
__all__
=
[
"AutoMixedPrecisionLists"
]
__all__
=
[
"
CustomOpLists"
,
"
AutoMixedPrecisionLists"
]
class
AutoMixedPrecisionLists
(
object
):
...
...
@@ -27,6 +27,7 @@ class AutoMixedPrecisionLists(object):
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
custom_black_varnames (set): Users' custom black varibles' names.
"""
def
__init__
(
self
,
...
...
@@ -284,3 +285,5 @@ unsupported_fp16_list = {
'generate_proposal_labels'
,
'generate_mask_labels'
,
}
CustomOpLists
=
AutoMixedPrecisionLists
python/paddle/fluid/contrib/mixed_precision/fp16_utils.py
浏览文件 @
46a73e64
...
...
@@ -282,6 +282,22 @@ def fp16_guard():
As for the pure fp16 training, if users set `use_fp16_guard` to True,
only those ops created in the context manager `fp16_guard` will be
transformed as float16 type.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
"""
with
framework
.
name_scope
(
prefix
=
_fp16_guard_pattern
):
yield
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录