From 99b3727d5bd37ae18ef63407dc4bd6ab75fb76fc Mon Sep 17 00:00:00 2001 From: Yuang Liu Date: Tue, 28 Jun 2022 22:15:26 +0800 Subject: [PATCH] [fused_transformer] update transformer fustion for dygraph, test=allcases (#43858) --- python/paddle/fluid/dygraph/amp/auto_cast.py | 4 + ...st_fused_transformer_with_amp_decorator.py | 74 +++++++++++++++ .../nn/functional/fused_transformer.py | 6 +- .../incubate/nn/layer/fused_transformer.py | 92 +++++++++++++++++-- 4 files changed, 167 insertions(+), 9 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fused_transformer_with_amp_decorator.py diff --git a/python/paddle/fluid/dygraph/amp/auto_cast.py b/python/paddle/fluid/dygraph/amp/auto_cast.py index 8547501e1b3..49c03684342 100644 --- a/python/paddle/fluid/dygraph/amp/auto_cast.py +++ b/python/paddle/fluid/dygraph/amp/auto_cast.py @@ -173,6 +173,10 @@ def pure_fp16_initialize(models): paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D, paddle.nn.LayerNorm, paddle.nn.SyncBatchNorm)): continue + if isinstance(layer, (paddle.incubate.nn.FusedFeedForward, + paddle.incubate.nn.FusedMultiHeadAttention)): + layer._amp_decorate(dtype='float16') + continue layer._to_impl(dtype='float16', include_sublayers=False, floating_only=True) diff --git a/python/paddle/fluid/tests/unittests/test_fused_transformer_with_amp_decorator.py b/python/paddle/fluid/tests/unittests/test_fused_transformer_with_amp_decorator.py new file mode 100644 index 00000000000..f0173d9632f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_transformer_with_amp_decorator.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from paddle.incubate.nn.layer.fused_transformer import FusedMultiHeadAttention, FusedFeedForward +import unittest + + +class PreModel(nn.Layer): + + def __init__(self): + super(PreModel, self).__init__() + self.attn = FusedMultiHeadAttention( + embed_dim=1024, + num_heads=16, + normalize_before=False, + ) + self.ffn = FusedFeedForward(d_model=1024, + dim_feedforward=4096, + normalize_before=False) + + def forward(self, x): + x = self.attn(x) + x = self.ffn(x) + + +class PostModel(nn.Layer): + + def __init__(self): + super(PostModel, self).__init__() + self.attn = FusedMultiHeadAttention( + embed_dim=1024, + num_heads=16, + normalize_before=True, + ) + self.ffn = FusedFeedForward(d_model=1024, + dim_feedforward=4096, + normalize_before=True) + + def forward(self, x): + x = self.attn(x) + x = self.ffn(x) + + +class TestFusedTransformerWithAmpDecorator(unittest.TestCase): + + def get_model(self): + self.pre_model = PreModel() + self.post_model = PostModel() + + def test_run(self): + self.get_model() + pre_model = paddle.amp.decorate(models=self.pre_model, + level='O2', + save_dtype='float32') + post_model = paddle.amp.decorate(models=self.post_model, + level='O2', + save_dtype='float32') + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 949f74b937a..3e4d015da1b 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -526,8 +526,10 @@ def fused_multi_head_attention(x, 0] == 3, "The shape of qkv_weight should be [3, num_head, head_dim, embed_dim]." assert qkv_weight.shape[3] == x.shape[ 2], "The 3rd dim of qkv_weight and 2nd dim of x should be the same, i.e., embed_dim." - assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[ - 3], "embed_dim must be divisible by num_heads." + if ring_id == -1: + # under mp, the num head will be split, this equation will not hold + assert qkv_weight.shape[1] * qkv_weight.shape[2] == qkv_weight.shape[ + 3], "embed_dim must be divisible by num_heads." _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, cache_kv_out, final_out = _C_ops.fused_attention( x, pre_ln_scale, pre_ln_bias, qkv_weight, qkv_bias, cache_kv, diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index f52cbd2cd3e..4a8f7815ae9 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -18,8 +18,11 @@ from paddle.framework import ParamAttr import paddle from paddle.nn.layer.transformer import _convert_attention_mask, _convert_param_attr_to_list from paddle.nn.initializer import Constant - -import collections +from paddle.fluid.dygraph import no_grad +from paddle.fluid.framework import convert_np_dtype_to_dtype_, _non_static_mode +from paddle.fluid.core import VarDesc +from paddle.fluid import core +import numpy as np # for distributed tensor model parallel @@ -29,11 +32,48 @@ def _set_var_distributed(var): var.is_distributed = True - # NOTE: use current_block and find_var_recursive to support while_loop - startup_block = paddle.static.default_startup_program().current_block() - main_block = paddle.static.default_main_program().current_block() - startup_block._find_var_recursive(var.name).is_distributed = True - main_block._find_var_recursive(var.name).is_distributed = True + if not _non_static_mode(): + # NOTE: use current_block and find_var_recursive to support while_loop + startup_block = paddle.static.default_startup_program().current_block() + main_block = paddle.static.default_main_program().current_block() + startup_block._find_var_recursive(var.name).is_distributed = True + main_block._find_var_recursive(var.name).is_distributed = True + + +def _to_dtype(t, dtype): + # this function is a prune of Layer._transform function to fix fused op under amp.decorator(O2) + if not paddle.is_floating_point(t): + return t + + if type(dtype) is not VarDesc.VarType: + dtype = convert_np_dtype_to_dtype_(dtype) + + if t.place.is_gpu_place(): + size_dtype = core.size_of_dtype(dtype) + waiting_alloc_memory = ( + (np.prod(t.shape) * size_dtype) / 256 + 1) * 256 * 1.2 + gpu_memory_available = core.gpu_memory_available() + if gpu_memory_available < waiting_alloc_memory: + t_used = t._copy_to(paddle.CPUPlace(), False) + t.value().get_tensor()._clear() + else: + t_used = t + else: + t_used = t + + if dtype is not None and dtype != t_used.dtype: + with paddle.fluid.framework._dygraph_place_guard(place=t_used.place): + t_casted = t_used.cast(dtype=dtype) + else: + t_casted = t_used + + new_t = t_casted + + dst_tensor = t.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return t class FusedBiasDropoutResidualLayerNorm(Layer): @@ -374,6 +414,25 @@ class FusedMultiHeadAttention(Layer): self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim, self.normalize_before, self.need_weights, self._dtype, name_str) + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + layer_norm_params_id = [] + if self.normalize_before: + layer_norm_params_id.append(id(self.pre_ln_scale)) + layer_norm_params_id.append(id(self.pre_ln_bias)) + else: + layer_norm_params_id.append(id(self.ln_scale)) + layer_norm_params_id.append(id(self.ln_bias)) + + for key, param in self._parameters.items(): + if id(param) in layer_norm_params_id: + continue + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + + self._dtype = dtype + class FusedFeedForward(Layer): """ @@ -559,6 +618,25 @@ class FusedFeedForward(Layer): self._epsilon, self._act_method, self._act_dropout_rate, self._normalize_before, self._dtype, name_str) + def _amp_decorate(self, dtype): + # tmp fix for amp.decorator(O2) + layer_norm_params_id = [] + if self._normalize_before: + layer_norm_params_id.append(id(self._ln1_scale)) + layer_norm_params_id.append(id(self._ln1_bias)) + else: + layer_norm_params_id.append(id(self._ln2_scale)) + layer_norm_params_id.append(id(self._ln2_bias)) + + for key, param in self._parameters.items(): + if id(param) in layer_norm_params_id: + continue + if param is not None: + with no_grad(): + param_applied = _to_dtype(param, dtype) + + self._dtype = dtype + class FusedTransformerEncoderLayer(Layer): """ -- GitLab