未验证 提交 99b3727d 编写于 作者: Y Yuang Liu 提交者: GitHub

[fused_transformer] update transformer fustion for dygraph, test=allcases (#43858)

上级 72116696
......@@ -173,6 +173,10 @@ def pure_fp16_initialize(models):
paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D,
paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)):
continue
if isinstance(layer, (paddle.incubate.nn.FusedFeedForward,
paddle.incubate.nn.FusedMultiHeadAttention)):
layer._amp_decorate(dtype='float16')
continue
layer._to_impl(dtype='float16',
include_sublayers=False,
floating_only=True)
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.nn as nn
from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention, FusedFeedForward
import unittest
class PreModel(nn.Layer):
def __init__(self):
super(PreModel, self).__init__()
self.attn = FusedMultiHeadAttention(
embed_dim=1024,
num_heads=16,
normalize_before=False,
)
self.ffn = FusedFeedForward(d_model=1024,
dim_feedforward=4096,
normalize_before=False)
def forward(self, x):
x = self.attn(x)
x = self.ffn(x)
class PostModel(nn.Layer):
def __init__(self):
super(PostModel, self).__init__()
self.attn = FusedMultiHeadAttention(
embed_dim=1024,
num_heads=16,
normalize_before=True,
)
self.ffn = FusedFeedForward(d_model=1024,
dim_feedforward=4096,
normalize_before=True)
def forward(self, x):
x = self.attn(x)
x = self.ffn(x)
class TestFusedTransformerWithAmpDecorator(unittest.TestCase):
def get_model(self):
self.pre_model = PreModel()
self.post_model = PostModel()
def test_run(self):
self.get_model()
pre_model = paddle.amp.decorate(models=self.pre_model,
level='O2',
save_dtype='float32')
post_model = paddle.amp.decorate(models=self.post_model,
level='O2',
save_dtype='float32')
if __name__ == "__main__":
unittest.main()
......@@ -526,8 +526,10 @@ def fused_multi_head_attention(x,
0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]."
assert qkv_weight.shape[3] == x.shape[
2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim."
assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
3], "embed_dim must be divisible by num_heads."
if ring_id == -1:
# under mp, the num head will be split, this equation will not hold
assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[
3], "embed_dim must be divisible by num_heads."
_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, cache_kv_out, final_out = _C_ops.fused_attention(
x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, cache_kv,
......
......@@ -18,8 +18,11 @@ from paddle.framework import ParamAttr
import paddle
from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list
from paddle.nn.initializer import Constant
import collections
from paddle.fluid.dygraph import no_grad
from paddle.fluid.framework import convert_np_dtype_to_dtype_, _non_static_mode
from paddle.fluid.core import VarDesc
from paddle.fluid import core
import numpy as np
# for distributed tensor model parallel
......@@ -29,11 +32,48 @@ def _set_var_distributed(var):
var.is_distributed = True
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
if not _non_static_mode():
# NOTE: use current_block and find_var_recursive to support while_loop
startup_block = paddle.static.default_startup_program().current_block()
main_block = paddle.static.default_main_program().current_block()
startup_block._find_var_recursive(var.name).is_distributed = True
main_block._find_var_recursive(var.name).is_distributed = True
def _to_dtype(t, dtype):
# this function is a prune of Layer._transform function to fix fused op under amp.decorator(O2)
if not paddle.is_floating_point(t):
return t
if type(dtype) is not VarDesc.VarType:
dtype = convert_np_dtype_to_dtype_(dtype)
if t.place.is_gpu_place():
size_dtype = core.size_of_dtype(dtype)
waiting_alloc_memory = (
(np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2
gpu_memory_available = core.gpu_memory_available()
if gpu_memory_available < waiting_alloc_memory:
t_used = t._copy_to(paddle.CPUPlace(), False)
t.value().get_tensor()._clear()
else:
t_used = t
else:
t_used = t
if dtype is not None and dtype != t_used.dtype:
with paddle.fluid.framework._dygraph_place_guard(place=t_used.place):
t_casted = t_used.cast(dtype=dtype)
else:
t_casted = t_used
new_t = t_casted
dst_tensor = t.value().get_tensor()
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)
return t
class FusedBiasDropoutResidualLayerNorm(Layer):
......@@ -374,6 +414,25 @@ class FusedMultiHeadAttention(Layer):
self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim,
self.normalize_before, self.need_weights, self._dtype, name_str)
def _amp_decorate(self, dtype):
# tmp fix for amp.decorator(O2)
layer_norm_params_id = []
if self.normalize_before:
layer_norm_params_id.append(id(self.pre_ln_scale))
layer_norm_params_id.append(id(self.pre_ln_bias))
else:
layer_norm_params_id.append(id(self.ln_scale))
layer_norm_params_id.append(id(self.ln_bias))
for key, param in self._parameters.items():
if id(param) in layer_norm_params_id:
continue
if param is not None:
with no_grad():
param_applied = _to_dtype(param, dtype)
self._dtype = dtype
class FusedFeedForward(Layer):
"""
......@@ -559,6 +618,25 @@ class FusedFeedForward(Layer):
self._epsilon, self._act_method, self._act_dropout_rate,
self._normalize_before, self._dtype, name_str)
def _amp_decorate(self, dtype):
# tmp fix for amp.decorator(O2)
layer_norm_params_id = []
if self._normalize_before:
layer_norm_params_id.append(id(self._ln1_scale))
layer_norm_params_id.append(id(self._ln1_bias))
else:
layer_norm_params_id.append(id(self._ln2_scale))
layer_norm_params_id.append(id(self._ln2_bias))
for key, param in self._parameters.items():
if id(param) in layer_norm_params_id:
continue
if param is not None:
with no_grad():
param_applied = _to_dtype(param, dtype)
self._dtype = dtype
class FusedTransformerEncoderLayer(Layer):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册