Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6f3c9643
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6f3c9643
编写于
4月 14, 2023
作者:
J
JZ-LIANG
提交者:
GitHub
4月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Eb118 BF16 Adoption (#52827)
* pr1 * pr2 * pr3 * fixed unitest * adopt for scale
上级
8cbc75ca
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
1878 addition
and
970 deletion
+1878
-970
python/paddle/distributed/auto_parallel/constants.py
python/paddle/distributed/auto_parallel/constants.py
+4
-2
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
...dle/distributed/auto_parallel/operators/dist_embedding.py
+7
-4
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
...paddle/distributed/auto_parallel/operators/dist_matmul.py
+1049
-631
python/paddle/distributed/auto_parallel/parallelizer_v2.py
python/paddle/distributed/auto_parallel/parallelizer_v2.py
+14
-5
python/paddle/distributed/passes/auto_parallel_amp.py
python/paddle/distributed/passes/auto_parallel_amp.py
+347
-166
python/paddle/distributed/passes/auto_parallel_fp16.py
python/paddle/distributed/passes/auto_parallel_fp16.py
+252
-156
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
...paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
+3
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
...paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
+142
-0
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
.../fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
+1
-1
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
...e/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
+55
-0
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
...ddle/fluid/tests/unittests/auto_parallel/test_strategy.py
+4
-5
未找到文件。
python/paddle/distributed/auto_parallel/constants.py
浏览文件 @
6f3c9643
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
...
@@ -62,6 +62,8 @@ set_field_default_config(RECOMPUTE, "enable_tuning", False)
#########################################
#########################################
AMP
=
"amp"
AMP
=
"amp"
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"enable"
,
False
)
set_field_default_config
(
AMP
,
"dtype"
,
"float16"
)
set_field_default_config
(
AMP
,
"level"
,
"o1"
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"init_loss_scaling"
,
32768.0
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"incr_every_n_steps"
,
1000
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
set_field_default_config
(
AMP
,
"decr_every_n_nan_or_inf"
,
2
)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
...
@@ -71,8 +73,8 @@ set_field_default_config(AMP, "use_dynamic_loss_scaling", True)
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_white_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_list"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"custom_black_varnames"
,
[])
set_field_default_config
(
AMP
,
"use_
pure_fp16
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard
"
,
False
)
set_field_default_config
(
AMP
,
"use_
fp16_guard"
,
Tru
e
)
set_field_default_config
(
AMP
,
"use_
bf16_guard"
,
Fals
e
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
set_field_default_config
(
AMP
,
"use_optimizer_fp16"
,
False
)
#########################################
#########################################
...
...
python/paddle/distributed/auto_parallel/operators/dist_embedding.py
浏览文件 @
6f3c9643
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -459,7 +459,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_var
,
Out_var
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'c_allreduce_sum'
,
'c_allreduce_sum'
,
)
)
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -649,7 +649,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
check_variable_and_dtype
(
check_variable_and_dtype
(
Out_grad
,
Out_grad
,
'tensor'
,
'tensor'
,
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
],
[
'float16'
,
'float32'
,
'float64'
,
'int32'
,
'int64'
,
'uint16'
],
'_c_identity'
,
'_c_identity'
,
)
)
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
...
@@ -691,12 +691,15 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
},
},
)
)
check_variable_and_dtype
(
check_variable_and_dtype
(
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
intermediate_var_0
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
)
)
check_dtype
(
check_dtype
(
intermediate_var_0
.
dtype
,
intermediate_var_0
.
dtype
,
'dtype'
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
[
'float16'
,
'float32'
,
'float64'
,
'uint16'
],
'linear'
,
'linear'
,
)
)
...
...
python/paddle/distributed/auto_parallel/operators/dist_matmul.py
浏览文件 @
6f3c9643
此差异已折叠。
点击以展开。
python/paddle/distributed/auto_parallel/parallelizer_v2.py
浏览文件 @
6f3c9643
...
@@ -254,17 +254,26 @@ class Parallelizer:
...
@@ -254,17 +254,26 @@ class Parallelizer:
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
self
.
_dist_context
.
serial_feed_vars
[
"inputs"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
+
self
.
_dist_context
.
serial_feed_vars
[
"labels"
]
)
)
if
config
[
"use_pure_fp16"
]:
self
.
_logger
.
info
(
"Applying AMP-{}-{} ..."
.
format
(
config
[
"dtype"
],
config
[
'level'
]
),
)
if
config
[
'level'
]
==
"o1"
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
loss
=
auto_parallel_amp_pass
.
get_loss
()
elif
config
[
'level'
]
in
[
'o2'
,
'o3'
]:
config
[
"base_opt"
]
=
optimizer
config
[
"base_opt"
]
=
optimizer
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
=
new_pass
(
"auto_parallel_fp16"
,
config
)
auto_parallel_fp16_pass
.
apply
(
auto_parallel_fp16_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
)
loss
=
auto_parallel_fp16_pass
.
get_loss
()
else
:
else
:
auto_parallel_amp_pass
=
new_pass
(
"auto_parallel_amp"
,
config
)
raise
ValueError
(
"AMP level should be one of o1, o2, o3"
)
auto_parallel_amp_pass
.
apply
(
[
main_program
],
[
startup_program
],
self
.
_pass_context
)
# apply recompute pass
# apply recompute pass
# recompute is then train-only optimization
# recompute is then train-only optimization
...
...
python/paddle/distributed/passes/auto_parallel_amp.py
浏览文件 @
6f3c9643
此差异已折叠。
点击以展开。
python/paddle/distributed/passes/auto_parallel_fp16.py
浏览文件 @
6f3c9643
此差异已折叠。
点击以展开。
python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt
浏览文件 @
6f3c9643
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
...
@@ -40,6 +40,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
py_test_modules
(
test_random_ctrl MODULES test_random_ctrl ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
set_tests_properties
(
test_random_ctrl PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
TIMEOUT 50
)
py_test_modules
(
test_amp_o2_pass MODULES test_amp_o2_pass ENVS
${
dist_ENVS
}
)
set_tests_properties
(
test_amp_o2_pass PROPERTIES LABELS
"RUN_TYPE=EXCLUSIVE"
TIMEOUT 50
)
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
py_test_modules
(
test_iterable_dataset MODULES test_iterable_dataset ENVS
${
dist_ENVS
}
)
${
dist_ENVS
}
)
set_tests_properties
(
test_iterable_dataset
set_tests_properties
(
test_iterable_dataset
...
...
python/paddle/fluid/tests/unittests/auto_parallel/amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
random
import
re
import
unittest
import
numpy
as
np
from
get_gpt_model
import
FakeDataset
,
generate_model
import
paddle
from
paddle.distributed.fleet
import
auto
from
paddle.fluid.framework
import
core
paddle
.
enable_static
()
def
get_cuda_version
():
result
=
os
.
popen
(
"nvcc --version"
).
read
()
regex
=
r
'release (\S+),'
match
=
re
.
search
(
regex
,
result
)
if
match
:
num
=
str
(
match
.
group
(
1
))
integer
,
decimal
=
num
.
split
(
'.'
)
return
int
(
integer
)
*
1000
+
int
(
float
(
decimal
)
*
10
)
else
:
return
-
1
def
apply_pass
(
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
strategy
=
auto
.
Strategy
()
strategy
.
auto_mode
=
"semi"
strategy
.
reinit
=
True
if
use_amp
:
amp
=
strategy
.
amp
amp
.
enable
=
True
amp
.
dtype
=
amp_dtype
amp
.
level
=
"o2"
amp
.
custom_black_list
=
[
'c_softmax_with_cross_entropy'
,
'elementwise_div'
,
'reduce_sum'
,
]
return
strategy
def
reset_prog
():
paddle
.
fluid
.
framework
.
switch_main_program
(
paddle
.
static
.
Program
())
paddle
.
fluid
.
framework
.
switch_startup_program
(
paddle
.
static
.
Program
())
class
TestShardingStage2WithNewEXE
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
batch_size
=
2
self
.
batch_num
=
10
self
.
clip_norm
=
0.2
self
.
dataset
=
FakeDataset
(
self
.
batch_size
*
self
.
batch_num
)
def
init
(
self
,
engine
):
paddle
.
seed
(
2022
)
np
.
random
.
seed
(
2022
)
random
.
seed
(
2022
)
place
=
paddle
.
fluid
.
CUDAPlace
(
paddle
.
distributed
.
ParallelEnv
().
dev_id
)
engine
.
_executor
=
paddle
.
static
.
Executor
(
place
)
def
get_engine
(
self
,
use_amp
=
False
,
amp_dtype
=
"bfloat16"
):
reset_prog
()
strategy
=
apply_pass
(
use_amp
,
amp_dtype
)
# clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
clip
=
None
opt
=
paddle
.
optimizer
.
AdamW
(
learning_rate
=
0.00001
,
grad_clip
=
clip
)
model
,
loss
=
generate_model
(
"mp"
)
engine
=
auto
.
Engine
(
model
,
loss
,
opt
,
strategy
=
strategy
)
self
.
init
(
engine
)
return
engine
def
check_bf16
(
self
,
program
):
num_bf16
=
0
num_fp16
=
0
num_fp32
=
0
for
p
in
program
.
all_parameters
():
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP32
:
num_fp32
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
FP16
:
num_fp16
+=
1
if
p
.
dtype
==
core
.
VarDesc
.
VarType
.
BF16
:
num_bf16
+=
1
self
.
assertEqual
(
num_bf16
,
25
)
self
.
assertEqual
(
num_fp16
,
0
)
self
.
assertEqual
(
num_fp32
,
11
)
def
test_param_grad_fuse_overlap
(
self
):
# std
mp_engine
=
self
.
get_engine
(
use_amp
=
False
)
mp_history
=
mp_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss0
=
mp_history
.
history
[
'loss'
][
0
]
# bf16
mp_bf16_engine
=
self
.
get_engine
(
use_amp
=
True
)
if
not
paddle
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11000
:
return
mp_bf16_history
=
mp_bf16_engine
.
fit
(
self
.
dataset
,
3
,
epochs
=
1
,
steps_per_epoch
=
self
.
batch_num
,
log_freq
=
1
,
batch_size
=
self
.
batch_size
,
)
loss1
=
mp_bf16_history
.
history
[
'loss'
][
0
]
np
.
testing
.
assert_allclose
(
loss0
,
loss1
,
atol
=
1e-3
,
rtol
=
1e-2
)
self
.
check_bf16
(
mp_bf16_engine
.
main_program
)
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/amp_pass_unittest.py
浏览文件 @
6f3c9643
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
...
@@ -38,7 +38,7 @@ def apply_pass(use_amp=False, level=None):
]
]
amp
.
init_loss_scaling
=
32768
amp
.
init_loss_scaling
=
32768
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_pure_fp16
=
level
in
[
"o2"
,
"o3"
]
amp
.
level
=
level
amp
.
use_optimizer_fp16
=
level
==
"o3"
amp
.
use_optimizer_fp16
=
level
==
"o3"
print
(
"amp level: "
,
level
)
print
(
"amp level: "
,
level
)
return
strategy
return
strategy
...
...
python/paddle/fluid/tests/unittests/auto_parallel/test_amp_o2_pass.py
0 → 100644
浏览文件 @
6f3c9643
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
os
import
subprocess
import
sys
import
tempfile
import
unittest
class
TestAMPO2
(
unittest
.
TestCase
):
def
test_bf16
(
self
):
file_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
launch_model_path
=
os
.
path
.
join
(
file_dir
,
"amp_o2_pass.py"
)
if
os
.
environ
.
get
(
"WITH_COVERAGE"
,
"OFF"
)
==
"ON"
:
coverage_args
=
[
"-m"
,
"coverage"
,
"run"
,
"--branch"
,
"-p"
]
else
:
coverage_args
=
[]
tmp_dir
=
tempfile
.
TemporaryDirectory
()
cmd
=
(
[
sys
.
executable
,
"-u"
]
+
coverage_args
+
[
"-m"
,
"paddle.distributed.launch"
,
"--devices"
,
"0,1"
,
"--log_dir"
,
tmp_dir
.
name
,
launch_model_path
,
]
)
process
=
subprocess
.
Popen
(
cmd
)
process
.
wait
()
self
.
assertEqual
(
process
.
returncode
,
0
)
tmp_dir
.
cleanup
()
if
__name__
==
"__main__"
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/auto_parallel/test_strategy.py
浏览文件 @
6f3c9643
...
@@ -13,13 +13,13 @@
...
@@ -13,13 +13,13 @@
# limitations under the License.
# limitations under the License.
import
os
import
os
# import yaml
# import yaml
import
unittest
import
unittest
from
paddle.distributed.fleet
import
auto
from
paddle.distributed.fleet
import
auto
class
TestStrategy
(
unittest
.
TestCase
):
class
TestStrategy
(
unittest
.
TestCase
):
def
test_default_config
(
self
):
def
test_default_config
(
self
):
strategy
=
auto
.
Strategy
()
strategy
=
auto
.
Strategy
()
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
...
@@ -29,6 +29,8 @@ class TestStrategy(unittest.TestCase):
amp
=
strategy
.
amp
amp
=
strategy
.
amp
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertEqual
(
amp
.
enable
,
False
)
self
.
assertAlmostEqual
(
amp
.
dtype
,
"float16"
)
self
.
assertAlmostEqual
(
amp
.
level
,
"o1"
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertAlmostEqual
(
amp
.
init_loss_scaling
,
32768.0
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
incr_every_n_steps
,
1000
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
self
.
assertEqual
(
amp
.
decr_every_n_nan_or_inf
,
2
)
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
...
@@ -38,8 +40,7 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_white_list
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
False
)
sharding
=
strategy
.
sharding
sharding
=
strategy
.
sharding
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -92,7 +93,6 @@ class TestStrategy(unittest.TestCase):
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_white_list
=
[
"x"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_list
=
[
"y"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
custom_black_varnames
=
[
"z"
]
amp
.
use_pure_fp16
=
True
amp
.
use_fp16_guard
=
False
amp
.
use_fp16_guard
=
False
amp
.
use_optimizer_fp16
=
True
amp
.
use_optimizer_fp16
=
True
self
.
assertEqual
(
amp
.
enable
,
True
)
self
.
assertEqual
(
amp
.
enable
,
True
)
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
...
@@ -105,7 +105,6 @@ class TestStrategy(unittest.TestCase):
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_white_list
,
[
"x"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_list
,
[
"y"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
custom_black_varnames
,
[
"z"
])
self
.
assertEqual
(
amp
.
use_pure_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_fp16_guard
,
False
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
self
.
assertEqual
(
amp
.
use_optimizer_fp16
,
True
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录