From 6486e242ed3e473b54eb6cda3e28983599577ac4 Mon Sep 17 00:00:00 2001 From: zhangkaihuo Date: Fri, 12 Nov 2021 21:14:22 +0800 Subject: [PATCH] [fix]fix the bug of fused_attention and fused_feedforward (#36972) * fix bug: 1. atten: set the default value of attn_dropout_rate to None 2. ffn: add activation parameter --- .../fluid/operators/fused/attn_bias_add.cu.h | 20 +- .../operators/fused/fused_feedforward_op.cu | 11 +- .../fluid/tests/unittests/CMakeLists.txt | 1 + .../unittests/test_fused_attention_op.py | 3 + .../unittests/test_fused_feedforward_op.py | 17 +- .../test_fused_transformer_encoder_layer.py | 188 ++++++++++++++++++ .../nn/functional/fused_transformer.py | 67 ++++++- .../incubate/nn/layer/fused_transformer.py | 66 +++++- python/paddle/nn/functional/common.py | 6 +- 9 files changed, 350 insertions(+), 29 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 18ae932c932..f7478364cdf 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -87,7 +87,8 @@ __global__ void BroadcastKernelBinary( kernel_primitives::ElementwiseBinary( result, arg0, arg1, func); // store - kernel_primitives::WriteData(out + fix, result, num); + kernel_primitives::WriteData(out + fix, result, + num); } // bias add forward impl for "[m, n] + [n] = [m, n]" @@ -267,25 +268,24 @@ __global__ void BiasAddBw1DReduceKernel(const ReduceParamType* temp_sum, } template -void Launch2DColumnReduce(gpuStream_t stream, const int max_threads, - const int reduce_num, const int left_num, - const T* d_out, T* d_bias) { +void Launch2DColumnReduce(const platform::CUDADeviceContext& dev_ctx, + const int max_threads, const int reduce_num, + const int left_num, const T* d_out, T* d_bias) { dim3 block; dim3 grid; bool should_reduce_again = false; int blocking_size = 1; SetConfigForColumnReduce(max_threads, reduce_num, left_num, &blocking_size, &should_reduce_again, &block, &grid); + const auto& stream = dev_ctx.stream(); if (!should_reduce_again) { BiasAddBwSinglePassKernel<<>>(d_out, reduce_num, left_num, d_bias); } else { framework::Tensor tmp_sum; - tmp_sum.mutable_data>( - framework::make_ddim({static_cast( - left_num * grid.y * sizeof(ReduceParamType))}), - paddle::platform::CUDAPlace()); + tmp_sum.Resize({grid.y, left_num}); + tmp_sum.mutable_data>(dev_ctx.GetPlace()); BiasAddBw2DReduceKernel<<>>( d_out, reduce_num, left_num, blocking_size, @@ -311,8 +311,8 @@ void LaunchBiasAddBwKernel(const platform::CUDADeviceContext& dev_ctx, int m, Launch1DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, d_out, d_bias); } else { - Launch2DColumnReduce(dev_ctx.stream(), max_threads, reduce_num, left_num, - d_out, d_bias); + Launch2DColumnReduce(dev_ctx, max_threads, reduce_num, left_num, d_out, + d_bias); } } diff --git a/paddle/fluid/operators/fused/fused_feedforward_op.cu b/paddle/fluid/operators/fused/fused_feedforward_op.cu index 3b47e65c483..a241e3c3027 100644 --- a/paddle/fluid/operators/fused/fused_feedforward_op.cu +++ b/paddle/fluid/operators/fused/fused_feedforward_op.cu @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/matmul_v2_op.h" +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" #include "paddle/fluid/operators/fused/fused_dropout_helper.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" @@ -261,7 +262,7 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { framework::Tensor d_linear2_out, d_dropout2_out, d_residual; d_linear2_out.mutable_data({bsz_seq, d_model}, place); d_dropout2_out.mutable_data({bsz_seq, d_model}, place); - d_residual.mutable_data({bsz_seq, d_model}, place); + d_residual.mutable_data(d_x->dims(), place); if (pre_layer_norm) { fused_dropout_layernorm_helper.ResidualDropoutBiasGrad( @@ -301,6 +302,14 @@ class FusedFeedForwardGradKernel : public framework::OpKernel { } else { MatMulGrad(ctx, d_linear1_out, x, linear1_weight, d_x, d_linear1_weight); } + std::vector ins(2); + std::vector outs(1); + ins[0] = &d_residual; + ins[1] = d_x; + outs[0] = d_x; + int elewise_add_axis = -1; + LaunchElementwiseCudaKernel( + ctx, ins, &outs, elewise_add_axis, AddFunctor()); } void Compute(const framework::ExecutionContext& context) const override { diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index 97af987ae3b..deabdc6c141 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -104,6 +104,7 @@ if(NOT WITH_GPU) LIST(REMOVE_ITEM TEST_OPS test_fused_feedforward_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op) LIST(REMOVE_ITEM TEST_OPS test_fused_attention_op_api) + LIST(REMOVE_ITEM TEST_OPS test_fused_transformer_encoder_layer) endif() if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) diff --git a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py index c0b3e27e671..b2b5cac2bff 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_attention_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_attention_op.py @@ -26,6 +26,9 @@ from paddle import tensor from paddle.fluid import layers import unittest from op_test import OpTest +from paddle.fluid.framework import default_main_program + +default_main_program().random_seed = 42 class TestFusedAttentionOp(OpTest): diff --git a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py index 5ea43d2edf0..a533b5d87a5 100644 --- a/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py +++ b/python/paddle/fluid/tests/unittests/test_fused_feedforward_op.py @@ -23,6 +23,7 @@ from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.common import Linear, Dropout import unittest from op_test import OpTest +from paddle.fluid.framework import default_main_program class TestFusedFFNOp(OpTest): @@ -91,7 +92,7 @@ class TestFusedFFNOp(OpTest): def Base(self): paddle.disable_static() tensor_src = paddle.to_tensor(self.src, stop_gradient=False) - residual = paddle.to_tensor(self.src) + residual = tensor_src if self.pre_layer_norm: ln1_out = self.norm1(tensor_src) linear2_out = self.linear2( @@ -140,6 +141,7 @@ class TestFusedFFNOp(OpTest): return out, x.grad def test_out_and_grad(self): + default_main_program().random_seed = 42 base_out, base_grad = self.Base() fused_out, fused_grad = self.FusedFFN() np.testing.assert_allclose( @@ -192,6 +194,7 @@ class TestFusedFFNOpNormalizeBefore(TestFusedFFNOp): class APITestStaticFusedFFN(unittest.TestCase): def test_static(self): paddle.enable_static() + default_main_program().random_seed = 42 dtype = "float32" layer_norm_dtype = "float32" batch_size = 1 @@ -324,6 +327,18 @@ class TestFusedFFNOpError(unittest.TestCase): self.assertRaises(ValueError, test_dropout_rate_value) + def test_dropout_mode(): + x = paddle.static.data( + name='x3', shape=[1, 10, 10], dtype="float32") + linear1_weight = paddle.static.data( + name='linear1_weight3', shape=[10, 10], dtype="float32") + linear2_weight = paddle.static.data( + name='linear2_weight3', shape=[10, 10], dtype="float32") + incubate_f.fused_feedforward( + x, linear1_weight, linear2_weight, mode='test') + + self.assertRaises(ValueError, test_dropout_mode) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py new file mode 100644 index 00000000000..e0281d6e21e --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_transformer_encoder_layer.py @@ -0,0 +1,188 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np + +import paddle +from paddle.incubate.nn import FusedTransformerEncoderLayer +from paddle.nn import TransformerEncoderLayer +from paddle.fluid.framework import default_main_program +import unittest + + +class TestFusedTransformerEncoderLayer(unittest.TestCase): + def setActivation(self): + self.activation = 'gelu' + + def setPreLayerNorm(self): + self.pre_layer_norm = False + + def setAttnMask(self): + self.has_attn_mask = True + + def setUp(self): + self.batch_size = np.random.randint(1, 8) + self.query_length = np.random.randint(1, 128) + self.nhead = 16 + self.head_dim = 4 + self.num_heads = self.nhead + self.d_model = self.head_dim * self.num_heads + self.embed_dim = self.d_model + self.dim_feedforward = np.random.randint(1, 32) + self.dropout_rate = 0 + self.attn_dropout_rate = None + self.act_dropout_rate = None + self.attn_mask_type = np.float64 + self.key_length = self.query_length + self.dtype = 'float32' + self.setActivation() + self.setPreLayerNorm() + self.setAttnMask() + + def fused_weight(self, weight, num_head): + a = paddle.transpose(weight, perm=[1, 0]) + return paddle.reshape( + a, shape=[1, num_head, int(a.shape[0] / num_head), a.shape[1]]) + + def fused_qkv(self, q, k, v, num_head): + fq = self.fused_weight(q, num_head) + fk = self.fused_weight(k, num_head) + fv = self.fused_weight(v, num_head) + return paddle.concat(x=[fq, fk, fv], axis=0) + + def test_out(self): + default_main_program().random_seed = 42 + base_encoder = TransformerEncoderLayer( + self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate, + self.activation, self.attn_dropout_rate, self.act_dropout_rate, + self.pre_layer_norm) + src = np.random.rand(self.batch_size, self.query_length, + self.embed_dim).astype(self.dtype) + + if self.has_attn_mask: + attn_mask = np.ones( + (self.batch_size, self.num_heads, self.query_length, + self.key_length), + dtype=self.attn_mask_type) + attn_mask_tensor = paddle.to_tensor(attn_mask) + else: + attn_mask = None + attn_mask_tensor = None + + dout = np.random.random(src.shape).astype(self.dtype) + + base_out = base_encoder( + paddle.to_tensor( + src, stop_gradient=False), attn_mask_tensor) + paddle.autograd.backward([base_out], [paddle.to_tensor(dout)], True) + + fused_encoder = FusedTransformerEncoderLayer( + self.d_model, self.nhead, self.dim_feedforward, self.dropout_rate, + self.activation, self.attn_dropout_rate, self.act_dropout_rate, + self.pre_layer_norm) + + fused_encoder.ffn._linear1_weight.set_value(base_encoder.linear1.weight) + fused_encoder.ffn._linear1_bias.set_value(base_encoder.linear1.bias) + fused_encoder.ffn._linear2_weight.set_value(base_encoder.linear2.weight) + fused_encoder.ffn._linear2_bias.set_value(base_encoder.linear2.bias) + if self.pre_layer_norm: + fused_encoder.ffn._ln1_scale.set_value(base_encoder.norm2.weight) + fused_encoder.ffn._ln1_bias.set_value(base_encoder.norm2.bias) + else: + fused_encoder.ffn._ln2_scale.set_value(base_encoder.norm2.weight) + fused_encoder.ffn._ln2_bias.set_value(base_encoder.norm2.bias) + + fused_encoder.fused_attn.linear_weight.set_value( + base_encoder.self_attn.out_proj.weight) + fused_encoder.fused_attn.linear_bias.set_value( + base_encoder.self_attn.out_proj.bias) + if self.pre_layer_norm: + fused_encoder.fused_attn.pre_ln_scale.set_value( + base_encoder.norm1.weight) + fused_encoder.fused_attn.pre_ln_bias.set_value( + base_encoder.norm1.bias) + else: + fused_encoder.fused_attn.ln_scale.set_value( + base_encoder.norm1.weight) + fused_encoder.fused_attn.ln_bias.set_value(base_encoder.norm1.bias) + + q = base_encoder.self_attn.q_proj.weight + q_bias = base_encoder.self_attn.q_proj.bias + k = base_encoder.self_attn.k_proj.weight + k_bias = base_encoder.self_attn.k_proj.bias + v = base_encoder.self_attn.v_proj.weight + v_bias = base_encoder.self_attn.v_proj.bias + qkv_weight = self.fused_qkv(q, k, v, self.num_heads) + fused_encoder.fused_attn.qkv_weight.set_value(qkv_weight) + + tmp = paddle.concat(x=[q_bias, k_bias, v_bias], axis=0) + qkv_bias = paddle.reshape( + tmp, + shape=[3, self.num_heads, int(tmp.shape[0] / 3 / self.num_heads)]) + fused_encoder.fused_attn.qkv_bias.set_value(qkv_bias) + + fused_out = fused_encoder( + paddle.to_tensor( + src, stop_gradient=False), attn_mask_tensor) + paddle.autograd.backward([fused_out], [paddle.to_tensor(dout)], True) + + correct_ffn_str = 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}'.format( + self.d_model, self.dim_feedforward, self.dropout_rate, + fused_encoder.ffn._epsilon, self.activation, self.dropout_rate, + self.pre_layer_norm, self.dtype) + self.assertTrue(fused_encoder.ffn.extra_repr(), correct_ffn_str) + + correct_attn_str = 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}'.format( + self.embed_dim, self.num_heads, self.dropout_rate, + self.dropout_rate, fused_encoder.fused_attn._epsilon, None, None, + self.pre_layer_norm, False, self.dtype) + self.assertTrue(fused_encoder.fused_attn.extra_repr(), correct_attn_str) + + np.testing.assert_allclose( + fused_out.numpy(), base_out.numpy(), rtol=1e-3, atol=1e-4) + self.assertTrue( + np.allclose( + fused_out.grad.numpy(), + base_out.grad.numpy(), + rtol=1e-3, + atol=1e-4)) + + +class TestFusedTransformerEncoderLayerAct(TestFusedTransformerEncoderLayer): + def setActivation(self): + self.activation = 'relu' + + +class TestFusedTransformerEncoderLayerPreLayerNorm( + TestFusedTransformerEncoderLayer): + def setPreLayerNorm(self): + self.pre_layer_norm = True + + +class TestFusedTransformerEncoderLayerAttnMaskIsNone( + TestFusedTransformerEncoderLayer): + def setAttnMask(self): + self.has_attn_mask = False + + +class TestFusedTransformerEncoderLayerPreLnTrueAttnMaskIsNone( + TestFusedTransformerEncoderLayer): + def setPreLayerNorm(self): + self.pre_layer_norm = True + + def setAttnMask(self): + self.has_attn_mask = False + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 6c447a73c52..3651f90d729 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -13,7 +13,7 @@ # limitations under the License. from paddle.fluid.layer_helper import LayerHelper -from paddle.fluid.framework import in_dygraph_mode +from paddle.fluid.framework import in_dygraph_mode, default_main_program from paddle.fluid.data_feeder import check_variable_and_dtype, check_dtype from paddle.fluid import core, dygraph_utils from paddle import _C_ops @@ -43,6 +43,8 @@ def fused_feedforward(x, ln1_epsilon=1e-5, ln2_epsilon=1e-5, pre_layer_norm=False, + training=True, + mode='upscale_in_train', name=None): """ This is a fusion operator to compute feed forward layer in transformer model architecture. @@ -74,6 +76,8 @@ def fused_feedforward(x, ln1_epsilon (float, optional): Small float of first layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. ln2_epsilon (float, optional): Small float of second layer_norm added to denominator to avoid dividing by zero. Default is 1e-5. pre_layer_norm (bool, optional): add layer_norm in the pre-processing stage or post-processing state. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']. name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: @@ -98,13 +102,27 @@ def fused_feedforward(x, _verify_dropout_rate(dropout1_rate) _verify_dropout_rate(dropout2_rate) + seed = None + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed out, _, _, _, _, _, _, _, _, _, _ = _C_ops.fused_feedforward( x, None, None, linear1_weight, linear1_bias, linear2_weight, linear2_bias, ln1_scale, ln1_bias, ln2_scale, ln2_bias, 'pre_layer_norm', pre_layer_norm, 'ln1_epsilon', ln1_epsilon, 'ln2_epsilon', ln2_epsilon, 'act_method', activation, - 'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate) + 'dropout1_rate', dropout1_rate, 'dropout2_rate', dropout2_rate, + "dropout1_is_test", not training, "dropout2_is_test", not training, + "dropout1_fix_seed", seed is not None, "dropout2_fix_seed", + seed is not None, "dropout1_seed", seed + if seed is not None else 0, "dropout2_seed", seed + if seed is not None else 0, 'dropout1_implementation', mode, + 'dropout2_implementation', mode) return out helper = LayerHelper("fused_feedforward") @@ -136,6 +154,9 @@ def fused_feedforward(x, dropout2_out = helper.create_variable_for_type_inference( x.dtype, stop_gradient=True) + if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + seed = helper.main_program.random_seed + helper.append_op( type='fused_feedforward', inputs={ @@ -169,6 +190,14 @@ def fused_feedforward(x, 'pre_layer_norm': pre_layer_norm, 'ln1_epsilon': ln1_epsilon, 'ln2_epsilon': ln2_epsilon, + 'dropout1_is_test': not training, + 'dropout2_is_test': not training, + 'dropout1_fix_seed': seed is not None, + 'dropout2_fix_seed': seed is not None, + 'dropout1_seed': seed if seed is not None else 0, + 'dropout2_seed': seed if seed is not None else 0, + 'dropout1_implementation': mode, + 'dropout2_implementation': mode }) return out @@ -188,6 +217,8 @@ def fused_multi_head_attention(x, dropout_rate=0.5, attn_dropout_rate=0.5, ln_epsilon=1e-05, + training=True, + mode='upscale_in_train', name=None): """ Attention mapps queries and a set of key-value pairs to outputs, and @@ -247,6 +278,9 @@ def fused_multi_head_attention(x, 0 for no dropout. Default 0.5. ln_epsilon (float, optional): Small float value added to denominator of layer_norm to avoid dividing by zero. Default is 1e-5. + training (bool): A flag indicating whether it is in train phrase or not. Default True. + mode(str): ['upscale_in_train'(default) | 'downscale_in_infer']. + name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`. Returns: Tensor: The output Tensor, the data type and shape is same as `x`. @@ -280,7 +314,16 @@ def fused_multi_head_attention(x, # [2, 4, 128] print(output.shape) """ + + seed = None + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'") + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed # pre_ln_mean, pre_ln_variance, pre_ln_out, qkv_out, qkv_bias_out, transpose_out, qk_out, # qktv_out, softmax_out, attn_dropout_mask_out, attn_dropout_out, attn_mask_out, fmha_out, # linear_out, dropout_mask_out, ln_mean_out, ln_var_out, bias_dropout_residual_out, final_out @@ -295,7 +338,12 @@ def fused_multi_head_attention(x, linear_weight, linear_bias, ln_scale, ln_bias, 'pre_layer_norm', pre_layer_norm, 'epsilon', pre_ln_epsilon, 'dropout_rate', dropout_rate, 'attn_dropout_rate', attn_dropout_rate, 'ln_epsilon', - ln_epsilon) + ln_epsilon, 'attn_dropout_is_test', not training, 'dropout_is_test', + not training, 'attn_dropout_fix_seed', seed is not None, + 'dropout_fix_seed', seed is not None, 'attn_dropout_seed', seed + if seed is not None else 0, 'dropout_seed', seed + if seed is not None else 0, 'attn_dropout_implementation', mode, + 'dropout_implementation', mode) return final_out else: helper = LayerHelper('fused_multi_head_attention', **locals()) @@ -323,13 +371,24 @@ def fused_multi_head_attention(x, if ln_bias: inputs['Ln2Bias'] = [ln_bias] + if (seed is None or seed == 0) and helper.main_program.random_seed != 0: + seed = helper.main_program.random_seed + # set attrs attrs = { 'pre_layer_norm': pre_layer_norm, 'epsilon': pre_ln_epsilon, 'ln_epsilon': ln_epsilon, 'dropout_rate': dropout_rate, - 'attn_dropout_rate': attn_dropout_rate + 'attn_dropout_rate': attn_dropout_rate, + 'attn_dropout_is_test': not training, + 'dropout_is_test': not training, + 'attn_dropout_fix_seed': seed is not None, + 'dropout_fix_seed': seed is not None, + 'attn_dropout_seed': seed if seed is not None else 0, + 'dropout_seed': seed if seed is not None else 0, + 'attn_dropout_implementation': mode, + 'dropout_implementation': mode, } # set outputs diff --git a/python/paddle/incubate/nn/layer/fused_transformer.py b/python/paddle/incubate/nn/layer/fused_transformer.py index a3d8a74844b..42d6c491d64 100644 --- a/python/paddle/incubate/nn/layer/fused_transformer.py +++ b/python/paddle/incubate/nn/layer/fused_transformer.py @@ -39,11 +39,13 @@ class FusedMultiHeadAttention(Layer): attn_dropout_rate (float, optional): The dropout probability used on attention weights to drop some attention targets for the dropout in attention. 0 for no dropout. Default 0.5. + epsilon (float, optional): he small value added to the variance to prevent + division by zero. Default: 1e-05. kdim (int, optional): The feature size in key. If None, assumed equal to `embed_dim`. Default None. vdim (int, optional): The feature size in value. If None, assumed equal to `embed_dim`. Default None. - normalize_before (bool, optional): Indicate whether it is pre_layer_norm + normalize_before (bool, optional): Indicate whether it is pre_layer_norm (True) or post_layer_norm architecture (False). Default False. need_weights (bool, optional): Indicate whether to return the attention weights. Now, only False is supported. Default False. @@ -80,6 +82,7 @@ class FusedMultiHeadAttention(Layer): need_weights=False, weight_attr=None, bias_attr=None, + epsilon=1e-5, name=None): super(FusedMultiHeadAttention, self).__init__() @@ -88,13 +91,18 @@ class FusedMultiHeadAttention(Layer): assert num_heads > 0, ("Expected nhead to be greater than 0, " "but recieved {}".format(num_heads)) - attn_dropout_rate = dropout_rate if attn_dropout_rate is None else attn_dropout_rate self.normalize_before = normalize_before self._dtype = self._helper.get_default_dtype() self._weight_attr = weight_attr self._bias_attr = bias_attr + self._epsilon = epsilon + self.embed_dim = embed_dim + self.num_heads = num_heads self.head_dim = embed_dim // num_heads + self.kdim = kdim + self.vdim = vdim + self.need_weights = need_weights assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" assert need_weights == False, "Only support need_weight is False now." @@ -186,15 +194,24 @@ class FusedMultiHeadAttention(Layer): pre_ln_bias=self.pre_ln_bias, ln_scale=self.ln_scale, ln_bias=self.ln_bias, - pre_ln_epsilon=1e-05, + pre_ln_epsilon=self._epsilon, qkv_bias=self.qkv_bias, linear_bias=self.linear_bias, attn_mask=attn_mask, dropout_rate=self.dropout_rate, attn_dropout_rate=self.attn_dropout_rate, - ln_epsilon=1e-05) + ln_epsilon=self._epsilon, + training=self.training, + name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'embed_dim={}, num_heads={}, dropout_rate={}, attn_dropout_rate={}, epsilon={}, kdim={}, vdim={}, normalize_before={}, need_weights={}, dtype={}{}'.format( + self.embed_dim, self.num_heads, self.dropout_rate, + self.attn_dropout_rate, self._epsilon, self.kdim, self.vdim, + self.normalize_before, self.need_weights, self._dtype, name_str) + class FusedFeedForward(Layer): """ @@ -203,6 +220,8 @@ class FusedFeedForward(Layer): dim_feedforward (int): The hidden layer size. dropout_rate (float, optional): The dropout probability used in pre-process and post-precess. Default 0.1 + epsilon (float, optional): he small value added to the variance to prevent + division by zero. Default: 1e-05. activation (str, optional): The activation function. Default relu. act_dropout_rate (float, optional): The dropout probability after activition. If None, use the value of `dropout_rate`. Default None @@ -235,11 +254,13 @@ class FusedFeedForward(Layer): d_model, dim_feedforward, dropout_rate=0.1, + epsilon=1e-05, activation="relu", act_dropout_rate=None, normalize_before=False, weight_attr=None, - bias_attr=None): + bias_attr=None, + name=None): super(FusedFeedForward, self).__init__() assert d_model > 0, ( @@ -256,6 +277,7 @@ class FusedFeedForward(Layer): self._act_dropout_rate = dropout_rate if act_dropout_rate is None else act_dropout_rate self._act_method = activation self._normalize_before = normalize_before + self._epsilon = epsilon self._linear1_weight = self.create_parameter( shape=[d_model, dim_feedforward], @@ -292,15 +314,36 @@ class FusedFeedForward(Layer): default_initializer=Constant(1.0)) self._ln2_bias = self.create_parameter( shape=[d_model], attr=None, is_bias=True) + self.name = name def forward(self, src, cache=None): out = incubate_f.fused_feedforward( - src, self._linear1_weight, self._linear2_weight, self._linear1_bias, - self._linear2_bias, self._ln1_scale, self._ln1_bias, - self._ln2_scale, self._ln2_bias, self._dropout_rate, - self._act_dropout_rate, self._act_method, self._normalize_before) + src, + self._linear1_weight, + self._linear2_weight, + self._linear1_bias, + self._linear2_bias, + self._ln1_scale, + self._ln1_bias, + self._ln2_scale, + self._ln2_bias, + dropout1_rate=self._act_dropout_rate, + dropout2_rate=self._dropout_rate, + activation=self._act_method, + ln1_epsilon=self._epsilon, + ln2_epsilon=self._epsilon, + pre_layer_norm=self._normalize_before, + training=self.training, + name=self.name) return out + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'd_model={}, dim_feedforward={}, dropout_rate={}, epsilon={}, activation={}, act_dropout_rate={}, normalize_before={}, dtype={}{}'.format( + self._d_model, self._dim_feedforward, self._dropout_rate, + self._epsilon, self._act_method, self._act_dropout_rate, + self._normalize_before, self._dtype, name_str) + class FusedTransformerEncoderLayer(Layer): """ @@ -393,7 +436,9 @@ class FusedTransformerEncoderLayer(Layer): self.fused_attn = FusedMultiHeadAttention( d_model, nhead, - dropout_rate=attn_dropout_rate, + dropout_rate=dropout_rate, + attn_dropout_rate=attn_dropout_rate, + normalize_before=self.normalize_before, weight_attr=weight_attrs[0], bias_attr=bias_attrs[0]) @@ -401,6 +446,7 @@ class FusedTransformerEncoderLayer(Layer): d_model, dim_feedforward, dropout_rate=dropout_rate, + activation=activation, act_dropout_rate=act_dropout_rate, normalize_before=self.normalize_before, weight_attr=weight_attrs[1], diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 7362b284eae..ef08982a6ff 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -235,8 +235,8 @@ def interpolate(x, Examples: .. code-block:: python - import paddle - import numpy as np + import paddle + import numpy as np import paddle.nn.functional as F # given out size @@ -244,7 +244,7 @@ def interpolate(x, x = paddle.to_tensor(input_data) output_1 = F.interpolate(x=x, size=[12,12]) print(output_1.shape) - # [2L, 3L, 12L, 12L] + # [2L, 3L, 12L, 12L] # given scale output_2 = F.interpolate(x=x, scale_factor=[2,1]) -- GitLab