未验证 提交 e9e07a19 编写于 作者: W Wennie396 提交者: GitHub

fix some bugs for amp and test case test_tuning_recompute_with_amp.py (#56864)

* replace amp.use_pure_fp16 with amp.dtype and amp.level

* old api still use use_pure_fp16

* test_fuse_adamw_pass still use use_pure_fp16

* add test case tuning recompute with amp(float16,o2)

* reset new test case properties TIMEOUT 60

* set smaller value of batch_size and batch_num

* deepcopy dist_context fix _rename_input problem

* fix loss name after cast

* set tuning.enable=True and use engine._tune()

* restore some changes in _rename_input()/_rename_output()

* add self.amp_dtype for _cast_loss() in auto_parallel_amp.py

* fix insert op index in _cast_loss()
上级 54b247b1
...@@ -308,7 +308,7 @@ class OptimizationTuner: ...@@ -308,7 +308,7 @@ class OptimizationTuner:
self._baseline_dist_context.serial_feed_vars["inputs"] self._baseline_dist_context.serial_feed_vars["inputs"]
+ self._baseline_dist_context.serial_feed_vars["labels"] + self._baseline_dist_context.serial_feed_vars["labels"]
) )
if config["use_pure_fp16"]: if config["dtype"] == "float16" and config["level"] == "o2":
config["base_opt"] = dist_context.serial_optimizer config["base_opt"] = dist_context.serial_optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config) auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply( auto_parallel_fp16_pass.apply(
......
...@@ -728,7 +728,7 @@ class AMPPass(PassBase): ...@@ -728,7 +728,7 @@ class AMPPass(PassBase):
if is_train: if is_train:
self._update_backward_cast_ops() self._update_backward_cast_ops()
self._cast_loss() self._cast_loss(self.amp_dtype)
if is_train and self.amp_dtype == "float16": if is_train and self.amp_dtype == "float16":
self._init_amp_var() self._init_amp_var()
...@@ -913,7 +913,7 @@ class AMPPass(PassBase): ...@@ -913,7 +913,7 @@ class AMPPass(PassBase):
world_process_group.ranks, world_process_group.ranks,
) )
def _cast_loss(self): def _cast_loss(self, target_dtype):
main_block = paddle.static.default_main_program().global_block() main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp() main_block._sync_with_cpp()
...@@ -957,11 +957,18 @@ class AMPPass(PassBase): ...@@ -957,11 +957,18 @@ class AMPPass(PassBase):
) )
# backward # backward
first_backward_op = main_block.ops[loss_op_idx + 2] first_backward_op = None
assert ( insert_op_offset = 3
first_backward_op.type == "fill_constant" for idx, op in enumerate(main_block.ops[loss_op_idx:]):
and int(first_backward_op.all_attrs()[OP_ROLE_KEY]) == 257 if op.type == "fill_constant" and is_loss_grad_op(op):
) first_backward_op = op
insert_op_offset = idx + 1
break
if is_backward_op(op):
break
assert first_backward_op is not None, "There is not loss_grad op."
cast_loss_grad = main_block.create_var( cast_loss_grad = main_block.create_var(
name=unique_name.generate(tmp_name + "@GRAD"), name=unique_name.generate(tmp_name + "@GRAD"),
shape=loss.shape, shape=loss.shape,
...@@ -984,13 +991,13 @@ class AMPPass(PassBase): ...@@ -984,13 +991,13 @@ class AMPPass(PassBase):
self.dist_context, self.dist_context,
) )
cast_grad_op = main_block._insert_op( cast_grad_op = main_block._insert_op(
loss_op_idx + 3, loss_op_idx + insert_op_offset,
type='cast', type='cast',
inputs={'X': [cast_loss_grad]}, inputs={'X': [cast_loss_grad]},
outputs={'Out': [pre_grad_name]}, outputs={'Out': [pre_grad_name]},
attrs={ attrs={
"in_dtype": core.VarDesc.VarType.FP32, "in_dtype": core.VarDesc.VarType.FP32,
"out_dtype": _str_to_dtype(self.amp_dtype), "out_dtype": _str_to_dtype(target_dtype),
"op_role": OpRole.Backward, "op_role": OpRole.Backward,
}, },
) )
...@@ -1002,6 +1009,7 @@ class AMPPass(PassBase): ...@@ -1002,6 +1009,7 @@ class AMPPass(PassBase):
) )
loss_op = cast_op loss_op = cast_op
loss = cast_loss loss = cast_loss
self.set_attr("loss", loss)
self._loss = loss self._loss = loss
main_block._sync_with_cpp() main_block._sync_with_cpp()
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from collections import defaultdict from collections import defaultdict
import paddle import paddle
...@@ -408,8 +409,8 @@ class FP16State: ...@@ -408,8 +409,8 @@ class FP16State:
(cast_name, in_var.name, dst_dtype, src_dtype, in_name) (cast_name, in_var.name, dst_dtype, src_dtype, in_name)
] ]
in_var_dist_attr = consume_op_attr.get_input_dist_attr( in_var_dist_attr = copy.deepcopy(
in_var.name consume_op_attr.get_input_dist_attr(in_var.name)
) )
assert in_var_dist_attr is not None assert in_var_dist_attr is not None
# truly insert cast op # truly insert cast op
...@@ -800,6 +801,8 @@ class FP16Pass(AMPPass): ...@@ -800,6 +801,8 @@ class FP16Pass(AMPPass):
is_train = fp16_state._build_state() is_train = fp16_state._build_state()
cast_startup_program() cast_startup_program()
if is_train:
self._cast_loss(self.target_dtype)
if is_train: if is_train:
if self.target_dtype == "float16": if self.target_dtype == "float16":
......
...@@ -105,6 +105,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -105,6 +105,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50) set_tests_properties(test_selective_recompute PROPERTIES TIMEOUT 50)
py_test_modules(test_tuning_recompute MODULES test_tuning_recompute) py_test_modules(test_tuning_recompute MODULES test_tuning_recompute)
set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 300) set_tests_properties(test_tuning_recompute PROPERTIES TIMEOUT 300)
py_test_modules(test_tuning_recompute_with_amp MODULES
test_tuning_recompute_with_amp)
set_tests_properties(test_tuning_recompute_with_amp PROPERTIES TIMEOUT 60)
py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass) py_test_modules(test_fused_linear_pass MODULES test_fused_linear_pass)
set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 40) set_tests_properties(test_fused_linear_pass PROPERTIES TIMEOUT 40)
py_test_modules(test_align_tool MODULES test_align_tool) py_test_modules(test_align_tool MODULES test_align_tool)
......
...@@ -80,7 +80,8 @@ def parallelizer(program_func, rank): ...@@ -80,7 +80,8 @@ def parallelizer(program_func, rank):
strategy = auto.Strategy() strategy = auto.Strategy()
amp = strategy.amp amp = strategy.amp
amp.enable = True amp.enable = True
amp.use_pure_fp16 = True amp.dtype = "float16"
amp.level = "o2"
amp.init_loss_scaling = 32768 amp.init_loss_scaling = 32768
amp.use_fp16_guard = False amp.use_fp16_guard = False
amp.custom_black_list = ['where'] amp.custom_black_list = ['where']
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
from get_gpt_model import FakeDataset
import paddle
from paddle.distributed.fleet import auto
sys.path.append("../legacy_test")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import (
GPTForPretraining,
GPTModel,
GPTPretrainingCriterion,
)
paddle.enable_static()
def generate_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
ranks = list(range(paddle.distributed.get_world_size()))
modeling._global_process_mesh = auto.ProcessMesh(
mesh=ranks, dim_names=["x"]
)
gpt = GPTModel(
vocab_size=50304,
hidden_size=1024,
num_hidden_layers=8,
num_attention_heads=16,
intermediate_size=1024 * 4,
hidden_act="gelu",
hidden_dropout_prob=0.1,
attention_probs_dropout_prob=0.1,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
use_new_recompute=True,
recompute_granularity="full",
)
model = GPTForPretraining(
gpt, vocab_size=50304, hidden_size=1024, initializer_range=0.02
)
criterion = GPTPretrainingCriterion()
return model, criterion
def apply_pass():
strategy = auto.Strategy()
strategy.auto_mode = "semi"
recompute = strategy.recompute
recompute.enable = True
recompute.enable_tuning = True
tuning = strategy.tuning
tuning.enable = True
tuning.profile_start_step = 1
tuning.profile_end_step = 2
tuning.run_after_tuning = True
tuning.verbose = True
amp = strategy.amp
amp.enable = True
amp.dtype = "float16"
amp.level = "o2"
return strategy
class TestRecomputeWithAMPPassTuning(unittest.TestCase):
def setUp(self):
self.batch_size = 2
self.batch_num = 10
self.dataset = FakeDataset(
self.batch_size * self.batch_num,
vocab_size=50304,
sequence_len=1024,
)
def test_recompute_with_amp_pass(self):
strategy = apply_pass()
clip = paddle.nn.ClipGradByGlobalNorm(0.2)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model()
engine = auto.Engine(model, loss, opt, strategy=strategy)
# engine.fit(self.dataset, 3, batch_size=self.batch_size)
engine._tune(self.dataset, 3, batch_size=self.batch_size)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册