未验证 提交 342d62de 编写于 作者: H huangxu96 提交者: GitHub

add amp example document (#30314)

上级 b1d8ff45
......@@ -44,7 +44,7 @@ class OptimizerWithMixedPrecision(object):
Args:
optimizer (Optimizer): A common Optimizer object.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling (float): The initial loss scaling factor.
use_dynamic_loss_scaling (bool): Whether to use dynamic loss scaling.
incr_every_n_steps(int): Increases loss scaling every n consecutive
......@@ -196,12 +196,56 @@ class OptimizerWithMixedPrecision(object):
Init the amp training, such as cast fp32 parameters to fp16 type.
Args:
place(CPUPlace|CUDAPlace): place is used to initialize
place(CUDAPlace): place is used to initialize
fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_fp16_test(bool): Whether to use fp16 testing.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
def run_example_code():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use fp16_guard to control the range of fp16 kernels used.
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure fp16 training by setting `use_pure_fp16` to True.
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
assert self._train_program is not None, \
"Please call the minimize method first."
......@@ -383,7 +427,7 @@ def decorate(optimizer,
Args:
optimizer(Optimizer): A common Optimizer.
amp_lists (AutoMixedPrecisionLists): An AutoMixedPrecisionLists object.
amp_lists (CustomOpLists): An CustomOpLists object.
init_loss_scaling(float): The initial loss scaling factor.
incr_every_n_steps(int): Increases loss scaling every n consecutive
steps with finite gradients.
......@@ -403,17 +447,70 @@ def decorate(optimizer,
An optimizer acting like a normal one but with mixed-precision training
enabled.
Examples:
.. code-block:: python
Examples 1:
.. code-block:: python
# black&white list based strategy example
import paddle
import paddle.static as static
paddle.enable_static()
data = static.data(name='X', shape=[None, 1], dtype='float32')
hidden = static.nn.fc(x=data, size=10)
loss = paddle.mean(hidden)
optimizer = paddle.optimizer.Adam(learning_rate=0.001)
mp_optimizer = static.amp.decorate(
optimizer=optimizer, init_loss_scaling=8.0)
loss = network()
optimizer = fluid.optimizer.Adam(learning_rate=0.001)
mp_optimizer = fluid.contrib.mixed_precision.decorate(
optimizer=optimizer, init_loss_scaling=8.0)
ops, param_grads = mp_optimizer.minimize(loss)
scaled_loss = mp_optimizer.get_scaled_loss()
Examples 2:
.. code-block:: python
# pure fp16 training example
import numpy as np
import paddle
import paddle.nn.functional as F
def run_example_code():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use fp16_guard to control the range of fp16 kernels used.
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure fp16 training by setting `use_pure_fp16` to True.
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
if amp_lists is None:
amp_lists = AutoMixedPrecisionLists()
......
......@@ -14,7 +14,7 @@
import copy
__all__ = ["AutoMixedPrecisionLists"]
__all__ = ["CustomOpLists", "AutoMixedPrecisionLists"]
class AutoMixedPrecisionLists(object):
......@@ -27,6 +27,7 @@ class AutoMixedPrecisionLists(object):
Args:
custom_white_list (set): Users' custom white list.
custom_black_list (set): Users' custom black list.
custom_black_varnames (set): Users' custom black varibles' names.
"""
def __init__(self,
......@@ -284,3 +285,5 @@ unsupported_fp16_list = {
'generate_proposal_labels',
'generate_mask_labels',
}
CustomOpLists = AutoMixedPrecisionLists
......@@ -282,6 +282,22 @@ def fp16_guard():
As for the pure fp16 training, if users set `use_fp16_guard` to True,
only those ops created in the context manager `fp16_guard` will be
transformed as float16 type.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
"""
with framework.name_scope(prefix=_fp16_guard_pattern):
yield
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册