未验证 提交 b99c1d07 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Mixed Precision FP16 Pass (#40615)

*  add FP16 Pass 

* Support the auto completion of while_op

*  acc aligned
上级 5c5a3660
......@@ -67,6 +67,8 @@ message AMPConfig {
repeated string custom_black_varnames = 9;
optional bool use_pure_fp16 = 10 [ default = false ];
optional bool use_fp16_guard = 11 [ default = true ];
optional bool use_optimizer_fp16 = 12
[ default = false ]; // auto parallel effective only
}
message LocalSGDConfig {
......
......@@ -105,9 +105,15 @@ class AutoParallelizer:
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["loss"] = loss
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
if config["use_pure_fp16"]:
config["base_opt"] = self._optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply(
[main_program], [startup_program], self._pass_context)
else:
auto_parallel_amp_pass = new_pass("auto_parallel_amp", config)
auto_parallel_amp_pass.apply([main_program], [startup_program],
self._pass_context)
# apply recompute pass
if self._dist_strategy.recompute:
......
......@@ -357,10 +357,11 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
src_var = src_block.var(src_varname)
if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=True,
persistable=persist,
stop_gradient=True)
target_shape = None
else:
......
......@@ -1047,8 +1047,7 @@ def set_grad_var_shape(program, dist_context):
forward_input_dist_attr = op_dist_attr.get_input_dist_attr(
forward_var_name)
assert forward_input_dist_attr is not None, f"{forward_var_name}"
assert forward_input_dist_attr is not None, f"{forward_var_name, str(op)}"
forward_var = vars[forward_var_name]
forward_var_dist_attr = dist_context.get_tensor_dist_attr_for_program(
forward_var)
......
......@@ -17,6 +17,7 @@ from .fuse_all_reduce import *
from .auto_parallel_gradient_merge import *
from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_fp16 import *
from .auto_parallel_recompute import *
from .cpp_pass import *
import os
......
......@@ -503,8 +503,6 @@ class AMPPass(PassBase):
return False
if self.get_attr("decr_ratio") < 0:
return False
if len(self.get_attr("params_grads")) <= 0:
return False
if self.get_attr("dist_context") is None:
return False
return True
......@@ -576,6 +574,8 @@ class AMPPass(PassBase):
main_block = paddle.static.default_main_program().global_block()
main_block._sync_with_cpp()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
loss = self.get_attr("loss")
assert loss is not None
loss_op = loss.op
......@@ -583,6 +583,37 @@ class AMPPass(PassBase):
loss_op)
if loss.dtype != core.VarDesc.VarType.FP32:
# cast loss here will change the effective loss tensor for the computation graph
# and therefore will effect all following passes whose logic is based on the loss tensor(Recompute & Gradient Merge),
# so we it is not allowed by now. fixed it in future.
raise NotImplementedError(
"Loss's generator op is not support in FP16 in Auto Parallel by now, please put that op into your black-list."
)
tmp_name = unique_name.generate(loss.name + ".cast_fp32")
cast_loss = main_block.create_var(name=tmp_name, dtype=dtype)
loss_dist_attr = self.dist_context.get_tensor_dist_attr_for_program(
loss)
ref_mesh = loss_op_dist_attr.process_mesh
self.dist_context.set_tensor_dist_attr_for_program(cast_loss,
loss_dist_attr)
loss_op_idx = find_op_index(main_block.desc, loss_op.desc)
cast_op = main_block._insert_op(
loss_op_idx + 1,
type='cast',
inputs={'X': [loss]},
outputs={'Out': [cast_loss]},
attrs={
"in_dtype": loss.dtype,
"out_dtype": core.VarDesc.VarType.FP32,
'op_role': loss_op.all_attrs()[OP_ROLE_KEY],
})
loss_op._set_attr(OP_ROLE_KEY,
core.op_proto_and_checker_maker.OpRole.Forward)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_mesh, [-1], self.dist_context)
loss = loss.astype('float32')
if self.get_attr("use_dynamic_loss_scaling") or self.get_attr(
......@@ -600,7 +631,6 @@ class AMPPass(PassBase):
set_var_dist_attr(self.dist_context, self._scaled_loss, [-1],
ref_mesh)
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
elementwise_mul_op = main_block._insert_op(
loss_op_idx + 1,
type='elementwise_mul',
......@@ -667,8 +697,11 @@ class AMPPass(PassBase):
for e in grads:
check_variable_and_dtype(e, "x", ['float16', 'float32', 'float64'],
'update_loss_scaling')
assert self._loss_scaling.dtype == e.dtype, \
"The dtype of prev_loss_scaling should be equal to the dtype of x."
if e.dtype == core.VarDesc.VarType.FP16:
assert self._loss_scaling.dtype == core.VarDesc.VarType.FP32, \
"The dtype of prev_loss_scaling should be float32 when the dtype of x is float16."
else:
assert self._loss_scaling.dtype == e.dtype, "The dtype of prev_loss_scaling should be equal to the dtype of x."
inputs = {
'X': grads,
......
此差异已折叠。
......@@ -14,6 +14,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_XPU) AND NOT (WITH_ASCEND OR WITH_ASCEND_CL))
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_amp_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_recompute_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass")
endif()
foreach(TEST_OP ${TEST_OPS})
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.distributed.fleet as fleet
from auto_parallel_pass_test_base import AutoPallelPassTestBase
from test_auto_parallel_amp_pass import TestAMPPass
class TestPF16Pass(TestAMPPass):
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.amp = True
dist_strategy.amp_configs = {
"custom_white_list": [
'softmax',
'layer_norm',
'gelu',
],
"custom_black_list": ['c_softmax_with_cross_entropy'],
"init_loss_scaling": 32768,
"use_dynamic_loss_scaling": True,
"use_pure_fp16": True
}
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册