Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
13862008
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
13862008
编写于
1月 20, 2021
作者:
H
huangxu96
提交者:
GitHub
1月 20, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fleet amp_init() (#30572)
* add fleet amp.init() * add unittest for fleet_amp_init
上级
2d5758c4
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
150 addition
and
2 deletion
+150
-2
python/paddle/distributed/fleet/base/fleet_base.py
python/paddle/distributed/fleet/base/fleet_base.py
+64
-0
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
+6
-2
python/paddle/fluid/tests/unittests/test_fleet_amp_init.py
python/paddle/fluid/tests/unittests/test_fleet_amp_init.py
+80
-0
未找到文件。
python/paddle/distributed/fleet/base/fleet_base.py
浏览文件 @
13862008
...
...
@@ -958,6 +958,70 @@ class Fleet(object):
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
clear_grad
()
def
amp_init
(
self
,
place
,
scope
=
None
,
test_program
=
None
,
use_fp16_test
=
False
):
"""
Init the amp training, such as cast fp32 parameters to fp16 type.
Args:
place(CUDAPlace): place is used to initialize
fp16 parameters with fp32 values.
scope(Scope): The scope is used to find fp32 parameters.
test_program(Program): The program is used for testing.
use_fp16_test(bool): Whether to use fp16 testing.
Examples:
.. code-block:: python
import numpy as np
import paddle
import paddle.nn.functional as F
paddle.enable_static()
def run_example_code():
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
data = paddle.static.data(name='X', shape=[None, 1, 28, 28], dtype='float32')
conv2d = paddle.static.nn.conv2d(input=data, num_filters=6, filter_size=3)
# 1) Use fp16_guard to control the range of fp16 kernels used.
with paddle.static.amp.fp16_guard():
bn = paddle.static.nn.batch_norm(input=conv2d, act="relu")
pool = F.max_pool2d(bn, kernel_size=2, stride=2)
hidden = paddle.static.nn.fc(pool, size=10)
loss = paddle.mean(hidden)
# 2) Create the optimizer and set `multi_precision` to True.
# Setting `multi_precision` to True can avoid the poor accuracy
# or the slow convergence in a way.
optimizer = paddle.optimizer.Momentum(learning_rate=0.01, multi_precision=True)
# 3) These ops in `custom_black_list` will keep in the float32 computation type.
amp_list = paddle.static.amp.CustomOpLists(
custom_black_list=['pool2d'])
# 4) The entry of Paddle AMP.
# Enable pure fp16 training by setting `use_pure_fp16` to True.
optimizer = paddle.static.amp.decorate(
optimizer,
amp_list,
init_loss_scaling=128.0,
use_dynamic_loss_scaling=True,
use_pure_fp16=True)
# If you don't use the default_startup_program(), you sholud pass
# your defined `startup_program` into `minimize`.
optimizer.minimize(loss)
exe.run(paddle.static.default_startup_program())
# 5) Use `amp_init` after FP32 parameters initialization(such as `exe.run(startup_program)`).
# If you want to perform the testing process, you should pass `test_program` into `amp_init`.
optimizer.amp_init(place, scope=paddle.static.global_scope())
if paddle.is_compiled_with_cuda() and len(paddle.static.cuda_places()) > 0:
run_example_code()
"""
# imitate target optimizer retrieval
return
self
.
user_defined_optimizer
.
amp_init
(
place
,
scope
=
None
,
test_program
=
None
,
use_fp16_test
=
False
)
def
_final_strategy
(
self
):
if
"valid_strategy"
not
in
self
.
_context
:
print
(
...
...
python/paddle/fluid/contrib/mixed_precision/fp16_lists.py
浏览文件 @
13862008
...
...
@@ -95,6 +95,9 @@ black_list = {
'sigmoid_cross_entropy_with_logits'
,
'cross_entropy'
,
'cross_entropy2'
,
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table'
,
'lookup_table_v2'
,
}
# This set contains two types of ops. All ops supported fp16 calculation. One
...
...
@@ -115,8 +118,6 @@ gray_list = {
'layer_norm'
,
'tanh'
,
'sigmoid'
,
'lookup_table'
,
'lookup_table_v2'
,
'top_k'
,
'pool2d'
,
'pool3d'
,
...
...
@@ -284,6 +285,9 @@ unsupported_fp16_list = {
'generate_proposals'
,
'generate_proposal_labels'
,
'generate_mask_labels'
,
# fp16 is slower than fp32, though fp16 is supported.
'lookup_table'
,
'lookup_table_v2'
,
}
CustomOpLists
=
AutoMixedPrecisionLists
python/paddle/fluid/tests/unittests/test_fleet_amp_init.py
0 → 100644
浏览文件 @
13862008
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
paddle
import
paddle.distributed.fleet.base.role_maker
as
role_maker
import
paddle.distributed.fleet
as
fleet
import
paddle.fluid
as
fluid
import
unittest
import
paddle.nn.functional
as
F
import
numpy
as
np
paddle
.
enable_static
()
def
gen_data
():
return
{
"x"
:
np
.
random
.
random
(
size
=
(
128
,
32
)).
astype
(
'float32'
),
"y"
:
np
.
random
.
randint
(
2
,
size
=
(
128
,
1
)).
astype
(
'int64'
)
}
def
mlp
(
input_x
,
input_y
,
hid_dim
=
128
,
label_dim
=
2
):
fc_1
=
paddle
.
static
.
nn
.
fc
(
x
=
input_x
,
size
=
hid_dim
,
activation
=
'tanh'
)
fc_2
=
paddle
.
static
.
nn
.
fc
(
x
=
fc_1
,
size
=
hid_dim
,
activation
=
'tanh'
)
prediction
=
paddle
.
static
.
nn
.
fc
(
x
=
[
fc_2
],
size
=
label_dim
,
activation
=
'softmax'
)
cost
=
F
.
cross_entropy
(
input
=
prediction
,
label
=
input_y
)
avg_cost
=
paddle
.
mean
(
x
=
cost
)
return
avg_cost
class
TestFleetAMPInit
(
unittest
.
TestCase
):
def
test_fleet_amp_init
(
self
):
if
not
fluid
.
core
.
is_compiled_with_cuda
():
return
input_x
=
paddle
.
static
.
data
(
name
=
"x"
,
shape
=
[
None
,
32
],
dtype
=
'float32'
)
input_y
=
paddle
.
static
.
data
(
name
=
"y"
,
shape
=
[
None
,
1
],
dtype
=
'int64'
)
cost
=
mlp
(
input_x
,
input_y
)
optimizer
=
paddle
.
optimizer
.
Momentum
(
learning_rate
=
0.001
,
momentum
=
0.9
,
weight_decay
=
fluid
.
regularizer
.
L2Decay
(
1e-4
),
multi_precision
=
True
)
role
=
role_maker
.
PaddleCloudRoleMaker
(
is_collective
=
True
)
fleet
.
init
(
role
)
optimizer
=
paddle
.
static
.
amp
.
decorate
(
optimizer
)
optimizer
=
fleet
.
distributed_optimizer
(
optimizer
)
optimizer
.
minimize
(
cost
)
place
=
paddle
.
CUDAPlace
(
0
)
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
paddle
.
static
.
default_startup_program
())
optimizer
.
amp_init
(
place
,
use_fp16_test
=
True
)
step
=
1
for
i
in
range
(
step
):
cost_val
=
exe
.
run
(
program
=
paddle
.
static
.
default_main_program
(),
feed
=
gen_data
(),
fetch_list
=
[
cost
.
name
])
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录